Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
96 changes: 83 additions & 13 deletions exla/lib/exla/defn.ex
Original file line number Diff line number Diff line change
Expand Up @@ -767,6 +767,63 @@ defmodule EXLA.Defn do
{fft2(&Value.fft(&1, :ifft, &2, &3), [tensor, opts], expr, state), cache}
end

defp cached_recur_operator(
:block,
%T{data: %Expr{args: [%Nx.Block.RFFT{} = rfft_struct, [tensor], expr, _callback]}},
state,
cache
) do
{tensor, cache} = recur_operator(tensor, state, cache) |> unwrap_single_tensor!()

opts = [length: rfft_struct.length, axis: rfft_struct.axis]

opts =
if eps = rfft_struct.eps do
Keyword.put(opts, :eps, eps)
else
opts
end

# expr.type is complex; input tensor is real
input_type = Nx.Type.to_real(expr.type)

{fft(&Value.fft(&1, :rfft, &2, &3), input_type, expr.type, [tensor, opts], expr, state),
cache}
end

defp cached_recur_operator(
:block,
%T{data: %Expr{args: [%Nx.Block.IRFFT{} = irfft_struct, [tensor], expr, _callback]}},
state,
cache
) do
{tensor, cache} = recur_operator(tensor, state, cache) |> unwrap_single_tensor!()

opts = [length: irfft_struct.length, axis: irfft_struct.axis]

opts =
if eps = irfft_struct.eps do
Keyword.put(opts, :eps, eps)
else
opts
end

# expr.type is real; input tensor is complex.
# pad_n = div(n,2)+1 (the expected input size), while fft_n = n (the output length).
n = irfft_struct.length
input_type = Nx.Type.to_complex(expr.type)

{fft(
&Value.fft(&1, :irfft, &2, &3),
input_type,
expr.type,
div(n, 2) + 1,
[tensor, opts],
expr,
state
), cache}
end

defp cached_recur_operator(:block, %T{data: %Expr{args: args}}, state, cache) do
[struct, in_args, expr, _callback] = args
%module{} = struct
Expand Down Expand Up @@ -1233,10 +1290,10 @@ defmodule EXLA.Defn do
end

defp to_operator(:fft, [%Value{} | _] = args, out, state),
do: fft(&Value.fft(&1, :fft, &2, &3), args, out, state)
do: fft(&Value.fft(&1, :fft, &2, &3), out.type, out.type, args, out, state)

defp to_operator(:ifft, [%Value{} | _] = args, out, state),
do: fft(&Value.fft(&1, :ifft, &2, &3), args, out, state)
do: fft(&Value.fft(&1, :ifft, &2, &3), out.type, out.type, args, out, state)

defp to_operator(:is_nan, [%Value{} = arg], out, _state),
do: Value.is_nan(arg, expr_to_typespec(out))
Expand Down Expand Up @@ -1561,16 +1618,16 @@ defmodule EXLA.Defn do
EXLA.Lib.argsort(state.builder, tensor, dimension, stable, comp, ans.type)
end

defp fft(exla_op, [%Value{} = tensor, opts], %{type: type} = ans, state) do
n = opts[:length]
defp fft(exla_op, input_type, output_type, pad_n \\ nil, [%Value{} = tensor, opts], ans, state) do
fft_n = opts[:length]
pad_n = pad_n || fft_n
axis = opts[:axis]
output_type = Nx.Type.to_complex(type)
tensor = to_type(tensor, output_type)
tensor = to_type(tensor, input_type)

shape = op_shape(tensor)
m = elem(shape, axis)

tensor = fft_pad_or_slice(tensor, m, n, axis, shape, output_type, state)
tensor = fft_pad_or_slice(tensor, m, pad_n, axis, shape, input_type, state)

last_axis = tuple_size(shape) - 1

Expand All @@ -1582,15 +1639,26 @@ defmodule EXLA.Defn do
ax -> ax
end)

{transposed_shape, _} = Nx.Shape.transpose(ans.shape, permutation, ans.names)
transposed_typespec = Typespec.tensor(ans.type, transposed_shape)
padded_shape = op_shape(tensor)

{transposed_input_shape, _} =
Nx.Shape.transpose(
padded_shape,
permutation,
List.duplicate(nil, tuple_size(padded_shape))
)

transposed_input_typespec = Typespec.tensor(input_type, transposed_input_shape)

{transposed_output_shape, _} = Nx.Shape.transpose(ans.shape, permutation, ans.names)
transposed_output_typespec = Typespec.tensor(output_type, transposed_output_shape)

tensor
|> Value.transpose(permutation, transposed_typespec)
|> exla_op.([n], transposed_typespec)
|> Value.transpose(permutation, transposed_input_typespec)
|> exla_op.([fft_n], transposed_output_typespec)
|> Value.transpose(permutation, expr_to_typespec(ans))
else
exla_op.(tensor, [n], expr_to_typespec(ans))
exla_op.(tensor, [fft_n], expr_to_typespec(ans))
end
end

Expand Down Expand Up @@ -1655,8 +1723,10 @@ defmodule EXLA.Defn do
Value.slice(tensor, starts, limit_indices, strides, typespec)

m < n ->
zero_value = if Nx.Type.complex?(output_type), do: Complex.new(0), else: 0

zero =
Value.constant(state.builder, [Complex.new(0)], Typespec.tensor(output_type, {}))
Value.constant(state.builder, [zero_value], Typespec.tensor(output_type, {}))

padding_config =
{0, 0, 0}
Expand Down
4 changes: 2 additions & 2 deletions exla/lib/exla/mlir/value.ex
Original file line number Diff line number Diff line change
Expand Up @@ -372,7 +372,7 @@ defmodule EXLA.MLIR.Value do
end

def fft(%Value{function: func} = value, fft_kind, fft_length, typespec)
when fft_kind in [:fft, :ifft]
when fft_kind in [:fft, :ifft, :rfft, :irfft]
when is_list(fft_length) or is_integer(fft_length) do
result_types = typespecs_to_mlir_types([typespec])

Expand Down Expand Up @@ -1070,7 +1070,7 @@ defmodule EXLA.MLIR.Value do
defp attr_transpose(value) when value in [:adjoint, :transpose, :no_transpose],
do: attr_enum("stablehlo", "transpose", value)

defp attr_fft_type(value) when value in [:fft, :ifft],
defp attr_fft_type(value) when value in [:fft, :ifft, :rfft, :irfft],
do: attr_enum("stablehlo", "fft_type", value)

defp attr_enum(dialect, enum_name, value) do
Expand Down
Loading
Loading