Skip to content

Commit cccc4f7

Browse files
authored
feat: add rfft/irfft (#1736)
1 parent 3e6e723 commit cccc4f7

10 files changed

Lines changed: 505 additions & 21 deletions

File tree

exla/lib/exla/defn.ex

Lines changed: 83 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -767,6 +767,63 @@ defmodule EXLA.Defn do
767767
{fft2(&Value.fft(&1, :ifft, &2, &3), [tensor, opts], expr, state), cache}
768768
end
769769

770+
defp cached_recur_operator(
771+
:block,
772+
%T{data: %Expr{args: [%Nx.Block.RFFT{} = rfft_struct, [tensor], expr, _callback]}},
773+
state,
774+
cache
775+
) do
776+
{tensor, cache} = recur_operator(tensor, state, cache) |> unwrap_single_tensor!()
777+
778+
opts = [length: rfft_struct.length, axis: rfft_struct.axis]
779+
780+
opts =
781+
if eps = rfft_struct.eps do
782+
Keyword.put(opts, :eps, eps)
783+
else
784+
opts
785+
end
786+
787+
# expr.type is complex; input tensor is real
788+
input_type = Nx.Type.to_real(expr.type)
789+
790+
{fft(&Value.fft(&1, :rfft, &2, &3), input_type, expr.type, [tensor, opts], expr, state),
791+
cache}
792+
end
793+
794+
defp cached_recur_operator(
795+
:block,
796+
%T{data: %Expr{args: [%Nx.Block.IRFFT{} = irfft_struct, [tensor], expr, _callback]}},
797+
state,
798+
cache
799+
) do
800+
{tensor, cache} = recur_operator(tensor, state, cache) |> unwrap_single_tensor!()
801+
802+
opts = [length: irfft_struct.length, axis: irfft_struct.axis]
803+
804+
opts =
805+
if eps = irfft_struct.eps do
806+
Keyword.put(opts, :eps, eps)
807+
else
808+
opts
809+
end
810+
811+
# expr.type is real; input tensor is complex.
812+
# pad_n = div(n,2)+1 (the expected input size), while fft_n = n (the output length).
813+
n = irfft_struct.length
814+
input_type = Nx.Type.to_complex(expr.type)
815+
816+
{fft(
817+
&Value.fft(&1, :irfft, &2, &3),
818+
input_type,
819+
expr.type,
820+
div(n, 2) + 1,
821+
[tensor, opts],
822+
expr,
823+
state
824+
), cache}
825+
end
826+
770827
defp cached_recur_operator(:block, %T{data: %Expr{args: args}}, state, cache) do
771828
[struct, in_args, expr, _callback] = args
772829
%module{} = struct
@@ -1233,10 +1290,10 @@ defmodule EXLA.Defn do
12331290
end
12341291

12351292
defp to_operator(:fft, [%Value{} | _] = args, out, state),
1236-
do: fft(&Value.fft(&1, :fft, &2, &3), args, out, state)
1293+
do: fft(&Value.fft(&1, :fft, &2, &3), out.type, out.type, args, out, state)
12371294

12381295
defp to_operator(:ifft, [%Value{} | _] = args, out, state),
1239-
do: fft(&Value.fft(&1, :ifft, &2, &3), args, out, state)
1296+
do: fft(&Value.fft(&1, :ifft, &2, &3), out.type, out.type, args, out, state)
12401297

12411298
defp to_operator(:is_nan, [%Value{} = arg], out, _state),
12421299
do: Value.is_nan(arg, expr_to_typespec(out))
@@ -1561,16 +1618,16 @@ defmodule EXLA.Defn do
15611618
EXLA.Lib.argsort(state.builder, tensor, dimension, stable, comp, ans.type)
15621619
end
15631620

1564-
defp fft(exla_op, [%Value{} = tensor, opts], %{type: type} = ans, state) do
1565-
n = opts[:length]
1621+
defp fft(exla_op, input_type, output_type, pad_n \\ nil, [%Value{} = tensor, opts], ans, state) do
1622+
fft_n = opts[:length]
1623+
pad_n = pad_n || fft_n
15661624
axis = opts[:axis]
1567-
output_type = Nx.Type.to_complex(type)
1568-
tensor = to_type(tensor, output_type)
1625+
tensor = to_type(tensor, input_type)
15691626

15701627
shape = op_shape(tensor)
15711628
m = elem(shape, axis)
15721629

1573-
tensor = fft_pad_or_slice(tensor, m, n, axis, shape, output_type, state)
1630+
tensor = fft_pad_or_slice(tensor, m, pad_n, axis, shape, input_type, state)
15741631

15751632
last_axis = tuple_size(shape) - 1
15761633

@@ -1582,15 +1639,26 @@ defmodule EXLA.Defn do
15821639
ax -> ax
15831640
end)
15841641

1585-
{transposed_shape, _} = Nx.Shape.transpose(ans.shape, permutation, ans.names)
1586-
transposed_typespec = Typespec.tensor(ans.type, transposed_shape)
1642+
padded_shape = op_shape(tensor)
1643+
1644+
{transposed_input_shape, _} =
1645+
Nx.Shape.transpose(
1646+
padded_shape,
1647+
permutation,
1648+
List.duplicate(nil, tuple_size(padded_shape))
1649+
)
1650+
1651+
transposed_input_typespec = Typespec.tensor(input_type, transposed_input_shape)
1652+
1653+
{transposed_output_shape, _} = Nx.Shape.transpose(ans.shape, permutation, ans.names)
1654+
transposed_output_typespec = Typespec.tensor(output_type, transposed_output_shape)
15871655

15881656
tensor
1589-
|> Value.transpose(permutation, transposed_typespec)
1590-
|> exla_op.([n], transposed_typespec)
1657+
|> Value.transpose(permutation, transposed_input_typespec)
1658+
|> exla_op.([fft_n], transposed_output_typespec)
15911659
|> Value.transpose(permutation, expr_to_typespec(ans))
15921660
else
1593-
exla_op.(tensor, [n], expr_to_typespec(ans))
1661+
exla_op.(tensor, [fft_n], expr_to_typespec(ans))
15941662
end
15951663
end
15961664

@@ -1655,8 +1723,10 @@ defmodule EXLA.Defn do
16551723
Value.slice(tensor, starts, limit_indices, strides, typespec)
16561724

16571725
m < n ->
1726+
zero_value = if Nx.Type.complex?(output_type), do: Complex.new(0), else: 0
1727+
16581728
zero =
1659-
Value.constant(state.builder, [Complex.new(0)], Typespec.tensor(output_type, {}))
1729+
Value.constant(state.builder, [zero_value], Typespec.tensor(output_type, {}))
16601730

16611731
padding_config =
16621732
{0, 0, 0}

exla/lib/exla/mlir/value.ex

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -372,7 +372,7 @@ defmodule EXLA.MLIR.Value do
372372
end
373373

374374
def fft(%Value{function: func} = value, fft_kind, fft_length, typespec)
375-
when fft_kind in [:fft, :ifft]
375+
when fft_kind in [:fft, :ifft, :rfft, :irfft]
376376
when is_list(fft_length) or is_integer(fft_length) do
377377
result_types = typespecs_to_mlir_types([typespec])
378378

@@ -1070,7 +1070,7 @@ defmodule EXLA.MLIR.Value do
10701070
defp attr_transpose(value) when value in [:adjoint, :transpose, :no_transpose],
10711071
do: attr_enum("stablehlo", "transpose", value)
10721072

1073-
defp attr_fft_type(value) when value in [:fft, :ifft],
1073+
defp attr_fft_type(value) when value in [:fft, :ifft, :rfft, :irfft],
10741074
do: attr_enum("stablehlo", "fft_type", value)
10751075

10761076
defp attr_enum(dialect, enum_name, value) do

0 commit comments

Comments
 (0)