From 3a6e6afcae96505cb752f8d821a4097851d12da8 Mon Sep 17 00:00:00 2001 From: Paulo Valente <16843419+polvalente@users.noreply.github.com> Date: Sun, 19 Apr 2026 02:20:30 -0300 Subject: [PATCH 1/5] feat: add rfft/irfft --- exla/lib/exla/defn.ex | 78 ++++++++-- exla/lib/exla/mlir/value.ex | 4 +- nx/lib/nx.ex | 270 +++++++++++++++++++++++++++++++++++ nx/lib/nx/block.ex | 8 ++ torchx/c_src/torchx.cpp | 16 +++ torchx/lib/torchx.ex | 2 + torchx/lib/torchx/backend.ex | 20 +++ 7 files changed, 383 insertions(+), 15 deletions(-) diff --git a/exla/lib/exla/defn.ex b/exla/lib/exla/defn.ex index 89a35c48cf..881077b926 100644 --- a/exla/lib/exla/defn.ex +++ b/exla/lib/exla/defn.ex @@ -767,6 +767,52 @@ defmodule EXLA.Defn do {fft2(&Value.fft(&1, :ifft, &2, &3), [tensor, opts], expr, state), cache} end + defp cached_recur_operator( + :block, + %T{data: %Expr{args: [%Nx.Block.RFFT{} = rfft_struct, [tensor], expr, _callback]}}, + state, + cache + ) do + {tensor, cache} = recur_operator(tensor, state, cache) |> unwrap_single_tensor!() + + opts = [length: rfft_struct.length, axis: rfft_struct.axis] + + opts = + if eps = rfft_struct.eps do + Keyword.put(opts, :eps, eps) + else + opts + end + + # expr.type is complex; input tensor is real + input_type = Nx.Type.to_real(expr.type) + {fft(&Value.fft(&1, :rfft, &2, &3), input_type, expr.type, [tensor, opts], expr, state), cache} + end + + defp cached_recur_operator( + :block, + %T{data: %Expr{args: [%Nx.Block.IRFFT{} = irfft_struct, [tensor], expr, _callback]}}, + state, + cache + ) do + {tensor, cache} = recur_operator(tensor, state, cache) |> unwrap_single_tensor!() + + opts = [length: irfft_struct.length, axis: irfft_struct.axis] + + opts = + if eps = irfft_struct.eps do + Keyword.put(opts, :eps, eps) + else + opts + end + + # expr.type is real; input tensor is complex. + # pad_n = div(n,2)+1 (the expected input size), while fft_n = n (the output length). + n = irfft_struct.length + input_type = Nx.Type.to_complex(expr.type) + {fft(&Value.fft(&1, :irfft, &2, &3), input_type, expr.type, div(n, 2) + 1, [tensor, opts], expr, state), cache} + end + defp cached_recur_operator(:block, %T{data: %Expr{args: args}}, state, cache) do [struct, in_args, expr, _callback] = args %module{} = struct @@ -1233,10 +1279,10 @@ defmodule EXLA.Defn do end defp to_operator(:fft, [%Value{} | _] = args, out, state), - do: fft(&Value.fft(&1, :fft, &2, &3), args, out, state) + do: fft(&Value.fft(&1, :fft, &2, &3), out.type, out.type, args, out, state) defp to_operator(:ifft, [%Value{} | _] = args, out, state), - do: fft(&Value.fft(&1, :ifft, &2, &3), args, out, state) + do: fft(&Value.fft(&1, :ifft, &2, &3), out.type, out.type, args, out, state) defp to_operator(:is_nan, [%Value{} = arg], out, _state), do: Value.is_nan(arg, expr_to_typespec(out)) @@ -1561,16 +1607,16 @@ defmodule EXLA.Defn do EXLA.Lib.argsort(state.builder, tensor, dimension, stable, comp, ans.type) end - defp fft(exla_op, [%Value{} = tensor, opts], %{type: type} = ans, state) do - n = opts[:length] + defp fft(exla_op, input_type, output_type, pad_n \\ nil, [%Value{} = tensor, opts], ans, state) do + fft_n = opts[:length] + pad_n = pad_n || fft_n axis = opts[:axis] - output_type = Nx.Type.to_complex(type) - tensor = to_type(tensor, output_type) + tensor = to_type(tensor, input_type) shape = op_shape(tensor) m = elem(shape, axis) - tensor = fft_pad_or_slice(tensor, m, n, axis, shape, output_type, state) + tensor = fft_pad_or_slice(tensor, m, pad_n, axis, shape, input_type, state) last_axis = tuple_size(shape) - 1 @@ -1582,15 +1628,19 @@ defmodule EXLA.Defn do ax -> ax end) - {transposed_shape, _} = Nx.Shape.transpose(ans.shape, permutation, ans.names) - transposed_typespec = Typespec.tensor(ans.type, transposed_shape) + padded_shape = op_shape(tensor) + {transposed_input_shape, _} = Nx.Shape.transpose(padded_shape, permutation, List.duplicate(nil, tuple_size(padded_shape))) + transposed_input_typespec = Typespec.tensor(input_type, transposed_input_shape) + + {transposed_output_shape, _} = Nx.Shape.transpose(ans.shape, permutation, ans.names) + transposed_output_typespec = Typespec.tensor(output_type, transposed_output_shape) tensor - |> Value.transpose(permutation, transposed_typespec) - |> exla_op.([n], transposed_typespec) + |> Value.transpose(permutation, transposed_input_typespec) + |> exla_op.([fft_n], transposed_output_typespec) |> Value.transpose(permutation, expr_to_typespec(ans)) else - exla_op.(tensor, [n], expr_to_typespec(ans)) + exla_op.(tensor, [fft_n], expr_to_typespec(ans)) end end @@ -1655,8 +1705,10 @@ defmodule EXLA.Defn do Value.slice(tensor, starts, limit_indices, strides, typespec) m < n -> + zero_value = if Nx.Type.complex?(output_type), do: Complex.new(0), else: 0 + zero = - Value.constant(state.builder, [Complex.new(0)], Typespec.tensor(output_type, {})) + Value.constant(state.builder, [zero_value], Typespec.tensor(output_type, {})) padding_config = {0, 0, 0} diff --git a/exla/lib/exla/mlir/value.ex b/exla/lib/exla/mlir/value.ex index 463ed01faf..9d028ff6dd 100644 --- a/exla/lib/exla/mlir/value.ex +++ b/exla/lib/exla/mlir/value.ex @@ -372,7 +372,7 @@ defmodule EXLA.MLIR.Value do end def fft(%Value{function: func} = value, fft_kind, fft_length, typespec) - when fft_kind in [:fft, :ifft] + when fft_kind in [:fft, :ifft, :rfft, :irfft] when is_list(fft_length) or is_integer(fft_length) do result_types = typespecs_to_mlir_types([typespec]) @@ -1070,7 +1070,7 @@ defmodule EXLA.MLIR.Value do defp attr_transpose(value) when value in [:adjoint, :transpose, :no_transpose], do: attr_enum("stablehlo", "transpose", value) - defp attr_fft_type(value) when value in [:fft, :ifft], + defp attr_fft_type(value) when value in [:fft, :ifft, :rfft, :irfft], do: attr_enum("stablehlo", "fft_type", value) defp attr_enum(dialect, enum_name, value) do diff --git a/nx/lib/nx.ex b/nx/lib/nx.ex index 339bf019d2..35c631c16a 100644 --- a/nx/lib/nx.ex +++ b/nx/lib/nx.ex @@ -16797,6 +16797,276 @@ defmodule Nx do end) end + @doc """ + Calculates the real-input DFT of the given tensor. + + Exploits the conjugate-symmetry property of the DFT for real inputs by + computing the full FFT and returning only the non-redundant first + `floor(length / 2) + 1` frequency components along the transform axis. + + ## Options + + * `:eps` - Threshold which backends can use for cleaning-up results. Defaults to `1.0e-10`. + * `:length` - Either a positive integer or `:power_of_two`. Will pad or slice the tensor + along the transform axis accordingly. `:power_of_two` will automatically pad to the + next power of two. Defaults to the axis size. + * `:axis` - the axis upon which the real DFT will be calculated. Defaults to the last axis. + + ## Examples + + iex> Nx.rfft(Nx.tensor([1.0, 1.0, 0.0, 0.0])) + #Nx.Tensor< + c64[3] + [2.0+0.0i, 1.0-1.0i, 0.0+0.0i] + > + + iex> Nx.rfft(Nx.tensor([1.0, 1.0, 1.0, 0.0, 1.0, 1.0])) + #Nx.Tensor< + c64[4] + [5.0+0.0i, 1.0+0.0i, -1.0+0.0i, 1.0+0.0i] + > + + The calculation can happen on a specific axis: + + iex> tensor = Nx.tensor([[1.0, 1.0, 0.0, 0.0], [1.0, 0.0, 0.0, 0.0]]) + iex> Nx.rfft(tensor, axis: -1) + #Nx.Tensor< + c64[2][3] + [ + [2.0+0.0i, 1.0-1.0i, 0.0+0.0i], + [1.0+0.0i, 1.0+0.0i, 1.0+0.0i] + ] + > + iex> Nx.rfft(tensor, axis: -2) + #Nx.Tensor< + c64[2][4] + [ + [2.0+0.0i, 1.0+0.0i, 0.0+0.0i, 0.0+0.0i], + [0.0+0.0i, 1.0+0.0i, 0.0+0.0i, 0.0+0.0i] + ] + > + + Padding and slicing can be introduced through `:length`: + + iex> Nx.rfft(Nx.tensor([1.0, 1.0]), length: 4) + #Nx.Tensor< + c64[3] + [2.0+0.0i, 1.0-1.0i, 0.0+0.0i] + > + + iex> Nx.rfft(Nx.tensor([1.0, 1.0, 0.0]), length: :power_of_two) + #Nx.Tensor< + c64[3] + [2.0+0.0i, 1.0-1.0i, 0.0+0.0i] + > + + ## Vectorized tensors + + Vectorized tensors work the same as N-dimensional tensors + + iex> tensor = Nx.tensor([[1.0, 1.0, 0.0, 0.0], [1.0, 0.0, 0.0, 0.0]]) |> Nx.vectorize(:x) + iex> Nx.rfft(tensor) + #Nx.Tensor< + vectorized[x: 2] + c64[3] + [ + [2.0+0.0i, 1.0-1.0i, 0.0+0.0i], + [1.0+0.0i, 1.0+0.0i, 1.0+0.0i] + ] + > + + ## Error Cases + + iex> Nx.rfft(Nx.tensor([Complex.new(1, 0), Complex.new(0, 1)])) + ** (ArgumentError) Nx.rfft/2 expects a real tensor, got type: {:c, 64} + + iex> Nx.rfft(Nx.tensor([1.0, 1.0]), length: :invalid) + ** (ArgumentError) expected an integer or :power_of_two as length, got: :invalid + """ + @doc type: :ndim + def rfft(tensor, opts \\ []) do + tensor = to_tensor(tensor) + + if Nx.Type.complex?(tensor.type) do + raise ArgumentError, "Nx.rfft/2 expects a real tensor, got type: #{inspect(tensor.type)}" + end + + apply_vectorized(tensor, fn tensor, offset -> + shape = Nx.Shape.fft(tensor.shape) + opts = Keyword.validate!(opts, [:length, axis: -1, eps: 1.0e-10]) + + axis = Nx.Shape.normalize_axis(shape, opts[:axis], tensor.names, offset) + n = elem(shape, axis) + + length = + case opts[:length] do + :power_of_two -> + 2 ** Kernel.ceil(:math.log2(n)) + + nil -> + n + + n when is_integer(n) and n > 0 -> + n + + length -> + raise ArgumentError, + "expected an integer or :power_of_two as length, got: #{inspect(length)}" + end + + rfft_length = div(length, 2) + 1 + + output_shape = + shape + |> Tuple.insert_at(axis, rfft_length) + |> Tuple.delete_at(axis + 1) + + out = to_template(%{tensor | shape: output_shape, type: Nx.Type.to_complex(tensor.type)}) + block_struct = struct!(Nx.Block.RFFT, eps: opts[:eps], length: length, axis: axis) + + block(block_struct, [tensor], out, fn s, tensor -> + tensor + |> fft(length: s.length, axis: s.axis, eps: s.eps) + |> slice_along_axis(0, div(s.length, 2) + 1, axis: s.axis) + end) + end) + end + + @doc """ + Calculates the Inverse real-input DFT of the given tensor. + + Reconstructs a real-valued signal from a one-sided complex spectrum produced + by `rfft/2`. The input is assumed to contain the non-redundant Hermitian half + of a spectrum of length `n` (i.e. `floor(n / 2) + 1` elements along the + transform axis). The missing conjugate-symmetric components are derived + automatically before calling `ifft/2`. + + ## Options + + * `:eps` - Threshold which backends can use for cleaning-up results. Defaults to `1.0e-10`. + * `:length` - A positive integer specifying the output signal length `n`. Defaults to + `2 * (m - 1)` where `m` is the axis size of the input, which assumes the original + signal had even length. Pass an explicit `:length` to recover odd-length signals. + * `:axis` - the axis upon which the Inverse real DFT will be calculated. Defaults to the + last axis. + + ## Examples + + iex> Nx.irfft(Nx.tensor([2.0, Complex.new(1.0, -1.0), 0.0])) + #Nx.Tensor< + f32[4] + [1.0, 1.0, 0.0, 0.0] + > + + iex> Nx.irfft(Nx.tensor([5.0, 1.0, -1.0, 1.0])) + #Nx.Tensor< + f32[6] + [1.0, 1.0, 1.0, 0.0, 1.0, 1.0] + > + + The calculation can happen on a specific axis: + + iex> tensor = Nx.tensor([[2.0, Complex.new(1.0, -1.0), 0.0], [4.0, 0.0, 0.0]]) + iex> Nx.irfft(tensor, axis: -1) + #Nx.Tensor< + f32[2][4] + [ + [1.0, 1.0, 0.0, 0.0], + [1.0, 1.0, 1.0, 1.0] + ] + > + + An explicit `:length` recovers odd-length signals and controls input truncation: + + iex> Nx.irfft(Nx.tensor([2.0, Complex.new(1.0, -1.0), 0.0]), length: 4) + #Nx.Tensor< + f32[4] + [1.0, 1.0, 0.0, 0.0] + > + + ## Vectorized tensors + + Vectorized tensors work the same as N-dimensional tensors + + iex> tensor = Nx.tensor([[2.0, Complex.new(1.0, -1.0), 0.0], [4.0, 0.0, 0.0]]) |> Nx.vectorize(:x) + iex> Nx.irfft(tensor) + #Nx.Tensor< + vectorized[x: 2] + f32[4] + [ + [1.0, 1.0, 0.0, 0.0], + [1.0, 1.0, 1.0, 1.0] + ] + > + + ## Error Cases + + iex> Nx.irfft(Nx.tensor([1.0, 1.0]), length: :invalid) + ** (ArgumentError) expected a positive integer as length, got: :invalid + """ + @doc type: :ndim + def irfft(tensor, opts \\ []) do + apply_vectorized(tensor, fn tensor, offset -> + shape = Nx.Shape.fft(tensor.shape) + opts = Keyword.validate!(opts, [:length, axis: -1, eps: 1.0e-10]) + + axis = Nx.Shape.normalize_axis(shape, opts[:axis], tensor.names, offset) + actual_m = elem(shape, axis) + + n = + case opts[:length] do + nil -> + 2 * (actual_m - 1) + + n when is_integer(n) and n > 0 -> + n + + length -> + raise ArgumentError, + "expected a positive integer as length, got: #{inspect(length)}" + end + + output_shape = + shape + |> Tuple.insert_at(axis, n) + |> Tuple.delete_at(axis + 1) + + out = to_template(%{tensor | shape: output_shape, type: Nx.Type.to_real(tensor.type)}) + block_struct = struct(Nx.Block.IRFFT, eps: opts[:eps], length: n, axis: axis) + + block(block_struct, [tensor], out, fn s, tensor -> + axis = s.axis + n = s.length + m = div(n, 2) + 1 + + actual_m = elem(Nx.shape(tensor), axis) + + tensor = + cond do + actual_m > m -> slice_along_axis(tensor, 0, m, axis: axis) + actual_m < m -> pad(tensor, 0, List.replace_at(List.duplicate({0, 0, 0}, tuple_size(Nx.shape(tensor))), axis, {0, m - actual_m, 0})) + true -> tensor + end + + # mirror_count = n - m handles both even and odd n: + # even n=8: m=5, mirror indices 1..3 (3 elements), total=8 + # odd n=7: m=4, mirror indices 1..3 (3 elements), total=7 + mirror_count = n - m + + mirror = + tensor + |> slice_along_axis(1, mirror_count, axis: axis) + |> conjugate() + |> reverse(axes: [axis]) + + [tensor, mirror] + |> concatenate(axis: axis) + |> ifft(axis: axis, eps: s.eps) + |> real() + end) + end) + end + @doc """ Creates a tensor of shape `{n}` with linearly spaced samples between `start` and `stop`. diff --git a/nx/lib/nx/block.ex b/nx/lib/nx/block.ex index c47caa1531..944955ca70 100644 --- a/nx/lib/nx/block.ex +++ b/nx/lib/nx/block.ex @@ -73,3 +73,11 @@ end defmodule Nx.Block.IFFT2 do defstruct eps: nil, lengths: nil, axes: nil end + +defmodule Nx.Block.RFFT do + defstruct eps: nil, length: nil, axis: nil +end + +defmodule Nx.Block.IRFFT do + defstruct eps: nil, length: nil, axis: nil +end diff --git a/torchx/c_src/torchx.cpp b/torchx/c_src/torchx.cpp index 477c8d21d4..ffaab13fb0 100644 --- a/torchx/c_src/torchx.cpp +++ b/torchx/c_src/torchx.cpp @@ -865,6 +865,22 @@ ifft(ErlNifEnv *env, fine::ResourcePtr tensor, int64_t length, REGISTER_TENSOR_NIF(ifft); +fine::Ok> +rfft(ErlNifEnv *env, fine::ResourcePtr tensor, int64_t length, + int64_t axis) { + return tensor_ok(torch::fft::rfft(get_tensor(tensor), length, axis)); +} + +REGISTER_TENSOR_NIF(rfft); + +fine::Ok> +irfft(ErlNifEnv *env, fine::ResourcePtr tensor, int64_t length, + int64_t axis) { + return tensor_ok(torch::fft::irfft(get_tensor(tensor), length, axis)); +} + +REGISTER_TENSOR_NIF(irfft); + fine::Ok> fft2(ErlNifEnv *env, fine::ResourcePtr tensor, std::vector lengths, std::vector axes) { diff --git a/torchx/lib/torchx.ex b/torchx/lib/torchx.ex index e94195f886..506fa958ec 100644 --- a/torchx/lib/torchx.ex +++ b/torchx/lib/torchx.ex @@ -353,6 +353,8 @@ defmodule Torchx do deftensor cbrt(tensor) deftensor fft(tensor, length, axis) deftensor ifft(tensor, length, axis) + deftensor rfft(tensor, length, axis) + deftensor irfft(tensor, length, axis) deftensor fft2(tensor, lengths, axes) deftensor ifft2(tensor, lengths, axes) diff --git a/torchx/lib/torchx/backend.ex b/torchx/lib/torchx/backend.ex index ce0985cc40..ea9abb93cd 100644 --- a/torchx/lib/torchx/backend.ex +++ b/torchx/lib/torchx/backend.ex @@ -94,6 +94,12 @@ defmodule Torchx.Backend do {Nx.Block.TakeAlongAxis, [tensor, indices]} -> take_along_axis_gather(output, tensor, indices, axis: struct.axis) + {Nx.Block.RFFT, [t]} -> + rfft_torchx(output, t, struct.length, struct.axis) + + {Nx.Block.IRFFT, [t]} -> + irfft_torchx(output, t, struct.length, struct.axis) + {Nx.Block.FFT2, [t]} -> fft2_torchx(output, t, struct.lengths, struct.axes) @@ -1092,6 +1098,20 @@ defmodule Torchx.Backend do |> to_nx(out) end + defp rfft_torchx(out, tensor, length, axis) do + tensor + |> from_nx() + |> Torchx.rfft(length, axis) + |> to_nx(out) + end + + defp irfft_torchx(out, tensor, length, axis) do + tensor + |> from_nx() + |> Torchx.irfft(length, axis) + |> to_nx(out) + end + defp fft2_torchx(out, tensor, lengths, axes) do tensor |> from_nx() From 885e004182906ba1e1f1a0d8543eaf66c74508c0 Mon Sep 17 00:00:00 2001 From: Paulo Valente <16843419+polvalente@users.noreply.github.com> Date: Sun, 19 Apr 2026 02:23:12 -0300 Subject: [PATCH 2/5] chore: format --- exla/lib/exla/defn.ex | 24 +++++++++++++++++++++--- 1 file changed, 21 insertions(+), 3 deletions(-) diff --git a/exla/lib/exla/defn.ex b/exla/lib/exla/defn.ex index 881077b926..3f6ba7def1 100644 --- a/exla/lib/exla/defn.ex +++ b/exla/lib/exla/defn.ex @@ -786,7 +786,9 @@ defmodule EXLA.Defn do # expr.type is complex; input tensor is real input_type = Nx.Type.to_real(expr.type) - {fft(&Value.fft(&1, :rfft, &2, &3), input_type, expr.type, [tensor, opts], expr, state), cache} + + {fft(&Value.fft(&1, :rfft, &2, &3), input_type, expr.type, [tensor, opts], expr, state), + cache} end defp cached_recur_operator( @@ -810,7 +812,16 @@ defmodule EXLA.Defn do # pad_n = div(n,2)+1 (the expected input size), while fft_n = n (the output length). n = irfft_struct.length input_type = Nx.Type.to_complex(expr.type) - {fft(&Value.fft(&1, :irfft, &2, &3), input_type, expr.type, div(n, 2) + 1, [tensor, opts], expr, state), cache} + + {fft( + &Value.fft(&1, :irfft, &2, &3), + input_type, + expr.type, + div(n, 2) + 1, + [tensor, opts], + expr, + state + ), cache} end defp cached_recur_operator(:block, %T{data: %Expr{args: args}}, state, cache) do @@ -1629,7 +1640,14 @@ defmodule EXLA.Defn do end) padded_shape = op_shape(tensor) - {transposed_input_shape, _} = Nx.Shape.transpose(padded_shape, permutation, List.duplicate(nil, tuple_size(padded_shape))) + + {transposed_input_shape, _} = + Nx.Shape.transpose( + padded_shape, + permutation, + List.duplicate(nil, tuple_size(padded_shape)) + ) + transposed_input_typespec = Typespec.tensor(input_type, transposed_input_shape) {transposed_output_shape, _} = Nx.Shape.transpose(ans.shape, permutation, ans.names) From 0034457aff3b3fd21d5350caba3e4de0821b69d2 Mon Sep 17 00:00:00 2001 From: Paulo Valente <16843419+polvalente@users.noreply.github.com> Date: Mon, 20 Apr 2026 14:28:36 -0300 Subject: [PATCH 3/5] chore: make ci happy --- nx/lib/nx.ex | 13 +++- nx/lib/nx/binary_backend.ex | 9 +-- nx/test/nx/defn/composite_test.exs | 3 +- nx/test/nx/defn/grad_test.exs | 12 +++- torchx/test/torchx/nx_block_test.exs | 83 +++++++++++++++++++++++++- torchx/test/torchx/nx_doctest_test.exs | 8 ++- 6 files changed, 111 insertions(+), 17 deletions(-) diff --git a/nx/lib/nx.ex b/nx/lib/nx.ex index 35c631c16a..82b3e7200d 100644 --- a/nx/lib/nx.ex +++ b/nx/lib/nx.ex @@ -17043,9 +17043,16 @@ defmodule Nx do tensor = cond do - actual_m > m -> slice_along_axis(tensor, 0, m, axis: axis) - actual_m < m -> pad(tensor, 0, List.replace_at(List.duplicate({0, 0, 0}, tuple_size(Nx.shape(tensor))), axis, {0, m - actual_m, 0})) - true -> tensor + actual_m > m -> + slice_along_axis(tensor, 0, m, axis: axis) + + actual_m < m -> + zeros = List.duplicate({0, 0, 0}, tuple_size(Nx.shape(tensor))) + padding_config = List.replace_at(zeros, axis, {0, m - actual_m, 0}) + pad(tensor, 0, padding_config) + + true -> + tensor end # mirror_count = n - m handles both even and odd n: diff --git a/nx/lib/nx/binary_backend.ex b/nx/lib/nx/binary_backend.ex index 84e05c6fbb..5c98fb0ab4 100644 --- a/nx/lib/nx/binary_backend.ex +++ b/nx/lib/nx/binary_backend.ex @@ -527,13 +527,11 @@ defmodule Nx.BinaryBackend do right_batch_item_bits = right_batch_item_length * right_size <<_::bitstring-size(^left_offset_bits), - left_batch_item_binary::bitstring-size(^left_batch_item_bits), - _::bitstring>> = + left_batch_item_binary::bitstring-size(^left_batch_item_bits), _::bitstring>> = left_binary <<_::bitstring-size(^right_offset_bits), - right_batch_item_binary::bitstring-size(^right_batch_item_bits), - _::bitstring>> = + right_batch_item_binary::bitstring-size(^right_batch_item_bits), _::bitstring>> = right_binary bin_dot( @@ -1781,8 +1779,7 @@ defmodule Nx.BinaryBackend do before_slice_size = current - previous <> = + current_bitstring::bitstring-size(^target_chunk), to_traverse::bitstring>> = to_traverse updated_elements = diff --git a/nx/test/nx/defn/composite_test.exs b/nx/test/nx/defn/composite_test.exs index 6dc8fc1d52..2618e29170 100644 --- a/nx/test/nx/defn/composite_test.exs +++ b/nx/test/nx/defn/composite_test.exs @@ -24,8 +24,7 @@ defmodule Nx.Defn.CompositeTest do Nx.tensor(1), Nx.tensor(3, type: {:c, 64}), Nx.tensor(4, type: {:c, 64}) - }, - Nx.tensor(2, type: {:c, 64})} == + }, Nx.tensor(2, type: {:c, 64})} == Composite.traverse( {1, Complex.new(2), Nx.tensor(3)}, 0, diff --git a/nx/test/nx/defn/grad_test.exs b/nx/test/nx/defn/grad_test.exs index 8326e2a9a6..df3efbcb89 100644 --- a/nx/test/nx/defn/grad_test.exs +++ b/nx/test/nx/defn/grad_test.exs @@ -4699,12 +4699,20 @@ defmodule Nx.Defn.GradTest do b = Nx.tensor([2, 3, 4]) assert_all_close( - triangular_solve_grad_wrt_a(a, b, transform_a: :conjugate, left_side: false, lower: false), + triangular_solve_grad_wrt_a(a, b, + transform_a: :conjugate, + left_side: false, + lower: false + ), triangular_solve_grad_wrt_a(a, b, transform_a: :none, left_side: false, lower: false) ) assert_all_close( - triangular_solve_grad_wrt_b(a, b, transform_a: :conjugate, left_side: false, lower: false), + triangular_solve_grad_wrt_b(a, b, + transform_a: :conjugate, + left_side: false, + lower: false + ), triangular_solve_grad_wrt_b(a, b, transform_a: :none, left_side: false, lower: false) ) end diff --git a/torchx/test/torchx/nx_block_test.exs b/torchx/test/torchx/nx_block_test.exs index e951e1f699..be307d77b5 100644 --- a/torchx/test/torchx/nx_block_test.exs +++ b/torchx/test/torchx/nx_block_test.exs @@ -2,7 +2,7 @@ defmodule Torchx.NxBlockTest do @moduledoc """ Numerical coverage for `Nx.block/4`-backed APIs on Torchx. - `Nx.fft2/2` and `Nx.ifft2/2` route through `%Nx.Block.FFT2{}` / `%Nx.Block.IFFT2{}`. + `Nx.fft2/2`, `Nx.ifft2/2`, `Nx.rfft/2`, and `Nx.irfft/2` route through block structs. `Torchx.NxDoctestTest` excludes their doctests (BinaryBackend `inspect` strings vs LibTorch signed zeros). Here we assert agreement with a `Nx.BinaryBackend` reference. """ @@ -22,6 +22,87 @@ defmodule Torchx.NxBlockTest do assert Nx.all_close(t_torchx, ref) end + describe "rfft / irfft (block-backed)" do + test "rfft basic" do + same_as_binary( + fn -> Nx.rfft(Nx.tensor([1.0, 1.0, 0.0, 0.0])) end, + fn -> Nx.rfft(Nx.tensor([1.0, 1.0, 0.0, 0.0], backend: Nx.BinaryBackend)) end + ) + end + + test "rfft signed-zero case" do + same_as_binary( + fn -> Nx.rfft(Nx.tensor([1.0, 1.0, 1.0, 0.0, 1.0, 1.0])) end, + fn -> + Nx.rfft(Nx.tensor([1.0, 1.0, 1.0, 0.0, 1.0, 1.0], backend: Nx.BinaryBackend)) + end + ) + end + + test "rfft with axis and length options" do + same_as_binary( + fn -> + tensor = Nx.tensor([[1.0, 1.0, 0.0, 0.0], [1.0, 0.0, 0.0, 0.0]]) + Nx.rfft(tensor, axis: -2) + end, + fn -> + tensor = + Nx.tensor([[1.0, 1.0, 0.0, 0.0], [1.0, 0.0, 0.0, 0.0]], backend: Nx.BinaryBackend) + + Nx.rfft(tensor, axis: -2) + end + ) + end + + test "rfft vectorized" do + same_as_binary( + fn -> + Nx.tensor([[1.0, 1.0, 0.0, 0.0], [1.0, 0.0, 0.0, 0.0]]) + |> Nx.vectorize(:x) + |> Nx.rfft() + end, + fn -> + Nx.tensor([[1.0, 1.0, 0.0, 0.0], [1.0, 0.0, 0.0, 0.0]], backend: Nx.BinaryBackend) + |> Nx.vectorize(:x) + |> Nx.rfft() + end + ) + end + + test "irfft basic" do + same_as_binary( + fn -> Nx.irfft(Nx.tensor([2.0, Complex.new(1.0, -1.0), 0.0])) end, + fn -> + Nx.irfft(Nx.tensor([2.0, Complex.new(1.0, -1.0), 0.0], backend: Nx.BinaryBackend)) + end + ) + end + + test "irfft with length" do + same_as_binary( + fn -> Nx.irfft(Nx.tensor([5.0, 1.0, -1.0, 1.0])) end, + fn -> Nx.irfft(Nx.tensor([5.0, 1.0, -1.0, 1.0], backend: Nx.BinaryBackend)) end + ) + end + + test "irfft vectorized" do + same_as_binary( + fn -> + Nx.tensor([[2.0, Complex.new(1.0, -1.0), 0.0], [4.0, 0.0, 0.0]]) + |> Nx.vectorize(:x) + |> Nx.irfft() + end, + fn -> + Nx.tensor([[2.0, Complex.new(1.0, -1.0), 0.0], [4.0, 0.0, 0.0]], + backend: Nx.BinaryBackend + ) + |> Nx.vectorize(:x) + |> Nx.irfft() + end + ) + end + end + describe "fft2 / ifft2 (block-backed)" do test "fft2 simple matrix" do same_as_binary( diff --git a/torchx/test/torchx/nx_doctest_test.exs b/torchx/test/torchx/nx_doctest_test.exs index ee0ffd8d86..ca6ed2be84 100644 --- a/torchx/test/torchx/nx_doctest_test.exs +++ b/torchx/test/torchx/nx_doctest_test.exs @@ -5,9 +5,9 @@ defmodule Torchx.NxDoctestTest do Many tests are excluded for the reasons below, coverage for the excluded tests can be found on Torchx.NxTest. - `Nx.fft2/2` and `Nx.ifft2/2` are implemented via `Nx.block/4` (`%Nx.Block.FFT2{}` / - `%Nx.Block.IFFT2{}`). Doctest expectations are written for `Nx.BinaryBackend` string - `inspect/1` output; LibTorch can produce signed zeros so complex `inspect` differs. + `Nx.fft2/2`, `Nx.ifft2/2`, `Nx.rfft/2`, and `Nx.irfft/2` are implemented via `Nx.block/4`. + Doctest expectations are written for `Nx.BinaryBackend` string `inspect/1` output; + LibTorch can produce signed zeros so complex `inspect` differs. Those doctests are excluded here; see `Torchx.NxBlockTest` for numerical checks. """ @@ -28,6 +28,8 @@ defmodule Torchx.NxDoctestTest do ifft: 2, fft2: 2, ifft2: 2, + rfft: 2, + irfft: 2, expm1: 1, standard_deviation: 2 ] From 5ef6ee5dc1a6513eb9b3a24800a24ae598342fc0 Mon Sep 17 00:00:00 2001 From: Paulo Valente <16843419+polvalente@users.noreply.github.com> Date: Mon, 20 Apr 2026 14:31:01 -0300 Subject: [PATCH 4/5] chore: format --- nx/lib/nx/binary_backend.ex | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/nx/lib/nx/binary_backend.ex b/nx/lib/nx/binary_backend.ex index 5c98fb0ab4..84e05c6fbb 100644 --- a/nx/lib/nx/binary_backend.ex +++ b/nx/lib/nx/binary_backend.ex @@ -527,11 +527,13 @@ defmodule Nx.BinaryBackend do right_batch_item_bits = right_batch_item_length * right_size <<_::bitstring-size(^left_offset_bits), - left_batch_item_binary::bitstring-size(^left_batch_item_bits), _::bitstring>> = + left_batch_item_binary::bitstring-size(^left_batch_item_bits), + _::bitstring>> = left_binary <<_::bitstring-size(^right_offset_bits), - right_batch_item_binary::bitstring-size(^right_batch_item_bits), _::bitstring>> = + right_batch_item_binary::bitstring-size(^right_batch_item_bits), + _::bitstring>> = right_binary bin_dot( @@ -1779,7 +1781,8 @@ defmodule Nx.BinaryBackend do before_slice_size = current - previous <> = + current_bitstring::bitstring-size(^target_chunk), + to_traverse::bitstring>> = to_traverse updated_elements = From 82ee0c31c5f1e7575db7463f437364a5fcea71a6 Mon Sep 17 00:00:00 2001 From: Paulo Valente <16843419+polvalente@users.noreply.github.com> Date: Mon, 20 Apr 2026 14:32:36 -0300 Subject: [PATCH 5/5] chore: format --- nx/test/nx/defn/composite_test.exs | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/nx/test/nx/defn/composite_test.exs b/nx/test/nx/defn/composite_test.exs index 2618e29170..6dc8fc1d52 100644 --- a/nx/test/nx/defn/composite_test.exs +++ b/nx/test/nx/defn/composite_test.exs @@ -24,7 +24,8 @@ defmodule Nx.Defn.CompositeTest do Nx.tensor(1), Nx.tensor(3, type: {:c, 64}), Nx.tensor(4, type: {:c, 64}) - }, Nx.tensor(2, type: {:c, 64})} == + }, + Nx.tensor(2, type: {:c, 64})} == Composite.traverse( {1, Complex.new(2), Nx.tensor(3)}, 0,