diff --git a/exla/lib/exla/defn.ex b/exla/lib/exla/defn.ex index 89a35c48cf..3f6ba7def1 100644 --- a/exla/lib/exla/defn.ex +++ b/exla/lib/exla/defn.ex @@ -767,6 +767,63 @@ defmodule EXLA.Defn do {fft2(&Value.fft(&1, :ifft, &2, &3), [tensor, opts], expr, state), cache} end + defp cached_recur_operator( + :block, + %T{data: %Expr{args: [%Nx.Block.RFFT{} = rfft_struct, [tensor], expr, _callback]}}, + state, + cache + ) do + {tensor, cache} = recur_operator(tensor, state, cache) |> unwrap_single_tensor!() + + opts = [length: rfft_struct.length, axis: rfft_struct.axis] + + opts = + if eps = rfft_struct.eps do + Keyword.put(opts, :eps, eps) + else + opts + end + + # expr.type is complex; input tensor is real + input_type = Nx.Type.to_real(expr.type) + + {fft(&Value.fft(&1, :rfft, &2, &3), input_type, expr.type, [tensor, opts], expr, state), + cache} + end + + defp cached_recur_operator( + :block, + %T{data: %Expr{args: [%Nx.Block.IRFFT{} = irfft_struct, [tensor], expr, _callback]}}, + state, + cache + ) do + {tensor, cache} = recur_operator(tensor, state, cache) |> unwrap_single_tensor!() + + opts = [length: irfft_struct.length, axis: irfft_struct.axis] + + opts = + if eps = irfft_struct.eps do + Keyword.put(opts, :eps, eps) + else + opts + end + + # expr.type is real; input tensor is complex. + # pad_n = div(n,2)+1 (the expected input size), while fft_n = n (the output length). + n = irfft_struct.length + input_type = Nx.Type.to_complex(expr.type) + + {fft( + &Value.fft(&1, :irfft, &2, &3), + input_type, + expr.type, + div(n, 2) + 1, + [tensor, opts], + expr, + state + ), cache} + end + defp cached_recur_operator(:block, %T{data: %Expr{args: args}}, state, cache) do [struct, in_args, expr, _callback] = args %module{} = struct @@ -1233,10 +1290,10 @@ defmodule EXLA.Defn do end defp to_operator(:fft, [%Value{} | _] = args, out, state), - do: fft(&Value.fft(&1, :fft, &2, &3), args, out, state) + do: fft(&Value.fft(&1, :fft, &2, &3), out.type, out.type, args, out, state) defp to_operator(:ifft, [%Value{} | _] = args, out, state), - do: fft(&Value.fft(&1, :ifft, &2, &3), args, out, state) + do: fft(&Value.fft(&1, :ifft, &2, &3), out.type, out.type, args, out, state) defp to_operator(:is_nan, [%Value{} = arg], out, _state), do: Value.is_nan(arg, expr_to_typespec(out)) @@ -1561,16 +1618,16 @@ defmodule EXLA.Defn do EXLA.Lib.argsort(state.builder, tensor, dimension, stable, comp, ans.type) end - defp fft(exla_op, [%Value{} = tensor, opts], %{type: type} = ans, state) do - n = opts[:length] + defp fft(exla_op, input_type, output_type, pad_n \\ nil, [%Value{} = tensor, opts], ans, state) do + fft_n = opts[:length] + pad_n = pad_n || fft_n axis = opts[:axis] - output_type = Nx.Type.to_complex(type) - tensor = to_type(tensor, output_type) + tensor = to_type(tensor, input_type) shape = op_shape(tensor) m = elem(shape, axis) - tensor = fft_pad_or_slice(tensor, m, n, axis, shape, output_type, state) + tensor = fft_pad_or_slice(tensor, m, pad_n, axis, shape, input_type, state) last_axis = tuple_size(shape) - 1 @@ -1582,15 +1639,26 @@ defmodule EXLA.Defn do ax -> ax end) - {transposed_shape, _} = Nx.Shape.transpose(ans.shape, permutation, ans.names) - transposed_typespec = Typespec.tensor(ans.type, transposed_shape) + padded_shape = op_shape(tensor) + + {transposed_input_shape, _} = + Nx.Shape.transpose( + padded_shape, + permutation, + List.duplicate(nil, tuple_size(padded_shape)) + ) + + transposed_input_typespec = Typespec.tensor(input_type, transposed_input_shape) + + {transposed_output_shape, _} = Nx.Shape.transpose(ans.shape, permutation, ans.names) + transposed_output_typespec = Typespec.tensor(output_type, transposed_output_shape) tensor - |> Value.transpose(permutation, transposed_typespec) - |> exla_op.([n], transposed_typespec) + |> Value.transpose(permutation, transposed_input_typespec) + |> exla_op.([fft_n], transposed_output_typespec) |> Value.transpose(permutation, expr_to_typespec(ans)) else - exla_op.(tensor, [n], expr_to_typespec(ans)) + exla_op.(tensor, [fft_n], expr_to_typespec(ans)) end end @@ -1655,8 +1723,10 @@ defmodule EXLA.Defn do Value.slice(tensor, starts, limit_indices, strides, typespec) m < n -> + zero_value = if Nx.Type.complex?(output_type), do: Complex.new(0), else: 0 + zero = - Value.constant(state.builder, [Complex.new(0)], Typespec.tensor(output_type, {})) + Value.constant(state.builder, [zero_value], Typespec.tensor(output_type, {})) padding_config = {0, 0, 0} diff --git a/exla/lib/exla/mlir/value.ex b/exla/lib/exla/mlir/value.ex index 463ed01faf..9d028ff6dd 100644 --- a/exla/lib/exla/mlir/value.ex +++ b/exla/lib/exla/mlir/value.ex @@ -372,7 +372,7 @@ defmodule EXLA.MLIR.Value do end def fft(%Value{function: func} = value, fft_kind, fft_length, typespec) - when fft_kind in [:fft, :ifft] + when fft_kind in [:fft, :ifft, :rfft, :irfft] when is_list(fft_length) or is_integer(fft_length) do result_types = typespecs_to_mlir_types([typespec]) @@ -1070,7 +1070,7 @@ defmodule EXLA.MLIR.Value do defp attr_transpose(value) when value in [:adjoint, :transpose, :no_transpose], do: attr_enum("stablehlo", "transpose", value) - defp attr_fft_type(value) when value in [:fft, :ifft], + defp attr_fft_type(value) when value in [:fft, :ifft, :rfft, :irfft], do: attr_enum("stablehlo", "fft_type", value) defp attr_enum(dialect, enum_name, value) do diff --git a/nx/lib/nx.ex b/nx/lib/nx.ex index 339bf019d2..82b3e7200d 100644 --- a/nx/lib/nx.ex +++ b/nx/lib/nx.ex @@ -16797,6 +16797,283 @@ defmodule Nx do end) end + @doc """ + Calculates the real-input DFT of the given tensor. + + Exploits the conjugate-symmetry property of the DFT for real inputs by + computing the full FFT and returning only the non-redundant first + `floor(length / 2) + 1` frequency components along the transform axis. + + ## Options + + * `:eps` - Threshold which backends can use for cleaning-up results. Defaults to `1.0e-10`. + * `:length` - Either a positive integer or `:power_of_two`. Will pad or slice the tensor + along the transform axis accordingly. `:power_of_two` will automatically pad to the + next power of two. Defaults to the axis size. + * `:axis` - the axis upon which the real DFT will be calculated. Defaults to the last axis. + + ## Examples + + iex> Nx.rfft(Nx.tensor([1.0, 1.0, 0.0, 0.0])) + #Nx.Tensor< + c64[3] + [2.0+0.0i, 1.0-1.0i, 0.0+0.0i] + > + + iex> Nx.rfft(Nx.tensor([1.0, 1.0, 1.0, 0.0, 1.0, 1.0])) + #Nx.Tensor< + c64[4] + [5.0+0.0i, 1.0+0.0i, -1.0+0.0i, 1.0+0.0i] + > + + The calculation can happen on a specific axis: + + iex> tensor = Nx.tensor([[1.0, 1.0, 0.0, 0.0], [1.0, 0.0, 0.0, 0.0]]) + iex> Nx.rfft(tensor, axis: -1) + #Nx.Tensor< + c64[2][3] + [ + [2.0+0.0i, 1.0-1.0i, 0.0+0.0i], + [1.0+0.0i, 1.0+0.0i, 1.0+0.0i] + ] + > + iex> Nx.rfft(tensor, axis: -2) + #Nx.Tensor< + c64[2][4] + [ + [2.0+0.0i, 1.0+0.0i, 0.0+0.0i, 0.0+0.0i], + [0.0+0.0i, 1.0+0.0i, 0.0+0.0i, 0.0+0.0i] + ] + > + + Padding and slicing can be introduced through `:length`: + + iex> Nx.rfft(Nx.tensor([1.0, 1.0]), length: 4) + #Nx.Tensor< + c64[3] + [2.0+0.0i, 1.0-1.0i, 0.0+0.0i] + > + + iex> Nx.rfft(Nx.tensor([1.0, 1.0, 0.0]), length: :power_of_two) + #Nx.Tensor< + c64[3] + [2.0+0.0i, 1.0-1.0i, 0.0+0.0i] + > + + ## Vectorized tensors + + Vectorized tensors work the same as N-dimensional tensors + + iex> tensor = Nx.tensor([[1.0, 1.0, 0.0, 0.0], [1.0, 0.0, 0.0, 0.0]]) |> Nx.vectorize(:x) + iex> Nx.rfft(tensor) + #Nx.Tensor< + vectorized[x: 2] + c64[3] + [ + [2.0+0.0i, 1.0-1.0i, 0.0+0.0i], + [1.0+0.0i, 1.0+0.0i, 1.0+0.0i] + ] + > + + ## Error Cases + + iex> Nx.rfft(Nx.tensor([Complex.new(1, 0), Complex.new(0, 1)])) + ** (ArgumentError) Nx.rfft/2 expects a real tensor, got type: {:c, 64} + + iex> Nx.rfft(Nx.tensor([1.0, 1.0]), length: :invalid) + ** (ArgumentError) expected an integer or :power_of_two as length, got: :invalid + """ + @doc type: :ndim + def rfft(tensor, opts \\ []) do + tensor = to_tensor(tensor) + + if Nx.Type.complex?(tensor.type) do + raise ArgumentError, "Nx.rfft/2 expects a real tensor, got type: #{inspect(tensor.type)}" + end + + apply_vectorized(tensor, fn tensor, offset -> + shape = Nx.Shape.fft(tensor.shape) + opts = Keyword.validate!(opts, [:length, axis: -1, eps: 1.0e-10]) + + axis = Nx.Shape.normalize_axis(shape, opts[:axis], tensor.names, offset) + n = elem(shape, axis) + + length = + case opts[:length] do + :power_of_two -> + 2 ** Kernel.ceil(:math.log2(n)) + + nil -> + n + + n when is_integer(n) and n > 0 -> + n + + length -> + raise ArgumentError, + "expected an integer or :power_of_two as length, got: #{inspect(length)}" + end + + rfft_length = div(length, 2) + 1 + + output_shape = + shape + |> Tuple.insert_at(axis, rfft_length) + |> Tuple.delete_at(axis + 1) + + out = to_template(%{tensor | shape: output_shape, type: Nx.Type.to_complex(tensor.type)}) + block_struct = struct!(Nx.Block.RFFT, eps: opts[:eps], length: length, axis: axis) + + block(block_struct, [tensor], out, fn s, tensor -> + tensor + |> fft(length: s.length, axis: s.axis, eps: s.eps) + |> slice_along_axis(0, div(s.length, 2) + 1, axis: s.axis) + end) + end) + end + + @doc """ + Calculates the Inverse real-input DFT of the given tensor. + + Reconstructs a real-valued signal from a one-sided complex spectrum produced + by `rfft/2`. The input is assumed to contain the non-redundant Hermitian half + of a spectrum of length `n` (i.e. `floor(n / 2) + 1` elements along the + transform axis). The missing conjugate-symmetric components are derived + automatically before calling `ifft/2`. + + ## Options + + * `:eps` - Threshold which backends can use for cleaning-up results. Defaults to `1.0e-10`. + * `:length` - A positive integer specifying the output signal length `n`. Defaults to + `2 * (m - 1)` where `m` is the axis size of the input, which assumes the original + signal had even length. Pass an explicit `:length` to recover odd-length signals. + * `:axis` - the axis upon which the Inverse real DFT will be calculated. Defaults to the + last axis. + + ## Examples + + iex> Nx.irfft(Nx.tensor([2.0, Complex.new(1.0, -1.0), 0.0])) + #Nx.Tensor< + f32[4] + [1.0, 1.0, 0.0, 0.0] + > + + iex> Nx.irfft(Nx.tensor([5.0, 1.0, -1.0, 1.0])) + #Nx.Tensor< + f32[6] + [1.0, 1.0, 1.0, 0.0, 1.0, 1.0] + > + + The calculation can happen on a specific axis: + + iex> tensor = Nx.tensor([[2.0, Complex.new(1.0, -1.0), 0.0], [4.0, 0.0, 0.0]]) + iex> Nx.irfft(tensor, axis: -1) + #Nx.Tensor< + f32[2][4] + [ + [1.0, 1.0, 0.0, 0.0], + [1.0, 1.0, 1.0, 1.0] + ] + > + + An explicit `:length` recovers odd-length signals and controls input truncation: + + iex> Nx.irfft(Nx.tensor([2.0, Complex.new(1.0, -1.0), 0.0]), length: 4) + #Nx.Tensor< + f32[4] + [1.0, 1.0, 0.0, 0.0] + > + + ## Vectorized tensors + + Vectorized tensors work the same as N-dimensional tensors + + iex> tensor = Nx.tensor([[2.0, Complex.new(1.0, -1.0), 0.0], [4.0, 0.0, 0.0]]) |> Nx.vectorize(:x) + iex> Nx.irfft(tensor) + #Nx.Tensor< + vectorized[x: 2] + f32[4] + [ + [1.0, 1.0, 0.0, 0.0], + [1.0, 1.0, 1.0, 1.0] + ] + > + + ## Error Cases + + iex> Nx.irfft(Nx.tensor([1.0, 1.0]), length: :invalid) + ** (ArgumentError) expected a positive integer as length, got: :invalid + """ + @doc type: :ndim + def irfft(tensor, opts \\ []) do + apply_vectorized(tensor, fn tensor, offset -> + shape = Nx.Shape.fft(tensor.shape) + opts = Keyword.validate!(opts, [:length, axis: -1, eps: 1.0e-10]) + + axis = Nx.Shape.normalize_axis(shape, opts[:axis], tensor.names, offset) + actual_m = elem(shape, axis) + + n = + case opts[:length] do + nil -> + 2 * (actual_m - 1) + + n when is_integer(n) and n > 0 -> + n + + length -> + raise ArgumentError, + "expected a positive integer as length, got: #{inspect(length)}" + end + + output_shape = + shape + |> Tuple.insert_at(axis, n) + |> Tuple.delete_at(axis + 1) + + out = to_template(%{tensor | shape: output_shape, type: Nx.Type.to_real(tensor.type)}) + block_struct = struct(Nx.Block.IRFFT, eps: opts[:eps], length: n, axis: axis) + + block(block_struct, [tensor], out, fn s, tensor -> + axis = s.axis + n = s.length + m = div(n, 2) + 1 + + actual_m = elem(Nx.shape(tensor), axis) + + tensor = + cond do + actual_m > m -> + slice_along_axis(tensor, 0, m, axis: axis) + + actual_m < m -> + zeros = List.duplicate({0, 0, 0}, tuple_size(Nx.shape(tensor))) + padding_config = List.replace_at(zeros, axis, {0, m - actual_m, 0}) + pad(tensor, 0, padding_config) + + true -> + tensor + end + + # mirror_count = n - m handles both even and odd n: + # even n=8: m=5, mirror indices 1..3 (3 elements), total=8 + # odd n=7: m=4, mirror indices 1..3 (3 elements), total=7 + mirror_count = n - m + + mirror = + tensor + |> slice_along_axis(1, mirror_count, axis: axis) + |> conjugate() + |> reverse(axes: [axis]) + + [tensor, mirror] + |> concatenate(axis: axis) + |> ifft(axis: axis, eps: s.eps) + |> real() + end) + end) + end + @doc """ Creates a tensor of shape `{n}` with linearly spaced samples between `start` and `stop`. diff --git a/nx/lib/nx/block.ex b/nx/lib/nx/block.ex index c47caa1531..944955ca70 100644 --- a/nx/lib/nx/block.ex +++ b/nx/lib/nx/block.ex @@ -73,3 +73,11 @@ end defmodule Nx.Block.IFFT2 do defstruct eps: nil, lengths: nil, axes: nil end + +defmodule Nx.Block.RFFT do + defstruct eps: nil, length: nil, axis: nil +end + +defmodule Nx.Block.IRFFT do + defstruct eps: nil, length: nil, axis: nil +end diff --git a/nx/test/nx/defn/grad_test.exs b/nx/test/nx/defn/grad_test.exs index 8326e2a9a6..df3efbcb89 100644 --- a/nx/test/nx/defn/grad_test.exs +++ b/nx/test/nx/defn/grad_test.exs @@ -4699,12 +4699,20 @@ defmodule Nx.Defn.GradTest do b = Nx.tensor([2, 3, 4]) assert_all_close( - triangular_solve_grad_wrt_a(a, b, transform_a: :conjugate, left_side: false, lower: false), + triangular_solve_grad_wrt_a(a, b, + transform_a: :conjugate, + left_side: false, + lower: false + ), triangular_solve_grad_wrt_a(a, b, transform_a: :none, left_side: false, lower: false) ) assert_all_close( - triangular_solve_grad_wrt_b(a, b, transform_a: :conjugate, left_side: false, lower: false), + triangular_solve_grad_wrt_b(a, b, + transform_a: :conjugate, + left_side: false, + lower: false + ), triangular_solve_grad_wrt_b(a, b, transform_a: :none, left_side: false, lower: false) ) end diff --git a/torchx/c_src/torchx.cpp b/torchx/c_src/torchx.cpp index 477c8d21d4..ffaab13fb0 100644 --- a/torchx/c_src/torchx.cpp +++ b/torchx/c_src/torchx.cpp @@ -865,6 +865,22 @@ ifft(ErlNifEnv *env, fine::ResourcePtr tensor, int64_t length, REGISTER_TENSOR_NIF(ifft); +fine::Ok> +rfft(ErlNifEnv *env, fine::ResourcePtr tensor, int64_t length, + int64_t axis) { + return tensor_ok(torch::fft::rfft(get_tensor(tensor), length, axis)); +} + +REGISTER_TENSOR_NIF(rfft); + +fine::Ok> +irfft(ErlNifEnv *env, fine::ResourcePtr tensor, int64_t length, + int64_t axis) { + return tensor_ok(torch::fft::irfft(get_tensor(tensor), length, axis)); +} + +REGISTER_TENSOR_NIF(irfft); + fine::Ok> fft2(ErlNifEnv *env, fine::ResourcePtr tensor, std::vector lengths, std::vector axes) { diff --git a/torchx/lib/torchx.ex b/torchx/lib/torchx.ex index e94195f886..506fa958ec 100644 --- a/torchx/lib/torchx.ex +++ b/torchx/lib/torchx.ex @@ -353,6 +353,8 @@ defmodule Torchx do deftensor cbrt(tensor) deftensor fft(tensor, length, axis) deftensor ifft(tensor, length, axis) + deftensor rfft(tensor, length, axis) + deftensor irfft(tensor, length, axis) deftensor fft2(tensor, lengths, axes) deftensor ifft2(tensor, lengths, axes) diff --git a/torchx/lib/torchx/backend.ex b/torchx/lib/torchx/backend.ex index ce0985cc40..ea9abb93cd 100644 --- a/torchx/lib/torchx/backend.ex +++ b/torchx/lib/torchx/backend.ex @@ -94,6 +94,12 @@ defmodule Torchx.Backend do {Nx.Block.TakeAlongAxis, [tensor, indices]} -> take_along_axis_gather(output, tensor, indices, axis: struct.axis) + {Nx.Block.RFFT, [t]} -> + rfft_torchx(output, t, struct.length, struct.axis) + + {Nx.Block.IRFFT, [t]} -> + irfft_torchx(output, t, struct.length, struct.axis) + {Nx.Block.FFT2, [t]} -> fft2_torchx(output, t, struct.lengths, struct.axes) @@ -1092,6 +1098,20 @@ defmodule Torchx.Backend do |> to_nx(out) end + defp rfft_torchx(out, tensor, length, axis) do + tensor + |> from_nx() + |> Torchx.rfft(length, axis) + |> to_nx(out) + end + + defp irfft_torchx(out, tensor, length, axis) do + tensor + |> from_nx() + |> Torchx.irfft(length, axis) + |> to_nx(out) + end + defp fft2_torchx(out, tensor, lengths, axes) do tensor |> from_nx() diff --git a/torchx/test/torchx/nx_block_test.exs b/torchx/test/torchx/nx_block_test.exs index e951e1f699..be307d77b5 100644 --- a/torchx/test/torchx/nx_block_test.exs +++ b/torchx/test/torchx/nx_block_test.exs @@ -2,7 +2,7 @@ defmodule Torchx.NxBlockTest do @moduledoc """ Numerical coverage for `Nx.block/4`-backed APIs on Torchx. - `Nx.fft2/2` and `Nx.ifft2/2` route through `%Nx.Block.FFT2{}` / `%Nx.Block.IFFT2{}`. + `Nx.fft2/2`, `Nx.ifft2/2`, `Nx.rfft/2`, and `Nx.irfft/2` route through block structs. `Torchx.NxDoctestTest` excludes their doctests (BinaryBackend `inspect` strings vs LibTorch signed zeros). Here we assert agreement with a `Nx.BinaryBackend` reference. """ @@ -22,6 +22,87 @@ defmodule Torchx.NxBlockTest do assert Nx.all_close(t_torchx, ref) end + describe "rfft / irfft (block-backed)" do + test "rfft basic" do + same_as_binary( + fn -> Nx.rfft(Nx.tensor([1.0, 1.0, 0.0, 0.0])) end, + fn -> Nx.rfft(Nx.tensor([1.0, 1.0, 0.0, 0.0], backend: Nx.BinaryBackend)) end + ) + end + + test "rfft signed-zero case" do + same_as_binary( + fn -> Nx.rfft(Nx.tensor([1.0, 1.0, 1.0, 0.0, 1.0, 1.0])) end, + fn -> + Nx.rfft(Nx.tensor([1.0, 1.0, 1.0, 0.0, 1.0, 1.0], backend: Nx.BinaryBackend)) + end + ) + end + + test "rfft with axis and length options" do + same_as_binary( + fn -> + tensor = Nx.tensor([[1.0, 1.0, 0.0, 0.0], [1.0, 0.0, 0.0, 0.0]]) + Nx.rfft(tensor, axis: -2) + end, + fn -> + tensor = + Nx.tensor([[1.0, 1.0, 0.0, 0.0], [1.0, 0.0, 0.0, 0.0]], backend: Nx.BinaryBackend) + + Nx.rfft(tensor, axis: -2) + end + ) + end + + test "rfft vectorized" do + same_as_binary( + fn -> + Nx.tensor([[1.0, 1.0, 0.0, 0.0], [1.0, 0.0, 0.0, 0.0]]) + |> Nx.vectorize(:x) + |> Nx.rfft() + end, + fn -> + Nx.tensor([[1.0, 1.0, 0.0, 0.0], [1.0, 0.0, 0.0, 0.0]], backend: Nx.BinaryBackend) + |> Nx.vectorize(:x) + |> Nx.rfft() + end + ) + end + + test "irfft basic" do + same_as_binary( + fn -> Nx.irfft(Nx.tensor([2.0, Complex.new(1.0, -1.0), 0.0])) end, + fn -> + Nx.irfft(Nx.tensor([2.0, Complex.new(1.0, -1.0), 0.0], backend: Nx.BinaryBackend)) + end + ) + end + + test "irfft with length" do + same_as_binary( + fn -> Nx.irfft(Nx.tensor([5.0, 1.0, -1.0, 1.0])) end, + fn -> Nx.irfft(Nx.tensor([5.0, 1.0, -1.0, 1.0], backend: Nx.BinaryBackend)) end + ) + end + + test "irfft vectorized" do + same_as_binary( + fn -> + Nx.tensor([[2.0, Complex.new(1.0, -1.0), 0.0], [4.0, 0.0, 0.0]]) + |> Nx.vectorize(:x) + |> Nx.irfft() + end, + fn -> + Nx.tensor([[2.0, Complex.new(1.0, -1.0), 0.0], [4.0, 0.0, 0.0]], + backend: Nx.BinaryBackend + ) + |> Nx.vectorize(:x) + |> Nx.irfft() + end + ) + end + end + describe "fft2 / ifft2 (block-backed)" do test "fft2 simple matrix" do same_as_binary( diff --git a/torchx/test/torchx/nx_doctest_test.exs b/torchx/test/torchx/nx_doctest_test.exs index ee0ffd8d86..ca6ed2be84 100644 --- a/torchx/test/torchx/nx_doctest_test.exs +++ b/torchx/test/torchx/nx_doctest_test.exs @@ -5,9 +5,9 @@ defmodule Torchx.NxDoctestTest do Many tests are excluded for the reasons below, coverage for the excluded tests can be found on Torchx.NxTest. - `Nx.fft2/2` and `Nx.ifft2/2` are implemented via `Nx.block/4` (`%Nx.Block.FFT2{}` / - `%Nx.Block.IFFT2{}`). Doctest expectations are written for `Nx.BinaryBackend` string - `inspect/1` output; LibTorch can produce signed zeros so complex `inspect` differs. + `Nx.fft2/2`, `Nx.ifft2/2`, `Nx.rfft/2`, and `Nx.irfft/2` are implemented via `Nx.block/4`. + Doctest expectations are written for `Nx.BinaryBackend` string `inspect/1` output; + LibTorch can produce signed zeros so complex `inspect` differs. Those doctests are excluded here; see `Torchx.NxBlockTest` for numerical checks. """ @@ -28,6 +28,8 @@ defmodule Torchx.NxDoctestTest do ifft: 2, fft2: 2, ifft2: 2, + rfft: 2, + irfft: 2, expm1: 1, standard_deviation: 2 ]