Skip to content

Commit 3e6e723

Browse files
committed
Tiny fixes on block
1 parent ca587aa commit 3e6e723

5 files changed

Lines changed: 19 additions & 31 deletions

File tree

exla/lib/exla/backend.ex

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -330,7 +330,7 @@ defmodule EXLA.Backend do
330330
{tensors, rest} = Enum.split_while(args, &is_struct(&1, Nx.Tensor))
331331

332332
wrapper_fun = fn tensors ->
333-
Nx.Defn.Expr.block(struct, nil, Tuple.to_list(tensors) ++ rest, fun)
333+
Nx.Defn.Expr.block(struct, Tuple.to_list(tensors) ++ rest, fun)
334334
end
335335

336336
jit([], wrapper_fun, tensors, [List.to_tuple(tensors)])

nx/lib/nx.ex

Lines changed: 15 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -8935,7 +8935,7 @@ defmodule Nx do
89358935
else
89368936
out = %{a | names: [], shape: {}, type: {:u, 8}}
89378937

8938-
block(struct(Nx.Block.AllClose, opts), [a, b], out, fn %Nx.Block.AllClose{} = o, a, b ->
8938+
block(struct!(Nx.Block.AllClose, opts), [a, b], out, fn %Nx.Block.AllClose{} = o, a, b ->
89398939
vectorized_all_close(a, b,
89408940
equal_nan: o.equal_nan,
89418941
rtol: o.rtol,
@@ -14327,19 +14327,18 @@ defmodule Nx do
1432714327
indices = devectorize(indices, keep_names: false)
1432814328
out = %{tensor | shape: inner_shape, names: inner_names}
1432914329

14330-
block(struct(Nx.Block.Take, axis: axis), [tensor, indices], out, fn %Nx.Block.Take{},
14331-
tensor,
14332-
indices ->
14333-
gather_indices = new_axis(indices, rank(indices))
14334-
{indices_axes, tensor_axes} = Enum.split(axes(inner_shape), rank(indices))
14335-
{leading, trailing} = Enum.split(tensor_axes, axis)
14330+
block(struct!(Nx.Block.Take, axis: axis), [tensor, indices], out, fn
14331+
%Nx.Block.Take{}, tensor, indices ->
14332+
gather_indices = new_axis(indices, rank(indices))
14333+
{indices_axes, tensor_axes} = Enum.split(axes(inner_shape), rank(indices))
14334+
{leading, trailing} = Enum.split(tensor_axes, axis)
1433614335

14337-
transpose_axes = leading ++ indices_axes ++ trailing
14336+
transpose_axes = leading ++ indices_axes ++ trailing
1433814337

14339-
tensor
14340-
|> gather(gather_indices, axes: [axis])
14341-
|> transpose(axes: transpose_axes)
14342-
|> rename(inner_names)
14338+
tensor
14339+
|> gather(gather_indices, axes: [axis])
14340+
|> transpose(axes: transpose_axes)
14341+
|> rename(inner_names)
1434314342
end)
1434414343
end
1434514344
end
@@ -14509,7 +14508,7 @@ defmodule Nx do
1450914508

1451014509
result =
1451114510
block(
14512-
struct(Nx.Block.TakeAlongAxis, axis: axis),
14511+
struct!(Nx.Block.TakeAlongAxis, axis: axis),
1451314512
[tensor, indices],
1451414513
out,
1451514514
fn %Nx.Block.TakeAlongAxis{}, tensor, indices ->
@@ -15327,7 +15326,7 @@ defmodule Nx do
1532715326
out_indices = %{tensor | shape: output_shape, names: output_names, type: {:s, 32}}
1532815327

1532915328
block(
15330-
struct(Nx.Block.TopK, k: opts[:k]),
15329+
struct!(Nx.Block.TopK, k: opts[:k]),
1533115330
[tensor],
1533215331
{out_values, out_indices},
1533315332
fn %Nx.Block.TopK{} = top_k, tensor ->
@@ -16774,9 +16773,9 @@ defmodule Nx do
1677416773

1677516774
block_struct =
1677616775
if kind == :fft2 do
16777-
struct(Nx.Block.FFT2, eps: opts[:eps], lengths: [l1, l2], axes: [ax1, ax2])
16776+
struct!(Nx.Block.FFT2, eps: opts[:eps], lengths: [l1, l2], axes: [ax1, ax2])
1677816777
else
16779-
struct(Nx.Block.IFFT2, eps: opts[:eps], lengths: [l1, l2], axes: [ax1, ax2])
16778+
struct!(Nx.Block.IFFT2, eps: opts[:eps], lengths: [l1, l2], axes: [ax1, ax2])
1678016779
end
1678116780

1678216781
block(block_struct, [tensor], out, fn s, tensor ->

nx/lib/nx/defn/evaluator.ex

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -366,7 +366,7 @@ defmodule Nx.Defn.Evaluator do
366366
end
367367

368368
defp eval_apply(:block, [struct, in_args, expr, expr_cache], ans, state, caches) do
369-
{in_args, caches} = Tree.map_block_args(in_args, caches, &eval(&1, state, &2))
369+
{in_args, caches} = Enum.map_reduce(in_args, caches, &eval(&1, state, &2))
370370
{param_prefix, _} = Enum.split_while(in_args, &(not is_list(&1)))
371371
backend = Nx.Shared.list_impl!(param_prefix)
372372

nx/lib/nx/defn/expr.ex

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -808,7 +808,7 @@ defmodule Nx.Defn.Expr do
808808
@behaviour Nx.Backend
809809

810810
@impl true
811-
def block(struct, _output, in_args, fun) do
811+
def block(struct, _output \\ nil, in_args, fun) do
812812
expr_block(struct, in_args, fun)
813813
end
814814

nx/lib/nx/defn/tree.ex

Lines changed: 1 addition & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -181,7 +181,7 @@ defmodule Nx.Defn.Tree do
181181

182182
def apply_args(%T{data: %Expr{op: :block, args: args}}, type, acc, fun) do
183183
[struct, in_args, expr, callback] = args
184-
{in_args, acc} = map_block_args(in_args, acc, fun)
184+
{in_args, acc} = Enum.map_reduce(in_args, acc, fun)
185185

186186
{expr, acc} =
187187
case type do
@@ -254,15 +254,4 @@ defmodule Nx.Defn.Tree do
254254
arg, acc -> {arg, acc}
255255
end)
256256
end
257-
258-
@doc false
259-
def map_block_args(list, acc, fun) when is_list(list) do
260-
Enum.map_reduce(list, acc, fn
261-
%T{} = arg, acc ->
262-
fun.(arg, acc)
263-
264-
arg, acc ->
265-
{arg, acc}
266-
end)
267-
end
268257
end

0 commit comments

Comments
 (0)