@@ -767,6 +767,63 @@ defmodule EXLA.Defn do
767767 { fft2 ( & Value . fft ( & 1 , :ifft , & 2 , & 3 ) , [ tensor , opts ] , expr , state ) , cache }
768768 end
769769
770+ defp cached_recur_operator (
771+ :block ,
772+ % T { data: % Expr { args: [ % Nx.Block.RFFT { } = rfft_struct , [ tensor ] , expr , _callback ] } } ,
773+ state ,
774+ cache
775+ ) do
776+ { tensor , cache } = recur_operator ( tensor , state , cache ) |> unwrap_single_tensor! ( )
777+
778+ opts = [ length: rfft_struct . length , axis: rfft_struct . axis ]
779+
780+ opts =
781+ if eps = rfft_struct . eps do
782+ Keyword . put ( opts , :eps , eps )
783+ else
784+ opts
785+ end
786+
787+ # expr.type is complex; input tensor is real
788+ input_type = Nx.Type . to_real ( expr . type )
789+
790+ { fft ( & Value . fft ( & 1 , :rfft , & 2 , & 3 ) , input_type , expr . type , [ tensor , opts ] , expr , state ) ,
791+ cache }
792+ end
793+
794+ defp cached_recur_operator (
795+ :block ,
796+ % T { data: % Expr { args: [ % Nx.Block.IRFFT { } = irfft_struct , [ tensor ] , expr , _callback ] } } ,
797+ state ,
798+ cache
799+ ) do
800+ { tensor , cache } = recur_operator ( tensor , state , cache ) |> unwrap_single_tensor! ( )
801+
802+ opts = [ length: irfft_struct . length , axis: irfft_struct . axis ]
803+
804+ opts =
805+ if eps = irfft_struct . eps do
806+ Keyword . put ( opts , :eps , eps )
807+ else
808+ opts
809+ end
810+
811+ # expr.type is real; input tensor is complex.
812+ # pad_n = div(n,2)+1 (the expected input size), while fft_n = n (the output length).
813+ n = irfft_struct . length
814+ input_type = Nx.Type . to_complex ( expr . type )
815+
816+ { fft (
817+ & Value . fft ( & 1 , :irfft , & 2 , & 3 ) ,
818+ input_type ,
819+ expr . type ,
820+ div ( n , 2 ) + 1 ,
821+ [ tensor , opts ] ,
822+ expr ,
823+ state
824+ ) , cache }
825+ end
826+
770827 defp cached_recur_operator ( :block , % T { data: % Expr { args: args } } , state , cache ) do
771828 [ struct , in_args , expr , _callback ] = args
772829 % module { } = struct
@@ -1233,10 +1290,10 @@ defmodule EXLA.Defn do
12331290 end
12341291
12351292 defp to_operator ( :fft , [ % Value { } | _ ] = args , out , state ) ,
1236- do: fft ( & Value . fft ( & 1 , :fft , & 2 , & 3 ) , args , out , state )
1293+ do: fft ( & Value . fft ( & 1 , :fft , & 2 , & 3 ) , out . type , out . type , args , out , state )
12371294
12381295 defp to_operator ( :ifft , [ % Value { } | _ ] = args , out , state ) ,
1239- do: fft ( & Value . fft ( & 1 , :ifft , & 2 , & 3 ) , args , out , state )
1296+ do: fft ( & Value . fft ( & 1 , :ifft , & 2 , & 3 ) , out . type , out . type , args , out , state )
12401297
12411298 defp to_operator ( :is_nan , [ % Value { } = arg ] , out , _state ) ,
12421299 do: Value . is_nan ( arg , expr_to_typespec ( out ) )
@@ -1561,16 +1618,16 @@ defmodule EXLA.Defn do
15611618 EXLA.Lib . argsort ( state . builder , tensor , dimension , stable , comp , ans . type )
15621619 end
15631620
1564- defp fft ( exla_op , [ % Value { } = tensor , opts ] , % { type: type } = ans , state ) do
1565- n = opts [ :length ]
1621+ defp fft ( exla_op , input_type , output_type , pad_n \\ nil , [ % Value { } = tensor , opts ] , ans , state ) do
1622+ fft_n = opts [ :length ]
1623+ pad_n = pad_n || fft_n
15661624 axis = opts [ :axis ]
1567- output_type = Nx.Type . to_complex ( type )
1568- tensor = to_type ( tensor , output_type )
1625+ tensor = to_type ( tensor , input_type )
15691626
15701627 shape = op_shape ( tensor )
15711628 m = elem ( shape , axis )
15721629
1573- tensor = fft_pad_or_slice ( tensor , m , n , axis , shape , output_type , state )
1630+ tensor = fft_pad_or_slice ( tensor , m , pad_n , axis , shape , input_type , state )
15741631
15751632 last_axis = tuple_size ( shape ) - 1
15761633
@@ -1582,15 +1639,26 @@ defmodule EXLA.Defn do
15821639 ax -> ax
15831640 end )
15841641
1585- { transposed_shape , _ } = Nx.Shape . transpose ( ans . shape , permutation , ans . names )
1586- transposed_typespec = Typespec . tensor ( ans . type , transposed_shape )
1642+ padded_shape = op_shape ( tensor )
1643+
1644+ { transposed_input_shape , _ } =
1645+ Nx.Shape . transpose (
1646+ padded_shape ,
1647+ permutation ,
1648+ List . duplicate ( nil , tuple_size ( padded_shape ) )
1649+ )
1650+
1651+ transposed_input_typespec = Typespec . tensor ( input_type , transposed_input_shape )
1652+
1653+ { transposed_output_shape , _ } = Nx.Shape . transpose ( ans . shape , permutation , ans . names )
1654+ transposed_output_typespec = Typespec . tensor ( output_type , transposed_output_shape )
15871655
15881656 tensor
1589- |> Value . transpose ( permutation , transposed_typespec )
1590- |> exla_op . ( [ n ] , transposed_typespec )
1657+ |> Value . transpose ( permutation , transposed_input_typespec )
1658+ |> exla_op . ( [ fft_n ] , transposed_output_typespec )
15911659 |> Value . transpose ( permutation , expr_to_typespec ( ans ) )
15921660 else
1593- exla_op . ( tensor , [ n ] , expr_to_typespec ( ans ) )
1661+ exla_op . ( tensor , [ fft_n ] , expr_to_typespec ( ans ) )
15941662 end
15951663 end
15961664
@@ -1655,8 +1723,10 @@ defmodule EXLA.Defn do
16551723 Value . slice ( tensor , starts , limit_indices , strides , typespec )
16561724
16571725 m < n ->
1726+ zero_value = if Nx.Type . complex? ( output_type ) , do: Complex . new ( 0 ) , else: 0
1727+
16581728 zero =
1659- Value . constant ( state . builder , [ Complex . new ( 0 ) ] , Typespec . tensor ( output_type , { } ) )
1729+ Value . constant ( state . builder , [ zero_value ] , Typespec . tensor ( output_type , { } ) )
16601730
16611731 padding_config =
16621732 { 0 , 0 , 0 }
0 commit comments