@@ -8935,7 +8935,7 @@ defmodule Nx do
89358935 else
89368936 out = % { a | names: [ ] , shape: { } , type: { :u , 8 } }
89378937
8938- block ( struct ( Nx.Block.AllClose , opts ) , [ a , b ] , out , fn % Nx.Block.AllClose { } = o , a , b ->
8938+ block ( struct! ( Nx.Block.AllClose , opts ) , [ a , b ] , out , fn % Nx.Block.AllClose { } = o , a , b ->
89398939 vectorized_all_close ( a , b ,
89408940 equal_nan: o . equal_nan ,
89418941 rtol: o . rtol ,
@@ -14327,19 +14327,18 @@ defmodule Nx do
1432714327 indices = devectorize ( indices , keep_names: false )
1432814328 out = % { tensor | shape: inner_shape , names: inner_names }
1432914329
14330- block ( struct ( Nx.Block.Take , axis: axis ) , [ tensor , indices ] , out , fn % Nx.Block.Take { } ,
14331- tensor ,
14332- indices ->
14333- gather_indices = new_axis ( indices , rank ( indices ) )
14334- { indices_axes , tensor_axes } = Enum . split ( axes ( inner_shape ) , rank ( indices ) )
14335- { leading , trailing } = Enum . split ( tensor_axes , axis )
14330+ block ( struct! ( Nx.Block.Take , axis: axis ) , [ tensor , indices ] , out , fn
14331+ % Nx.Block.Take { } , tensor , indices ->
14332+ gather_indices = new_axis ( indices , rank ( indices ) )
14333+ { indices_axes , tensor_axes } = Enum . split ( axes ( inner_shape ) , rank ( indices ) )
14334+ { leading , trailing } = Enum . split ( tensor_axes , axis )
1433614335
14337- transpose_axes = leading ++ indices_axes ++ trailing
14336+ transpose_axes = leading ++ indices_axes ++ trailing
1433814337
14339- tensor
14340- |> gather ( gather_indices , axes: [ axis ] )
14341- |> transpose ( axes: transpose_axes )
14342- |> rename ( inner_names )
14338+ tensor
14339+ |> gather ( gather_indices , axes: [ axis ] )
14340+ |> transpose ( axes: transpose_axes )
14341+ |> rename ( inner_names )
1434314342 end )
1434414343 end
1434514344 end
@@ -14509,7 +14508,7 @@ defmodule Nx do
1450914508
1451014509 result =
1451114510 block (
14512- struct ( Nx.Block.TakeAlongAxis , axis: axis ) ,
14511+ struct! ( Nx.Block.TakeAlongAxis , axis: axis ) ,
1451314512 [ tensor , indices ] ,
1451414513 out ,
1451514514 fn % Nx.Block.TakeAlongAxis { } , tensor , indices ->
@@ -15327,7 +15326,7 @@ defmodule Nx do
1532715326 out_indices = % { tensor | shape: output_shape , names: output_names , type: { :s , 32 } }
1532815327
1532915328 block (
15330- struct ( Nx.Block.TopK , k: opts [ :k ] ) ,
15329+ struct! ( Nx.Block.TopK , k: opts [ :k ] ) ,
1533115330 [ tensor ] ,
1533215331 { out_values , out_indices } ,
1533315332 fn % Nx.Block.TopK { } = top_k , tensor ->
@@ -16774,9 +16773,9 @@ defmodule Nx do
1677416773
1677516774 block_struct =
1677616775 if kind == :fft2 do
16777- struct ( Nx.Block.FFT2 , eps: opts [ :eps ] , lengths: [ l1 , l2 ] , axes: [ ax1 , ax2 ] )
16776+ struct! ( Nx.Block.FFT2 , eps: opts [ :eps ] , lengths: [ l1 , l2 ] , axes: [ ax1 , ax2 ] )
1677816777 else
16779- struct ( Nx.Block.IFFT2 , eps: opts [ :eps ] , lengths: [ l1 , l2 ] , axes: [ ax1 , ax2 ] )
16778+ struct! ( Nx.Block.IFFT2 , eps: opts [ :eps ] , lengths: [ l1 , l2 ] , axes: [ ax1 , ax2 ] )
1678016779 end
1678116780
1678216781 block ( block_struct , [ tensor ] , out , fn s , tensor ->
0 commit comments