@@ -601,11 +601,12 @@ defmodule EXLA.Defn do
601601 end
602602
603603 defp cached_recur_operator (
604- :optional ,
604+ :block ,
605605 % T {
606606 data: % Expr {
607607 args: [
608- % { data: % { op: :qr , args: [ tensor , _opts ] } } ,
608+ % Nx.Block.QR { } ,
609+ [ tensor ] ,
609610 { % { type: { type_kind , _ } } = q_expr , r_expr } ,
610611 _callback
611612 ]
@@ -631,11 +632,12 @@ defmodule EXLA.Defn do
631632 end
632633
633634 defp cached_recur_operator (
634- :optional ,
635+ :block ,
635636 % T {
636637 data: % Expr {
637638 args: [
638- % { data: % { op: :eigh , args: [ tensor , _opts ] } } ,
639+ % Nx.Block.Eigh { } ,
640+ [ tensor ] ,
639641 { % { type: { evec_type_kind , _ } } = eigenvals_expr ,
640642 % { type: { eval_type_kind , _ } } = eigenvecs_expr } ,
641643 _callback
@@ -672,16 +674,15 @@ defmodule EXLA.Defn do
672674 end
673675
674676 defp cached_recur_operator (
675- :optional ,
677+ :block ,
676678 % T {
677679 data: % Expr {
678- args: [ % { data: % { op: :take , args: [ tensor , indices , opts ] } } , expr , _callback ]
680+ args: [ % Nx.Block.Take { axis: axis } , [ tensor , indices ] , expr , _callback ]
679681 }
680682 } ,
681683 state ,
682684 cache
683685 ) do
684- axis = opts [ :axis ]
685686 { tensor , cache } = recur_operator ( tensor , state , cache ) |> unwrap_single_tensor! ( )
686687 { indices , cache } = recur_operator ( indices , state , cache ) |> unwrap_single_tensor! ( )
687688
@@ -714,56 +715,80 @@ defmodule EXLA.Defn do
714715 end
715716
716717 defp cached_recur_operator (
717- :optional ,
718- % T { data: % Expr { args: [ % { data: % { op: :top_k , args: [ tensor , opts ] } } , expr , _callback ] } } ,
718+ :block ,
719+ % T { data: % Expr { args: [ % Nx.Block.TopK { k: k } , [ tensor ] , expr , _callback ] } } ,
719720 state ,
720721 cache
721722 ) do
722723 { tensor , cache } = recur_operator ( tensor , state , cache ) |> unwrap_single_tensor! ( )
723724 { values , idx } = expr
724725 typespecs = [ expr_to_typespec ( values ) , expr_to_typespec ( idx ) ]
725- results = Value . top_k ( tensor , opts [ :k ] , typespecs )
726+ results = Value . top_k ( tensor , k , typespecs )
726727 { results , cache }
727728 end
728729
729730 defp cached_recur_operator (
730- :optional ,
731- % T { data: % Expr { args: [ % { data: % { op: :fft2 , args: [ tensor , opts ] } } , expr , _callback ] } } ,
731+ :block ,
732+ % T { data: % Expr { args: [ % Nx.Block.FFT2 { } = fft2_struct , [ tensor ] , expr , _callback ] } } ,
732733 state ,
733734 cache
734735 ) do
735736 { tensor , cache } = recur_operator ( tensor , state , cache ) |> unwrap_single_tensor! ( )
736737
738+ opts = [ lengths: fft2_struct . lengths , axes: fft2_struct . axes ]
739+
740+ opts =
741+ if eps = fft2_struct . eps do
742+ Keyword . put ( opts , :eps , eps )
743+ else
744+ opts
745+ end
746+
737747 { fft2 ( & Value . fft ( & 1 , :fft , & 2 , & 3 ) , [ tensor , opts ] , expr , state ) , cache }
738748 end
739749
740750 defp cached_recur_operator (
741- :optional ,
742- % T { data: % Expr { args: [ % { data: % { op: :ifft2 , args: [ tensor , opts ] } } , expr , _callback ] } } ,
751+ :block ,
752+ % T { data: % Expr { args: [ % Nx.Block.IFFT2 { } = ifft2_struct , [ tensor ] , expr , _callback ] } } ,
743753 state ,
744754 cache
745755 ) do
746756 { tensor , cache } = recur_operator ( tensor , state , cache ) |> unwrap_single_tensor! ( )
747757
758+ opts = [ lengths: ifft2_struct . lengths , axes: ifft2_struct . axes ]
759+
760+ opts =
761+ if eps = ifft2_struct . eps do
762+ Keyword . put ( opts , :eps , eps )
763+ else
764+ opts
765+ end
766+
748767 { fft2 ( & Value . fft ( & 1 , :ifft , & 2 , & 3 ) , [ tensor , opts ] , expr , state ) , cache }
749768 end
750769
751- defp cached_recur_operator ( :optional , % T { data: % Expr { args: args } } , state , cache ) do
752- [ call , expr , _callback ] = args
753- % { data: % { args: in_args , op: op } } = call
754-
755- { args , opts } = Enum . split_while ( in_args , & ( not is_list ( & 1 ) ) )
770+ defp cached_recur_operator ( :block , % T { data: % Expr { args: args } } , state , cache ) do
771+ [ struct , in_args , expr , _callback ] = args
772+ % module { } = struct
756773
757- { call_args , cache } = Enum . map_reduce ( args , cache , & recur_operator ( & 1 , state , & 2 ) )
758- key = computation_key ( op , call_args ++ opts )
774+ { call_args , cache } = Enum . map_reduce ( in_args , cache , & recur_operator ( & 1 , state , & 2 ) )
775+ key = computation_key ( module , [ struct | call_args ] )
759776
760777 { call_body , cache } =
761778 case cache do
762779 % { ^ key => computation } ->
763780 { computation , cache }
764781
765782 % { } ->
766- { computation , cache } = optional_computation ( "optional" , call_args , expr , state , cache )
783+ { computation , cache } =
784+ block_computation (
785+ block_subfunction_description ( struct ) ,
786+ call_args ,
787+ expr ,
788+ state ,
789+ cache
790+ )
791+
767792 { computation , Map . put ( cache , key , computation ) }
768793 end
769794
@@ -1818,8 +1843,14 @@ defmodule EXLA.Defn do
18181843 { region , merge_outfeed ( cache , comp_cache ) }
18191844 end
18201845
1821- defp optional_computation ( name , args , expr , % { builder: % Function { } } = state , cache ) do
1822- % Function { module: module , name: name } = subbuilder ( state . builder , name )
1846+ defp block_subfunction_description ( % module { } = _ ) do
1847+ module
1848+ |> Atom . to_string ( )
1849+ |> String . replace ( "." , "_" )
1850+ end
1851+
1852+ defp block_computation ( description , args , expr , % { builder: % Function { } } = state , cache ) do
1853+ % Function { module: module , name: name } = subbuilder ( state . builder , description )
18231854
18241855 arg_typespecs = Enum . map ( args , & Value . get_typespec / 1 )
18251856 out_typespecs = container_to_typespecs ( expr )
0 commit comments