diff --git a/nx/lib/nx/defn/grad.ex b/nx/lib/nx/defn/grad.ex index a681504b9f..73e29fa3e8 100644 --- a/nx/lib/nx/defn/grad.ex +++ b/nx/lib/nx/defn/grad.ex @@ -5,6 +5,8 @@ defmodule Nx.Defn.Grad do alias Nx.Tensor, as: T def transform(to_grad, fun, transform) do + to_grad = apply_boundary_broadcast(to_grad) + {to_grad, ids} = Composite.traverse(to_grad, %{}, fn to_grad, ids -> to_grad = @@ -26,15 +28,20 @@ defmodule Nx.Defn.Grad do {parents, nodes} = parents_tree(transformed_expr, ids) - to_grad_ids = {to_grad, ids} - grads = %{transformed_expr.data.id => [constant(1.0, transformed_expr)]} + output_vectorized_axes = transformed_expr.vectorized_axes + batch_count = length(output_vectorized_axes) + to_grad_ids = {to_grad, ids, batch_count} + + # Seed the backward pass in devectorized space. + devec_expr = Nx.devectorize(transformed_expr, keep_names: false) + grads = %{transformed_expr.data.id => [constant(1.0, devec_expr)]} {graded, _} = Composite.traverse( to_grad, {nodes, grads}, fn node, acc -> - to_grad(node, to_grad_ids, parents, acc) + to_grad(node, to_grad_ids, parents, acc, output_vectorized_axes) end ) @@ -46,6 +53,21 @@ defmodule Nx.Defn.Grad do Expr.constant(%{t | names: names, type: {:f, 32}}, float, []) end + # Align heterogenous vectorized inputs to the union of vec axes — matching + # the implicit alignment the forward pass already performs, so the grad + # recursion sees a homogeneous case. + defp apply_boundary_broadcast(to_grad) do + case Composite.flatten_list([to_grad]) do + [_ | _] = flat -> + broadcast = Nx.broadcast_vectors(flat) + {result, []} = Composite.traverse(to_grad, broadcast, fn _, [h | t] -> {h, t} end) + result + + _ -> + to_grad + end + end + defp validate_expr!(%T{data: %Expr{}} = expr) do expr end @@ -91,89 +113,57 @@ defmodule Nx.Defn.Grad do Composite.reduce( expr, {%{}, nodes}, - &recur_parents_tree( - Nx.devectorize(&1, keep_names: true), - &2, - Keyword.keys(&1.vectorized_axes) - ) + &recur_parents_tree(Nx.devectorize(&1, keep_names: true), &2) ) end - defp recur_parents_tree(%T{data: %Expr{id: id, op: op}} = t, {parents, nodes}, vectorized_names) do + defp recur_parents_tree(%T{data: %Expr{id: id, op: op}} = t, {parents, nodes}) do case nodes do %{^id => _} -> {parents, nodes} %{} -> - # We use this to compute the proper axis sizes for the tensor - nodes = Map.put(nodes, id, {t, vectorized_names}) - - parents_args(op, t, id, {parents, nodes}, vectorized_names) + nodes = Map.put(nodes, id, t) + parents_args(op, t, id, {parents, nodes}) end end - defp parents_args( - :metadata, - %{data: %{args: [_, %{stop_grad: true}]}}, - _id, - acc, - _parent_vectorized_names - ) do + defp parents_args(:metadata, %{data: %{args: [_, %{stop_grad: true}]}}, _id, acc) do acc end - defp parents_args(:runtime_call, _expr, _id, acc, _parent_vectorized_names) do + defp parents_args(:runtime_call, _expr, _id, acc) do acc end - defp parents_args( - :block, - %{data: %{args: [struct, in_args, _expr, callback]}} = t, - id, - acc, - parent_vectorized_names - ) do + defp parents_args(:block, %{data: %{args: [struct, in_args, _expr, callback]}} = t, id, acc) do expr = apply(callback, [struct | in_args]) - # Now traverse over the optional expression where args are the new parameters. + # Now traverse over the block expression where args are the new parameters. # Once we access the parameter itself, we point the parameter to the arg. - {{parents, nodes}, _} = - Composite.reduce(expr, {acc, parent_vectorized_names}, fn - expr, {{parents, nodes}, expr_vectorized_names} -> - arg_vectorized_names = compute_arg_vectorized_names(expr, expr_vectorized_names) - parents = Map.update(parents, expr.data.id, [id], &[id | &1]) - - acc = - recur_parents_tree( - expr, - {parents, nodes}, - arg_vectorized_names - ) - - {acc, expr_vectorized_names} + {parents, nodes} = + Composite.reduce(expr, acc, fn expr, {parents, nodes} -> + parents = Map.update(parents, expr.data.id, [id], &[id | &1]) + recur_parents_tree(expr, {parents, nodes}) end) - updated_node = - {put_in(t.data.args, [struct, in_args, expr, callback]), parent_vectorized_names} - + updated_node = put_in(t.data.args, [struct, in_args, expr, callback]) {parents, Map.put(nodes, id, updated_node)} end # We register cond as a special node to avoid pretraversing it. # Instead we traverse it early on on the grad computation. - defp parents_args(:cond, _, id, {parents, nodes}, _parent_vectorized_names) do + defp parents_args(:cond, _, id, {parents, nodes}) do {Map.update(parents, __MODULE__, [id], &[id | &1]), nodes} end - defp parents_args(op, t, parent_id, acc, parent_vectorized_names) do + defp parents_args(op, t, parent_id, acc) do reduce_args(op, t, acc, fn arg, {parents, nodes} -> if arg.data.op in @constants do {parents, nodes} else - arg_vectorized_names = compute_arg_vectorized_names(t, parent_vectorized_names) parents = Map.update(parents, arg.data.id, [parent_id], &[parent_id | &1]) - - recur_parents_tree(arg, {parents, nodes}, arg_vectorized_names) + recur_parents_tree(arg, {parents, nodes}) end end) end @@ -207,14 +197,28 @@ defmodule Nx.Defn.Grad do ## Recursion - defp to_grad(arg, to_grad_ids, parents, acc) do + defp to_grad(arg, to_grad_ids, parents, acc, output_vectorized_axes) do id = arg.data.id acc = traverse_parents(__MODULE__, to_grad_ids, parents, acc) acc = traverse_parents(id, to_grad_ids, parents, acc) {nodes, grads} = acc res = sum_grad(Map.get(grads, id, [])) - {Nx.broadcast(res, arg), {nodes, grads}} + + res = + cond do + arg.vectorized_axes != [] and res.vectorized_axes == [] -> + Nx.vectorize(res, arg.vectorized_axes) + + arg.vectorized_axes == [] and output_vectorized_axes != [] and + tuple_size(res.shape) > tuple_size(arg.shape) -> + Nx.vectorize(res, output_vectorized_axes) + + true -> + Nx.broadcast(res, arg) + end + + {res, {nodes, grads}} end defp sum_grad([]), do: Expr.tensor(0.0) @@ -230,26 +234,19 @@ defmodule Nx.Defn.Grad do case nodes do %{^id => _} -> {nodes, grads} = traverse_parents(id, to_grad_ids, parents, {nodes, grads}) - {{ans, vectorized_names}, nodes} = Map.pop!(nodes, id) + {ans, nodes} = Map.pop!(nodes, id) %T{data: %Expr{op: op, args: args}} = ans {gs, grads} = Map.pop(grads, id) - {args, ans} = - if vectorized_names != [] do - args = - Enum.map(args, fn - %T{} = arg -> - revectorize_node(arg, vectorized_names) - - opt -> - opt - end) + # Devectorize args to match ans (already devec'd by parents_tree) + args = + Enum.map(args, fn + %T{vectorized_axes: va} = arg when va != [] -> + Nx.devectorize(arg, keep_names: false) - ans = Nx.vectorize(ans, vectorized_names) - {args, ans} - else - {args, ans} - end + other -> + other + end) case gs do nil -> @@ -269,22 +266,6 @@ defmodule Nx.Defn.Grad do end end - defp compute_arg_vectorized_names(%{vectorized_axes: vectorized_axes}, []), - do: Keyword.keys(vectorized_axes) - - defp compute_arg_vectorized_names( - %{vectorized_axes: vectorized_axes, names: names}, - parent_names - ) do - Keyword.keys(vectorized_axes) ++ Enum.filter(names, &(&1 in parent_names)) - end - - defp revectorize_node(node, vectorized_names) do - vectorized_names = compute_arg_vectorized_names(node, vectorized_names) - - Nx.vectorize(node, vectorized_names) - end - defp update_grads(:elem, [%{type: {:tuple, size}} = tuple, pos], _ans, g, _to_grad_ids, grads) do update_in(grads[tuple.data.id], fn tuple -> tuple = tuple || Tuple.duplicate([], size) @@ -329,7 +310,7 @@ defmodule Nx.Defn.Grad do {grad_body, _} = [arg] |> Composite.flatten_list() - |> Enum.map_reduce({nodes, while_grads}, &to_grad(&1, {arg, %{}}, parents, &2)) + |> Enum.map_reduce({nodes, while_grads}, &to_grad(&1, {arg, %{}, 0}, parents, &2, [])) # And finally build a new while. {_, while_gs} = @@ -350,7 +331,14 @@ defmodule Nx.Defn.Grad do grads end - defp update_grads(:cond, [clauses, last], _ans, gs, {to_grad, ids} = to_grad_ids, grads) do + defp update_grads( + :cond, + [clauses, last], + _ans, + gs, + {to_grad, ids, _batch_count} = to_grad_ids, + grads + ) do gs = List.wrap(gs) to_grad = Composite.flatten_list([to_grad]) @@ -364,7 +352,7 @@ defmodule Nx.Defn.Grad do end) {graded, _} = - Enum.map_reduce(to_grad, {nodes, grads}, &to_grad(&1, to_grad_ids, parents, &2)) + Enum.map_reduce(to_grad, {nodes, grads}, &to_grad(&1, to_grad_ids, parents, &2, [])) {head, graded} end) @@ -410,8 +398,8 @@ defmodule Nx.Defn.Grad do @reduced_grads [:add, :multiply, :pow] @verify_grad Application.compile_env(:nx, :verify_grad, false) - defp update_grads(op, args, ans, g, _to_grad_ids, grads) do - pairs = grad(op, args, ans, g) + defp update_grads(op, args, ans, g, {_to_grad, _ids, batch_count}, grads) do + pairs = grad(op, args, ans, g, batch_count) if @verify_grad do count = reduce_args(op, ans, 0, fn _arg, count -> count + 1 end) @@ -428,11 +416,11 @@ defmodule Nx.Defn.Grad do ## Gradients - defp grad(:parameter, [arg], _ans, g) do + defp grad(:parameter, [arg], _ans, g, _batch_count) do [{arg, g}] end - defp grad(:metadata, [_expr, %{custom_grad: {inputs, fun}}], _ans, g) do + defp grad(:metadata, [_expr, %{custom_grad: {inputs, fun}}], _ans, g, _batch_count) do # We don't expose the internal list representation to users g = if is_list(g), do: List.to_tuple(g), else: g args = fun.(g) @@ -444,21 +432,25 @@ defmodule Nx.Defn.Grad do Enum.zip(inputs, args) end - defp grad(:metadata, [expr, _], _ans, g) do + defp grad(:metadata, [expr, _], _ans, g, _batch_count) do [{expr, g}] end - defp grad(:select, [pred, on_true, on_false], ans, g) do + defp grad(:select, [pred, on_true, on_false], ans, g, batch_count) do d_on_true = Nx.select(pred, g, Expr.tensor(0.0)) d_on_false = Nx.select(pred, Expr.tensor(0.0), g) - [unbroadcast(on_true, d_on_true, ans), unbroadcast(on_false, d_on_false, ans)] + + [ + unbroadcast(on_true, d_on_true, ans, batch_count), + unbroadcast(on_false, d_on_false, ans, batch_count) + ] end - defp grad(:broadcast, [x, shape, axes], _ans, g) do + defp grad(:broadcast, [x, shape, axes], _ans, g, _batch_count) do [{x, grad_broadcast(x, shape, axes, g)}] end - defp grad(:clip, [operand, min, max], _ans, g) do + defp grad(:clip, [operand, min, max], _ans, g, _batch_count) do # w.r.t min w_min = Nx.select( @@ -485,19 +477,27 @@ defmodule Nx.Defn.Grad do ] end - defp grad(:squeeze, [x, axes], _ans, g) do + defp grad(:squeeze, [x, axes], _ans, g, _batch_count) do [{x, Nx.broadcast(g, x.shape, axes: Nx.axes(x.shape) -- axes)}] end - defp grad(:reshape, [x], _ans, g) do - [{x, Nx.reshape(g, x)}] + defp grad(:reshape, [x], _ans, g, batch_count) do + g = + if batch_count > 0 and Nx.size(g) > Nx.size(x) do + batch_dims = g.shape |> Tuple.to_list() |> Enum.take(batch_count) + Nx.reshape(g, List.to_tuple(batch_dims ++ Tuple.to_list(x.shape))) + else + Nx.reshape(g, x) + end + + [{x, g}] end - defp grad(:transpose, [x, axes], _ans, g) do + defp grad(:transpose, [x, axes], _ans, g, _batch_count) do [{x, Nx.transpose(g, axes: argsort(axes))}] end - defp grad(:pad, [x, value, padding_config], _ans, g) do + defp grad(:pad, [x, value, padding_config], _ans, g, _batch_count) do inverse_padding_config = Enum.map(padding_config, fn {lo, hi, _} -> {-lo, -hi, 0} end) unpadded = Nx.pad(g, 0.0, inverse_padding_config) @@ -511,7 +511,7 @@ defmodule Nx.Defn.Grad do [{x, g_operand}, {value, g_value}] end - defp grad(:slice, [x, start_indices, _lengths, strides], _ans, g) do + defp grad(:slice, [x, start_indices, _lengths, strides], _ans, g, _batch_count) do padding_config = Enum.map(strides, &{0, 0, &1 - 1}) pad_value = 0.0 g = Nx.pad(g, pad_value, padding_config) @@ -520,7 +520,7 @@ defmodule Nx.Defn.Grad do [{x, Nx.put_slice(zeros, start_indices, g)}] end - defp grad(:put_slice, [x, start_indices, update], _ans, g) do + defp grad(:put_slice, [x, start_indices, update], _ans, g, _batch_count) do zeros = Nx.broadcast(Expr.tensor(0.0), update) operand_t = Nx.put_slice(g, start_indices, zeros) @@ -529,7 +529,7 @@ defmodule Nx.Defn.Grad do [{x, operand_t}, {update, update_t}] end - defp grad(:indexed_put, [target, indices, updates, opts], _ans, g) do + defp grad(:indexed_put, [target, indices, updates, opts], _ans, g, _batch_count) do zeros = Nx.broadcast(Expr.tensor(0.0), updates) target_g = Nx.indexed_put(g, indices, zeros, opts) updates_g = g |> Nx.gather(indices, opts) |> Nx.reshape(updates.shape) @@ -538,7 +538,7 @@ defmodule Nx.Defn.Grad do [{target, target_g}, {indices, indices_g}, {updates, updates_g}] end - defp grad(:indexed_add, [target, indices, updates, opts], _ans, g) do + defp grad(:indexed_add, [target, indices, updates, opts], _ans, g, _batch_count) do target_g = g updates_g = g |> Nx.gather(indices, opts) |> Nx.reshape(updates.shape) indices_g = Nx.broadcast(Expr.tensor(0.0), indices) @@ -546,15 +546,15 @@ defmodule Nx.Defn.Grad do [{target, target_g}, {indices, indices_g}, {updates, updates_g}] end - defp grad(:reverse, [x, axes], _ans, g) do + defp grad(:reverse, [x, axes], _ans, g, _batch_count) do [{x, Nx.reverse(g, axes: axes)}] end - defp grad(:sum, [x, opts], _ans, g) do + defp grad(:sum, [x, opts], _ans, g, _batch_count) do [{x, reduce_g(x, opts, g)}] end - defp grad(:product, [x, opts], ans, g) do + defp grad(:product, [x, opts], ans, g, _batch_count) do axes = opts[:axes] || Nx.axes(x) unsqueezed_shape = Enum.reduce(axes, Nx.shape(x), &put_elem(&2, &1, 1)) g = Nx.reshape(g, unsqueezed_shape) @@ -589,7 +589,7 @@ defmodule Nx.Defn.Grad do @reduce_min_max_ops [:reduce_max, :reduce_min] - defp grad(op, [x, opts], ans, g) when op in @reduce_min_max_ops do + defp grad(op, [x, opts], ans, g, _batch_count) when op in @reduce_min_max_ops do g = reduce_g(x, opts, g) axes = opts[:axes] || Nx.axes(x) @@ -604,7 +604,7 @@ defmodule Nx.Defn.Grad do [{x, Nx.divide(num, den)}] end - defp grad(:dot, [x, axes_x, x_batch_axes, y, axes_y, y_batch_axes], ans, g) do + defp grad(:dot, [x, axes_x, x_batch_axes, y, axes_y, y_batch_axes], ans, g, _batch_count) do g = Nx.broadcast(g, ans) batch_gx = up_to(0, length(x_batch_axes)) @@ -632,13 +632,14 @@ defmodule Nx.Defn.Grad do [{x, gx}, {y, gy}] end - defp grad(:conv, [x, y, opts], ans, g) do + defp grad(:conv, [x, y, opts], ans, g, _batch_count) do grad_conv(x, y, opts, ans, g) end @window_chooser_op [:window_min, :window_max] - defp grad(op, [x, window_dimensions, opts], _ans, g) when op in @window_chooser_op do + defp grad(op, [x, window_dimensions, opts], _ans, g, _batch_count) + when op in @window_chooser_op do padding = opts[:padding] strides = opts[:strides] @@ -651,7 +652,7 @@ defmodule Nx.Defn.Grad do [{x, g}] end - defp grad(:window_sum, [x, window_dimensions, opts], _, g) do + defp grad(:window_sum, [x, window_dimensions, opts], _, g, _batch_count) do strides = opts[:strides] window_dilation = opts[:window_dilations] base_dilation = List.duplicate(1, Nx.rank(x)) @@ -687,7 +688,7 @@ defmodule Nx.Defn.Grad do [{x, g}] end - defp grad(:stack, [tensors, axis], ans, g) do + defp grad(:stack, [tensors, axis], ans, g, _batch_count) do zero_axes = List.duplicate(0, Nx.rank(ans)) ans_shape_list = Tuple.to_list(ans.shape) @@ -704,7 +705,7 @@ defmodule Nx.Defn.Grad do pairs end - defp grad(:concatenate, [tensors, axis], ans, g) do + defp grad(:concatenate, [tensors, axis], ans, g, _batch_count) do zero_axes = List.duplicate(0, Nx.rank(ans)) ans_shape_list = Tuple.to_list(ans.shape) @@ -720,7 +721,7 @@ defmodule Nx.Defn.Grad do pairs end - defp grad(:sort, [t, opts], _ans, g) do + defp grad(:sort, [t, opts], _ans, g, _batch_count) do idx = Nx.argsort(t, opts) reverse_idx = Nx.argsort(idx, axis: opts[:axis], direction: :asc) take_along_opts = Keyword.take(opts, [:axis]) @@ -728,7 +729,7 @@ defmodule Nx.Defn.Grad do [{t, g}] end - defp grad(:gather, [t, i, opts], _ans, g) do + defp grad(:gather, [t, i, opts], _ans, g, _batch_count) do i_axes = opts[:axes] i_shape = i.shape t_shape = t.shape @@ -748,46 +749,49 @@ defmodule Nx.Defn.Grad do [{t, g}] end - defp grad(:add, [x, y], ans, g) do + defp grad(:add, [x, y], ans, g, batch_count) do if x.data.id == y.data.id do [{x, Nx.multiply(g, 2.0)}] else - [unbroadcast(x, g, ans), unbroadcast(y, g, ans)] + [unbroadcast(x, g, ans, batch_count), unbroadcast(y, g, ans, batch_count)] end end - defp grad(:subtract, [x, y], ans, g) do - [unbroadcast(x, g, ans), unbroadcast(y, Nx.negate(g), ans)] + defp grad(:subtract, [x, y], ans, g, batch_count) do + [unbroadcast(x, g, ans, batch_count), unbroadcast(y, Nx.negate(g), ans, batch_count)] end - defp grad(:multiply, [x, y], ans, g) do + defp grad(:multiply, [x, y], ans, g, batch_count) do if x.data.id == y.data.id do [{x, Nx.multiply(g, Nx.multiply(2.0, x))}] else - [unbroadcast(x, Nx.multiply(g, y), ans), unbroadcast(y, Nx.multiply(g, x), ans)] + [ + unbroadcast(x, Nx.multiply(g, y), ans, batch_count), + unbroadcast(y, Nx.multiply(g, x), ans, batch_count) + ] end end - defp grad(:divide, [x, y], ans, g) do + defp grad(:divide, [x, y], ans, g, batch_count) do [ - unbroadcast(x, Nx.divide(g, y), ans), - unbroadcast(y, Nx.multiply(g, Nx.negate(Nx.divide(ans, y))), ans) + unbroadcast(x, Nx.divide(g, y), ans, batch_count), + unbroadcast(y, Nx.multiply(g, Nx.negate(Nx.divide(ans, y))), ans, batch_count) ] end - defp grad(:remainder, [x, y], ans, g) do + defp grad(:remainder, [x, y], ans, g, batch_count) do [ - unbroadcast(x, g, ans), - unbroadcast(y, Nx.multiply(g, Nx.negate(Nx.floor(Nx.divide(x, y)))), ans) + unbroadcast(x, g, ans, batch_count), + unbroadcast(y, Nx.multiply(g, Nx.negate(Nx.floor(Nx.divide(x, y)))), ans, batch_count) ] end - defp grad(:pow, [x, y], ans, g) do + defp grad(:pow, [x, y], ans, g, batch_count) do case y do %T{data: %Expr{op: :constant, args: [y]}} -> exponent = if y == 0.0, do: 1.0, else: y - 1.0 gx = Nx.multiply(y, Nx.pow(x, exponent)) - [unbroadcast(x, Nx.multiply(g, gx), ans)] + [unbroadcast(x, Nx.multiply(g, gx), ans, batch_count)] %{} -> exponent = Nx.select(Nx.equal(y, 0.0), 1.0, Nx.subtract(y, 1.0)) @@ -795,20 +799,24 @@ defmodule Nx.Defn.Grad do gx = Nx.multiply(y, Nx.pow(x, exponent)) gy = Nx.multiply(Nx.log(base), ans) - [unbroadcast(x, Nx.multiply(g, gx), ans), unbroadcast(y, Nx.multiply(g, gy), ans)] + + [ + unbroadcast(x, Nx.multiply(g, gx), ans, batch_count), + unbroadcast(y, Nx.multiply(g, gy), ans, batch_count) + ] end end - defp grad(:atan2, [x, y], ans, g) do + defp grad(:atan2, [x, y], ans, g, batch_count) do den = Nx.add(Nx.multiply(x, x), Nx.multiply(y, y)) [ - unbroadcast(x, Nx.multiply(g, Nx.divide(y, den)), ans), - unbroadcast(y, Nx.multiply(g, Nx.negate(Nx.divide(x, den))), ans) + unbroadcast(x, Nx.multiply(g, Nx.divide(y, den)), ans, batch_count), + unbroadcast(y, Nx.multiply(g, Nx.negate(Nx.divide(x, den))), ans, batch_count) ] end - defp grad(op, [x, y], ans, g) when op in [:min, :max] do + defp grad(op, [x, y], ans, g, batch_count) when op in [:min, :max] do lhs = Nx.divide( Nx.select(Nx.equal(x, ans), 1.0, 0.0), @@ -821,10 +829,13 @@ defmodule Nx.Defn.Grad do Nx.select(Nx.equal(x, ans), 2.0, 1.0) ) - [unbroadcast(x, Nx.multiply(g, lhs), ans), unbroadcast(y, Nx.multiply(g, rhs), ans)] + [ + unbroadcast(x, Nx.multiply(g, lhs), ans, batch_count), + unbroadcast(y, Nx.multiply(g, rhs), ans, batch_count) + ] end - defp grad(:as_type, [%{type: {:c, _}} = x], %{type: {output_type, _}}, g) + defp grad(:as_type, [%{type: {:c, _}} = x], %{type: {output_type, _}}, g, _batch_count) when output_type != :c do # For downcasting complex to float or integer types, `as_type/2` # behaves as: `x |> real() |> as_type(output_type)` @@ -836,15 +847,15 @@ defmodule Nx.Defn.Grad do [{x, Nx.real(g)}] end - defp grad(:as_type, [x], _ans, g) do + defp grad(:as_type, [x], _ans, g, _batch_count) do [{x, g}] end - defp grad(:bitcast, [x], _ans, g) do + defp grad(:bitcast, [x], _ans, g, _batch_count) do [{x, g}] end - defp grad(:abs, [%{type: {:c, _}} = z], ans, g) do + defp grad(:abs, [%{type: {:c, _}} = z], ans, g, _batch_count) do # For the complex variant of abs(z), we can define the forward-mode # derivative abs'(z) as follows (for an element-wise function): # abs(z)^2 = z.z* @@ -875,35 +886,35 @@ defmodule Nx.Defn.Grad do [{z, dz}] end - defp grad(:abs, [x], _ans, g) do + defp grad(:abs, [x], _ans, g, _batch_count) do [{x, Nx.select(Nx.greater_equal(x, 0.0), g, Nx.negate(g))}] end - defp grad(:sqrt, [x], ans, g) do + defp grad(:sqrt, [x], ans, g, _batch_count) do [{x, Nx.divide(Nx.multiply(g, 0.5), ans)}] end - defp grad(:cbrt, [x], ans, g) do + defp grad(:cbrt, [x], ans, g, _batch_count) do [{x, Nx.divide(g, 3 |> Nx.multiply(ans) |> Nx.multiply(ans))}] end - defp grad(:exp, [x], ans, g) do + defp grad(:exp, [x], ans, g, _batch_count) do [{x, Nx.multiply(g, ans)}] end - defp grad(:expm1, [x], ans, g) do + defp grad(:expm1, [x], ans, g, _batch_count) do [{x, Nx.multiply(g, Nx.add(ans, 1))}] end - defp grad(:log, [x], _ans, g) do + defp grad(:log, [x], _ans, g, _batch_count) do [{x, Nx.divide(g, x)}] end - defp grad(:log1p, [x], _ans, g) do + defp grad(:log1p, [x], _ans, g, _batch_count) do [{x, Nx.divide(g, Nx.add(x, 1))}] end - defp grad(:sigmoid, [x], ans, g) do + defp grad(:sigmoid, [x], ans, g, _batch_count) do gs = x |> Nx.negate() @@ -914,67 +925,67 @@ defmodule Nx.Defn.Grad do [{x, Nx.multiply(g, gs)}] end - defp grad(:negate, [x], _ans, g) do + defp grad(:negate, [x], _ans, g, _batch_count) do [{x, Nx.negate(g)}] end - defp grad(:rsqrt, [x], _ans, g) do + defp grad(:rsqrt, [x], _ans, g, _batch_count) do [{x, Nx.multiply(Nx.multiply(g, -0.5), Nx.pow(x, -1.5))}] end - defp grad(:sin, [x], _ans, g) do + defp grad(:sin, [x], _ans, g, _batch_count) do [{x, Nx.multiply(g, Nx.cos(x))}] end - defp grad(:asin, [x], _ans, g) do + defp grad(:asin, [x], _ans, g, _batch_count) do [{x, Nx.multiply(g, Nx.rsqrt(Nx.subtract(1.0, Nx.multiply(x, x))))}] end - defp grad(:sinh, [x], _ans, g) do + defp grad(:sinh, [x], _ans, g, _batch_count) do [{x, Nx.multiply(g, Nx.cosh(x))}] end - defp grad(:asinh, [x], _ans, g) do + defp grad(:asinh, [x], _ans, g, _batch_count) do [{x, Nx.multiply(g, Nx.rsqrt(Nx.add(Nx.multiply(x, x), 1.0)))}] end - defp grad(:acosh, [x], _ans, g) do + defp grad(:acosh, [x], _ans, g, _batch_count) do [{x, Nx.multiply(g, Nx.rsqrt(Nx.subtract(Nx.multiply(x, x), 1.0)))}] end - defp grad(:atanh, [x], _ans, g) do + defp grad(:atanh, [x], _ans, g, _batch_count) do [{x, Nx.divide(g, Nx.subtract(1.0, Nx.multiply(x, x)))}] end - defp grad(:cos, [x], _ans, g) do + defp grad(:cos, [x], _ans, g, _batch_count) do [{x, Nx.multiply(g, Nx.negate(Nx.sin(x)))}] end - defp grad(:acos, [x], _ans, g) do + defp grad(:acos, [x], _ans, g, _batch_count) do [{x, Nx.multiply(g, Nx.negate(Nx.rsqrt(Nx.subtract(1.0, Nx.multiply(x, x)))))}] end - defp grad(:cosh, [x], _ans, g) do + defp grad(:cosh, [x], _ans, g, _batch_count) do [{x, Nx.multiply(g, Nx.sinh(x))}] end - defp grad(:tan, [x], _ans, g) do + defp grad(:tan, [x], _ans, g, _batch_count) do cos = Nx.cos(x) [{x, g |> Nx.divide(cos) |> Nx.divide(cos)}] end - defp grad(:atan, [x], _ans, g) do + defp grad(:atan, [x], _ans, g, _batch_count) do [{x, Nx.divide(g, Nx.add(1.0, Nx.multiply(x, x)))}] end - defp grad(:tanh, [x], ans, g) do + defp grad(:tanh, [x], ans, g, _batch_count) do [{x, Nx.multiply(g, Nx.subtract(1.0, Nx.multiply(ans, ans)))}] end @half_sqrt_pi :math.sqrt(:math.pi()) / 2 @two_rsqrt_pi 2 / :math.sqrt(:math.pi()) - defp grad(:erf, [x], _ans, g) do + defp grad(:erf, [x], _ans, g, _batch_count) do gs = x |> Nx.multiply(x) @@ -985,7 +996,7 @@ defmodule Nx.Defn.Grad do [{x, Nx.multiply(g, gs)}] end - defp grad(:erfc, [x], _ans, g) do + defp grad(:erfc, [x], _ans, g, _batch_count) do gs = x |> Nx.multiply(x) @@ -996,16 +1007,16 @@ defmodule Nx.Defn.Grad do [{x, Nx.multiply(g, gs)}] end - defp grad(:erf_inv, [x], ans, g) do + defp grad(:erf_inv, [x], ans, g, _batch_count) do gs = Nx.multiply(@half_sqrt_pi, Nx.exp(Nx.multiply(ans, ans))) [{x, Nx.multiply(g, gs)}] end - defp grad(:attach_token, [_, x], _ans, g) do + defp grad(:attach_token, [_, x], _ans, g, _batch_count) do [{x, g}] end - defp grad(:conjugate, [%{type: {type, _}} = t], _ans, g) do + defp grad(:conjugate, [%{type: {type, _}} = t], _ans, g, _batch_count) do if type == :c do [{t, Nx.conjugate(g)}] else @@ -1013,22 +1024,22 @@ defmodule Nx.Defn.Grad do end end - defp grad(:real, [t], _ans, g) do + defp grad(:real, [t], _ans, g, _batch_count) do # real(z) = (z + conj(z))/2 # real'(z) = (z' + (conj(z))')/2 = (z' + conj(z'))/2 = real(z') [{t, Nx.real(g)}] end - defp grad(:imag, [t], _ans, g) do + defp grad(:imag, [t], _ans, g, _batch_count) do # imag(z) = (z - z*) / 2i # imag'(z) = z' - z'* / 2i = imag(z') [{t, Nx.imag(g)}] end - defp grad(:fft, args, ans, g), do: grad_fft(:fft, args, ans, g) - defp grad(:ifft, args, ans, g), do: grad_fft(:ifft, args, ans, g) + defp grad(:fft, args, ans, g, _batch_count), do: grad_fft(:fft, args, ans, g) + defp grad(:ifft, args, ans, g, _batch_count), do: grad_fft(:ifft, args, ans, g) - defp grad(:triangular_solve, [a_input, b, opts], x_input, g) do + defp grad(:triangular_solve, [a_input, b, opts], x_input, g, _batch_count) do # We can model the triangular solve function as X = triangular_solve(a, b) # where the function itself depends on the options passed. @@ -1111,7 +1122,7 @@ defmodule Nx.Defn.Grad do [{a_input, da}, {b, db}] end - defp grad(op, [tensor, source, init_value, window_dimensions, opts], _ans, g) + defp grad(op, [tensor, source, init_value, window_dimensions, opts], _ans, g, _batch_count) when op in [:window_scatter_max, :window_scatter_min] do padding_config = opts[:padding] strides = opts[:strides] @@ -1152,7 +1163,7 @@ defmodule Nx.Defn.Grad do [{tensor, dtensor}, {source, dsource}, {init_value, dinit_value}] end - defp grad(:quotient, _, _, _) do + defp grad(:quotient, _, _, _, _batch_count) do raise ArgumentError, """ cannot compute gradient for Nx.quotient/2. @@ -1162,7 +1173,7 @@ defmodule Nx.Defn.Grad do """ end - defp grad(:reduce, _, _, _) do + defp grad(:reduce, _, _, _, _batch_count) do raise ArgumentError, """ cannot compute gradient for Nx.reduce/4. @@ -1174,7 +1185,7 @@ defmodule Nx.Defn.Grad do """ end - defp grad(:window_reduce, _, _, _) do + defp grad(:window_reduce, _, _, _, _batch_count) do raise ArgumentError, """ cannot compute gradient for Nx.window_reduce/5. @@ -1188,7 +1199,7 @@ defmodule Nx.Defn.Grad do @error [:map, :window_product] - defp grad(op, args, _, _) when op in @error do + defp grad(op, args, _, _, _batch_count) when op in @error do raise ArgumentError, """ cannot compute gradient for Nx.#{op}/#{length(args)}. @@ -1198,7 +1209,7 @@ defmodule Nx.Defn.Grad do """ end - defp grad(op, args, _, _) do + defp grad(op, args, _, _, _batch_count) do raise ArgumentError, """ gradient not yet implemented for Nx.#{op}/#{length(args)}. @@ -1398,14 +1409,31 @@ defmodule Nx.Defn.Grad do ## General helpers - defp unbroadcast(%{shape: shape} = x, res, %{shape: shape}), do: {x, res} + defp unbroadcast(%{shape: shape} = x, res, %{shape: shape}, _batch_count), do: {x, res} + + defp unbroadcast(x, res, %{shape: new_shape}, batch_count) do + # Preserve batch dims when x doesn't already have them: + # x has fewer dims than batch_count, or its leading dims are all 1. + batch_offset = + cond do + batch_count == 0 -> + 0 + + tuple_size(x.shape) < batch_count -> + batch_count + + Enum.all?(0..(batch_count - 1)//1, &(elem(x.shape, &1) == 1)) -> + batch_count + + true -> + 0 + end - defp unbroadcast(%{shape: shape} = x, res, %{shape: new_shape}) do - axes = Nx.Shape.broadcast_axes(shape, new_shape) - {x, grad_broadcast(x, new_shape, axes, res)} + axes = Nx.Shape.broadcast_axes(x.shape, new_shape) + {x, grad_broadcast(x, new_shape, axes, res, batch_offset)} end - defp grad_broadcast(x, shape, axes, g) do + defp grad_broadcast(x, shape, axes, g, batch_offset \\ 0) do implicit_axes = for {a, i} <- Enum.with_index(axes), elem(shape, a) != 1 and elem(x.shape, i) == 1, @@ -1414,6 +1442,10 @@ defmodule Nx.Defn.Grad do {implicit_axes, broadcast_axes} = Enum.unzip(implicit_axes) explicit_axes = Nx.axes(shape) -- axes + # Skip batch dims — they should be preserved, not summed + implicit_axes = Enum.filter(implicit_axes, &(&1 >= batch_offset)) + explicit_axes = Enum.filter(explicit_axes, &(&1 >= batch_offset)) + g = case explicit_axes ++ implicit_axes do [] -> g @@ -1421,8 +1453,19 @@ defmodule Nx.Defn.Grad do end case broadcast_axes do - [] -> g - _ -> Nx.broadcast(g, x.shape, axes: Nx.axes(x.shape) -- broadcast_axes) + [] -> + g + + _ when batch_offset > 0 -> + # g has batch dims that x.shape doesn't — broadcast inner dims only + batch_dims = g.shape |> Tuple.to_list() |> Enum.take(batch_offset) + inner_target = List.to_tuple(batch_dims ++ Tuple.to_list(x.shape)) + inner_axes = Enum.map(Nx.axes(x.shape), &(&1 + batch_offset)) + keep_axes = inner_axes -- Enum.map(broadcast_axes, &(&1 + batch_offset)) + Nx.broadcast(g, inner_target, axes: Enum.to_list(0..(batch_offset - 1)) ++ keep_axes) + + _ -> + Nx.broadcast(g, x.shape, axes: Nx.axes(x.shape) -- broadcast_axes) end end diff --git a/nx/lib/nx/lin_alg/cholesky.ex b/nx/lib/nx/lin_alg/cholesky.ex index 7f6719c484..4ab76b31e5 100644 --- a/nx/lib/nx/lin_alg/cholesky.ex +++ b/nx/lib/nx/lin_alg/cholesky.ex @@ -76,8 +76,8 @@ defmodule Nx.LinAlg.Cholesky do end defn cholesky_grad(l, _input, g) do - num = g |> Nx.tril() |> Nx.dot([0], l, [0]) |> Nx.transpose() - den = l |> Nx.shape() |> Nx.eye() |> Nx.add(1) + num = g |> Nx.tril() |> Nx.dot([-2], l, [-2]) |> batch_transpose() + den = Nx.eye(l) |> Nx.add(1) phi_tril = num |> Nx.divide(den) |> Nx.tril() bm = Nx.LinAlg.triangular_solve(l, phi_tril, transform_a: :transpose) @@ -97,6 +97,17 @@ defmodule Nx.LinAlg.Cholesky do [dl] end + deftransformp batch_transpose(t) do + rank = tuple_size(t.shape) + + if rank <= 2 do + Nx.transpose(t) + else + axes = Enum.to_list(0..(rank - 3)) ++ [rank - 1, rank - 2] + Nx.transpose(t, axes: axes) + end + end + defnp conjugate_if_complex(x) do case Nx.type(x) do {:c, _} -> diff --git a/nx/lib/nx/lin_alg/qr.ex b/nx/lib/nx/lin_alg/qr.ex index 94f49937db..e1e465eb75 100644 --- a/nx/lib/nx/lin_alg/qr.ex +++ b/nx/lib/nx/lin_alg/qr.ex @@ -148,14 +148,23 @@ defmodule Nx.LinAlg.QR do # Definition taken from https://arxiv.org/pdf/2009.10071.pdf # Equation (3) r_inv = Nx.LinAlg.invert(r) + ba = batch_axes(r) - m = Nx.dot(r, Nx.LinAlg.adjoint(dr)) |> Nx.subtract(Nx.dot(Nx.LinAlg.adjoint(dq), q)) + m = + Nx.dot(r, [-1], ba, Nx.LinAlg.adjoint(dr), [-2], ba) + |> Nx.subtract(Nx.dot(Nx.LinAlg.adjoint(dq), [-1], ba, q, [-2], ba)) # copyltu m_ltu = Nx.tril(m) |> Nx.add(m |> Nx.tril(k: -1) |> Nx.LinAlg.adjoint()) - da = dq |> Nx.add(Nx.dot(q, m_ltu)) |> Nx.dot(Nx.LinAlg.adjoint(r_inv)) + q_m = Nx.dot(q, [-1], ba, m_ltu, [-2], ba) + da = Nx.dot(Nx.add(dq, q_m), [-1], ba, Nx.LinAlg.adjoint(r_inv), [-2], ba) [da] end + + deftransformp batch_axes(t) do + rank = tuple_size(t.shape) + Enum.to_list(0..(rank - 3)//1) + end end diff --git a/nx/test/nx/defn/grad_test.exs b/nx/test/nx/defn/grad_test.exs index df3efbcb89..e9b719f039 100644 --- a/nx/test/nx/defn/grad_test.exs +++ b/nx/test/nx/defn/grad_test.exs @@ -4872,7 +4872,83 @@ defmodule Nx.Defn.GradTest do end end + # Module-level defn functions for vectorization tests + defn while_square_n(x) do + {_i, result} = + while {i = 0, x}, Nx.less(i, 3) do + {i + 1, Nx.multiply(x, x)} + end + + Nx.sum(result) + end + describe "vectorization" do + @vec_atol 1.0e-4 + + # Compares vectorized grad against per-element grads. + # + # Accepts either a plain tensor (vectorized internally with `:batch`) + # or a pre-vectorized tensor with any number of vec axes. Per-element + # grads are computed on devectorized slices and compared to the + # corresponding slice of the vectorized grad's devectorized form. + defp check_vectorized_grad(x_or_vec, fun, opts \\ []) do + atol = opts[:atol] || @vec_atol + + x_vec = + if x_or_vec.vectorized_axes == [] do + Nx.vectorize(x_or_vec, :batch) + else + x_or_vec + end + + vec_axes = x_vec.vectorized_axes + n_vec = length(vec_axes) + vec_dims = Enum.map(vec_axes, fn {_name, size} -> size end) + + x_devec = Nx.devectorize(x_vec, keep_names: false) + vec_grad = Nx.Defn.grad(x_vec, fun) + vec_grad_devec = Nx.devectorize(vec_grad, keep_names: false) + + inner_shape = + x_devec.shape + |> Tuple.to_list() + |> Enum.drop(n_vec) + |> List.to_tuple() + + ranges = Enum.map(vec_dims, &Enum.to_list(0..(&1 - 1))) + indices = cartesian_product(ranges) + + for idx <- indices do + x_elem = + Enum.reduce(idx, x_devec, fn i, acc -> acc[i] end) + |> Nx.reshape(inner_shape) + + vec_elem = + Enum.reduce(idx, vec_grad_devec, fn i, acc -> acc[i] end) + |> Nx.reshape(inner_shape) + + elem_grad = Nx.Defn.grad(x_elem, fun) + + for {v, e} <- Enum.zip(Nx.to_flat_list(vec_elem), Nx.to_flat_list(elem_grad)) do + if v == :nan and e == :nan do + :ok + else + assert_in_delta v, e, atol, "Mismatch at idx #{inspect(idx)}: vec=#{v}, elem=#{e}" + end + end + end + + :ok + end + + defp cartesian_product([]), do: [[]] + + defp cartesian_product([list | rest]) do + for x <- list, suffix <- cartesian_product(rest), do: [x | suffix] + end + + # ── Pre-existing edge case tests ───────────────────────────────── + test "supports combination of vectorized and non-vectorized tensors" do x = Nx.tensor([[1, 2, 3], [4, 5, 6]]) |> Nx.vectorize(:x) y = 1 @@ -4893,89 +4969,504 @@ defmodule Nx.Defn.GradTest do assert grad == Nx.cos(x) end - # Skipping this as it's not supported yet. - @tag :skip - test "edge case where the same name changes meaning" do - x = Nx.tensor([[1], [2], [3]]) |> Nx.vectorize(x: 3) + test "supports heterogenous vectorization combinations" do + x_vec = Nx.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) |> Nx.vectorize(:x) + y_vec = Nx.tensor([10.0, 20.0]) |> Nx.vectorize(:y) + + {grad_x, grad_y} = + Nx.Defn.grad({x_vec, y_vec}, fn {a, b} -> Nx.multiply(a, b) end) + + union = [x: 2, y: 2] + assert grad_x.vectorized_axes == union + assert grad_y.vectorized_axes == union + end + + test "supports same-axis vectorization combinations" do + x = Nx.tensor([[1, 2, 3], [4, 5, 6]]) + y = Nx.tensor([10, 20]) + x_vec = Nx.vectorize(x, :x) + y_vec = Nx.vectorize(y, :x) + + {grad_x_vec, grad_y_vec} = + Nx.Defn.grad({x_vec, y_vec}, fn {a, b} -> Nx.multiply(a, b) end) + + assert grad_x_vec == + Nx.tensor([[10.0, 10.0, 10.0], [20.0, 20.0, 20.0]]) + |> Nx.vectorize(x_vec.vectorized_axes) + + assert grad_y_vec == Nx.tensor([6.0, 15.0]) |> Nx.vectorize(y_vec.vectorized_axes) + end + + # ── Edge case tests ────────────────────────────────────────────── + + test "vectorize/devectorize inside grad function" do + x = Nx.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) |> Nx.vectorize(:batch) grad = Nx.Defn.grad(x, fn t -> - devec = Nx.devectorize(t, keep_names: true) - new_axis = Nx.reshape(devec, {1, 3, 1}, names: [:x, nil, nil]) + devec = Nx.devectorize(t, keep_names: false) + re_vec = Nx.vectorize(devec, :batch) + Nx.sum(re_vec) + end) + + assert grad.vectorized_axes == [batch: 2] + expected = Nx.tensor([[1.0, 1.0, 1.0], [1.0, 1.0, 1.0]]) |> Nx.vectorize(:batch) + assert grad == expected + end + + test "reshape then vectorize inside grad" do + x = Nx.tensor([[1.0, 2.0, 3.0, 4.0], [5.0, 6.0, 7.0, 8.0]]) |> Nx.vectorize(:batch) - Nx.vectorize(new_axis, x: 1) + grad = + Nx.Defn.grad(x, fn t -> + reshaped = Nx.reshape(t, {2, 2}) + Nx.sum(reshaped) end) - assert grad == Nx.tensor([[1], [1], [1]]) |> Nx.vectorize(x: 3) + assert grad.vectorized_axes == [batch: 2] end - test "supports heterogenous vectorization combinations" do - x = Nx.tensor([[1, 2, 3], [4, 5, 6]]) - y = Nx.tensor([10, 20]) + test "non-vectorized input, vectorized output" do + x_vec = Nx.tensor([[1.0, 2.0], [3.0, 4.0]]) |> Nx.vectorize(:batch) + y = Nx.tensor([1.0, 1.0]) - # first case: y is vectorized scalar, x is vectorized vectors, different vectorized axis names - # expected result: equivalent to fully broadcasting one tensor onto the other - x_vec = Nx.vectorize(x, :x) - y_vec = Nx.vectorize(y, :y) + grad = Nx.Defn.grad(y, fn y -> Nx.sum(Nx.add(x_vec, y)) end) + assert grad.vectorized_axes == [batch: 2] + end - grad_fun = fn x, y -> - Nx.Defn.grad({x, y}, fn {a, b} -> Nx.multiply(a, b) end) - end + test "rename vectorized axes inside grad" do + x = Nx.tensor([[1.0, 2.0], [3.0, 4.0]]) |> Nx.vectorize(:a) - {grad_x_vec, grad_y_vec} = grad_fun.(x_vec, y_vec) + grad = + Nx.Defn.grad(x, fn t -> + devec = Nx.devectorize(t, keep_names: false) + Nx.vectorize(devec, :b) |> Nx.sum() + end) - # Explicit assertion on the results - assert grad_x_vec == - Nx.tensor([ - [ - [10.0, 10.0, 10.0], - [20.0, 20.0, 20.0] - ], - [ - [10.0, 10.0, 10.0], - [20.0, 20.0, 20.0] - ] - ]) - |> Nx.vectorize([:x, :y]) + assert grad.vectorized_axes == [a: 2] + end - assert grad_y_vec == - Nx.tensor([ - [6.0, 6.0], - [15.0, 15.0] - ]) - |> Nx.vectorize([:x, :y]) + test "devectorize then compute then return scalar" do + x = Nx.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) |> Nx.vectorize(:batch) - # Conceptual assertion: the result should be equivalent to calling Nx.Defn.grad with - # each cross-entry of the combined vectors [(x0, y0), (x0, y1), (x1, y0), (x1, y1)] + grad = + Nx.Defn.grad(x, fn t -> + devec = Nx.devectorize(t, keep_names: false) + Nx.sum(Nx.multiply(devec, devec)) + end) - {x0y0_wrt_x, x0y0_wrt_y} = grad_fun.(x[0], y[0]) - {x0y1_wrt_x, x0y1_wrt_y} = grad_fun.(x[0], y[1]) - {x1y0_wrt_x, x1y0_wrt_y} = grad_fun.(x[1], y[0]) - {x1y1_wrt_x, x1y1_wrt_y} = grad_fun.(x[1], y[1]) + assert grad.vectorized_axes == [batch: 2] + end - assert grad_x_vec == - [x0y0_wrt_x, x0y1_wrt_x, x1y0_wrt_x, x1y1_wrt_x] - |> Nx.stack() - |> Nx.reshape({2, 2, 3}) - |> Nx.vectorize([:x, :y]) - - assert grad_y_vec == - [x0y0_wrt_y, x0y1_wrt_y, x1y0_wrt_y, x1y1_wrt_y] - |> Nx.stack() - |> Nx.reshape({2, 2}) - |> Nx.vectorize([:x, :y]) - - # second case: y is vectorized scalar, x is vectorized vectors, same vectorized axis name - # expected result: equivalent to "row-wise" broadcasting - x_vec = Nx.vectorize(x, :x) - y_vec = Nx.vectorize(y, :x) - {grad_x_vec, grad_y_vec} = Nx.Defn.grad({x_vec, y_vec}, fn {a, b} -> Nx.multiply(a, b) end) + test "chained devectorize/vectorize with computation" do + x = Nx.tensor([[1.0, 2.0], [3.0, 4.0]]) |> Nx.vectorize(:batch) - assert grad_x_vec == - Nx.tensor([[10.0, 10.0, 10.0], [20.0, 20.0, 20.0]]) - |> Nx.vectorize(x_vec.vectorized_axes) + grad = + Nx.Defn.grad(x, fn t -> + devec = Nx.devectorize(t, keep_names: false) + doubled = Nx.multiply(devec, 2) + re_vec = Nx.vectorize(doubled, :batch) + Nx.sum(re_vec) + end) - assert grad_y_vec == Nx.tensor([6.0, 15.0]) |> Nx.vectorize(y_vec.vectorized_axes) + assert grad.vectorized_axes == [batch: 2] + expected = Nx.tensor([[2.0, 2.0], [2.0, 2.0]]) |> Nx.vectorize(:batch) + assert grad == expected + end + + test "multiple vectorized axes input" do + x = + Nx.tensor([[[1.0, 2.0], [3.0, 4.0]], [[5.0, 6.0], [7.0, 8.0]]]) + |> Nx.vectorize(:a) + |> Nx.vectorize(:b) + + grad = Nx.Defn.grad(x, fn t -> Nx.sum(Nx.multiply(t, t)) end) + assert grad.vectorized_axes == [a: 2, b: 2] + end + + test "second-order grad" do + x = Nx.tensor([[1.0, 2.0], [3.0, 4.0]]) |> Nx.vectorize(:batch) + + grad = + Nx.Defn.grad(x, fn t -> + inner = Nx.Defn.grad(t, fn u -> Nx.sum(Nx.pow(u, 3)) end) + Nx.sum(inner) + end) + + assert grad.vectorized_axes == [batch: 2] + end + + test "constant-grad ops with vectorized inputs (all/any/argmax/argmin)" do + x = Nx.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) |> Nx.vectorize(:batch) + + grad = Nx.Defn.grad(x, fn t -> Nx.sum(Nx.multiply(t, Nx.argmax(t))) end) + assert grad.vectorized_axes == [batch: 2] + + grad = Nx.Defn.grad(x, fn t -> Nx.sum(Nx.multiply(t, Nx.argmin(t))) end) + assert grad.vectorized_axes == [batch: 2] + end + + test "window_scatter_max/min with vectorized inputs" do + t_data = + Nx.tensor([ + [[7.0, 2.0, 5.0, 3.0], [8.0, 9.0, 1.0, 5.0]], + [[1.0, 5.0, 7.0, 0.0], [6.0, 2.0, 4.0, 3.0]] + ]) + + t = Nx.vectorize(t_data, :batch) + source = Nx.tensor([[2.0, 6.0], [3.0, 1.0]]) + init = 0 + + grad = + Nx.Defn.grad(t, fn t -> + Nx.window_scatter_max(t, source, init, {1, 2}, strides: [1, 2], padding: :valid) + end) + + assert grad.vectorized_axes == [batch: 2] + end + + test "value_and_grad with vectorized target" do + x = Nx.tensor([[1.0, 2.0], [3.0, 4.0]]) |> Nx.vectorize(:batch) + + {value, grad} = + Nx.Defn.value_and_grad(x, fn t -> Nx.sum(Nx.multiply(t, t)) end) + + assert value.vectorized_axes == [batch: 2] + assert grad.vectorized_axes == [batch: 2] + end + + test "grad of non-vectorized target with vectorized capture" do + x_vec = Nx.tensor([[1.0, 2.0], [3.0, 4.0]]) |> Nx.vectorize(:batch) + y = Nx.tensor(2.0) + + grad = Nx.Defn.grad(y, fn y -> Nx.sum(Nx.multiply(x_vec, y)) end) + assert grad.vectorized_axes == [batch: 2] + end + + test "grad w.r.t. tuple of vectorized and non-vectorized" do + x = Nx.tensor([[1.0, 2.0], [3.0, 4.0]]) |> Nx.vectorize(:batch) + y = Nx.tensor(2.0) + + {grad_x, grad_y} = + Nx.Defn.grad({x, y}, fn {a, b} -> Nx.sum(Nx.multiply(a, b)) end) + + assert grad_x.vectorized_axes == [batch: 2] + assert grad_y.vectorized_axes == [batch: 2] + end + + test "large vectorized batch" do + x = Nx.iota({64, 4}, type: :f32) |> Nx.divide(256) |> Nx.add(0.1) + + check_vectorized_grad(x, fn x -> Nx.sum(Nx.exp(x)) end) + end + + test "mixed vectorized axes with add" do + x_vec = Nx.tensor([[1.0, 2.0], [3.0, 4.0]]) |> Nx.vectorize(:x) + y_vec = Nx.tensor([10.0, 20.0]) |> Nx.vectorize(:y) + {grad_x, grad_y} = Nx.Defn.grad({x_vec, y_vec}, fn {a, b} -> Nx.add(a, b) end) + union = [x: 2, y: 2] + assert grad_x.vectorized_axes == union + assert grad_y.vectorized_axes == union + end + + test "mixed vectorized axes with multiply" do + x_vec = Nx.tensor([[1.0, 2.0], [3.0, 4.0]]) |> Nx.vectorize(:x) + y_vec = Nx.tensor([10.0, 20.0]) |> Nx.vectorize(:y) + {grad_x, grad_y} = Nx.Defn.grad({x_vec, y_vec}, fn {a, b} -> Nx.multiply(a, b) end) + union = [x: 2, y: 2] + assert grad_x.vectorized_axes == union + assert grad_y.vectorized_axes == union + end + + test "heterogenous vec grad: multiply (per-instance partials)" do + x_vec = Nx.tensor([2.0, 5.0]) |> Nx.vectorize(:foo) + y_vec = Nx.tensor([3.0, 4.0]) |> Nx.vectorize(:bar) + + {grad_x, grad_y} = + Nx.Defn.grad({x_vec, y_vec}, fn {x, y} -> + Nx.sum(Nx.multiply(x, y)) + end) + + assert grad_x == + Nx.tensor([[3.0, 4.0], [3.0, 4.0]]) |> Nx.vectorize(foo: 2, bar: 2) + + assert grad_y == + Nx.tensor([[2.0, 2.0], [5.0, 5.0]]) |> Nx.vectorize(foo: 2, bar: 2) + end + + test "heterogenous vec grad: add (constant per-instance partials)" do + x_vec = Nx.tensor([2.0, 5.0]) |> Nx.vectorize(:foo) + y_vec = Nx.tensor([3.0, 4.0, 6.0]) |> Nx.vectorize(:bar) + + {grad_x, grad_y} = + Nx.Defn.grad({x_vec, y_vec}, fn {x, y} -> + Nx.sum(Nx.add(x, y)) + end) + + ones = Nx.tensor([[1.0, 1.0, 1.0], [1.0, 1.0, 1.0]]) + assert grad_x == ones |> Nx.vectorize(foo: 2, bar: 3) + assert grad_y == ones |> Nx.vectorize(foo: 2, bar: 3) + end + + test "heterogenous vec grad: sin(add) (per-instance partials)" do + x_vec = Nx.tensor([0.5, 1.0]) |> Nx.vectorize(:foo) + y_vec = Nx.tensor([0.0, 0.3]) |> Nx.vectorize(:bar) + + {grad_x, grad_y} = + Nx.Defn.grad({x_vec, y_vec}, fn {x, y} -> + Nx.sum(Nx.sin(Nx.add(x, y))) + end) + + expected = + Nx.tensor([ + [:math.cos(0.5), :math.cos(0.8)], + [:math.cos(1.0), :math.cos(1.3)] + ]) + + assert grad_x.vectorized_axes == [foo: 2, bar: 2] + assert grad_y.vectorized_axes == [foo: 2, bar: 2] + + assert_all_close(Nx.devectorize(grad_x), expected, atol: 1.0e-6) + assert_all_close(Nx.devectorize(grad_y), expected, atol: 1.0e-6) + end + + test "heterogenous vec grad: x^2 * y (asymmetric per-instance partials)" do + x_vec = Nx.tensor([2.0, 5.0]) |> Nx.vectorize(:foo) + y_vec = Nx.tensor([3.0, 4.0]) |> Nx.vectorize(:bar) + + {grad_x, grad_y} = + Nx.Defn.grad({x_vec, y_vec}, fn {x, y} -> + Nx.sum(Nx.multiply(Nx.pow(x, 2), y)) + end) + + assert grad_x == + Nx.tensor([[12.0, 16.0], [30.0, 40.0]]) |> Nx.vectorize(foo: 2, bar: 2) + + assert grad_y == + Nx.tensor([[4.0, 4.0], [25.0, 25.0]]) |> Nx.vectorize(foo: 2, bar: 2) + end + + test "hidden vec axes inside grad: rename to different names and lengths" do + x = Nx.iota({4, 2, 3}, type: :f32) |> Nx.vectorize(:batch) + + grad = + Nx.Defn.grad(x, fn t -> + t + |> Nx.devectorize(keep_names: false) + |> Nx.vectorize(a: 4, b: 2) + |> Nx.sum() + end) + + devec = Nx.devectorize(grad, keep_names: false) + assert grad.vectorized_axes == [batch: 4] + assert devec == Nx.broadcast(1.0, devec) + end + + test "hidden vec axes inside grad: same outer name with hidden intermediate axes" do + x = Nx.iota({4, 2, 3}, type: :f32) |> Nx.vectorize(:batch) + + grad = + Nx.Defn.grad(x, fn t -> + t + |> Nx.devectorize(keep_names: false) + |> Nx.vectorize(a: 4, b: 2) + |> Nx.sum() + |> Nx.devectorize(keep_names: false) + |> Nx.vectorize(batch: 4) + |> Nx.sum() + end) + + devec = Nx.devectorize(grad, keep_names: false) + assert grad.vectorized_axes == [batch: 4] + assert devec == Nx.broadcast(1.0, devec) + end + + test "hidden vec axes inside grad: same length, different name" do + x = Nx.iota({4, 2, 3}, type: :f32) |> Nx.vectorize(:batch) + + grad = + Nx.Defn.grad(x, fn t -> + t + |> Nx.devectorize(keep_names: false) + |> Nx.vectorize(other: 4) + |> Nx.sum() + end) + + assert grad == + Nx.broadcast(1.0, {4, 2, 3}) |> Nx.vectorize(batch: 4) + end + + test "vectorized grad through Nx.LinAlg.qr (custom grad)" do + x = + Nx.tensor([ + [[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]], + [[7.0, 8.0], [9.0, 10.0], [11.0, 12.0]] + ]) + |> Nx.vectorize(:batch) + + check_vectorized_grad(x, fn t -> + {q, r} = Nx.LinAlg.qr(t) + Nx.add(Nx.sum(q), Nx.sum(r)) + end) + end + + test "vectorized grad through Nx.LinAlg.eigh (autograd via defn)" do + x = + Nx.tensor([ + [[2.0, 1.0], [1.0, 3.0]], + [[4.0, 0.5], [0.5, 5.0]] + ]) + |> Nx.vectorize(:batch) + + check_vectorized_grad(x, fn t -> + {evals, evecs} = Nx.LinAlg.eigh(t) + Nx.add(Nx.sum(evals), Nx.sum(evecs)) + end) + end + + test "three different vectorized axes" do + x_vec = Nx.tensor([[1.0, 2.0], [3.0, 4.0]]) |> Nx.vectorize(:x) + y_vec = Nx.tensor([10.0, 20.0]) |> Nx.vectorize(:y) + z_vec = Nx.tensor([100.0]) |> Nx.vectorize(:z) + + {grad_x, grad_y, grad_z} = + Nx.Defn.grad({x_vec, y_vec, z_vec}, fn {a, b, c} -> + Nx.add(Nx.multiply(a, b), c) + end) + + union = [x: 2, y: 2, z: 1] + assert grad_x.vectorized_axes == union + assert grad_y.vectorized_axes == union + assert grad_z.vectorized_axes == union + end + + test "two vectorized inputs through sin(add)" do + x_vec = Nx.tensor([[1.0, 2.0], [3.0, 4.0]]) |> Nx.vectorize(:x) + y_vec = Nx.tensor([0.5, 1.0]) |> Nx.vectorize(:y) + + {grad_x, grad_y} = + Nx.Defn.grad({x_vec, y_vec}, fn {a, b} -> Nx.sin(Nx.add(a, b)) end) + + union = [x: 2, y: 2] + assert grad_x.vectorized_axes == union + assert grad_y.vectorized_axes == union + end + + # ── check_vectorized_grad tests (from vectorized_grad_test) ────── + + test "exp (basic vectorized grad)" do + x = Nx.tensor([[0.5, 1.0, 1.5], [2.0, 0.3, 0.8], [1.2, 0.7, 0.1]], type: :f32) + check_vectorized_grad(x, fn x -> Nx.sum(Nx.exp(x)) end) + end + + test "multiply then sum" do + x = Nx.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]], type: :f32) + check_vectorized_grad(x, fn x -> Nx.sum(Nx.multiply(x, x)) end) + end + + test "sum axis 0 on 2D inner (exercises reduce_g fix)" do + x = + Nx.tensor( + [ + [[1.0, 2.0], [3.0, 4.0]], + [[5.0, 6.0], [7.0, 8.0]], + [[9.0, 10.0], [11.0, 12.0]] + ], + type: :f32 + ) + + check_vectorized_grad(x, fn x -> Nx.sum(Nx.sum(x, axes: [0])) end) + end + + test "reshape inside grad (exercises reshape boundary crossing)" do + x = + Nx.tensor( + [ + [[1.0, 2.0], [3.0, 4.0]], + [[5.0, 6.0], [7.0, 8.0]], + [[9.0, 10.0], [11.0, 12.0]] + ], + type: :f32 + ) + + check_vectorized_grad(x, fn x -> Nx.sum(Nx.transpose(Nx.reshape(x, {4}))) end) + end + + test "concatenate (exercises concatenate grad axis offset)" do + x = Nx.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]], type: :f32) + + check_vectorized_grad(x, fn x -> + Nx.sum(Nx.concatenate([x, Nx.multiply(x, 2)], axis: 0)) + end) + end + + test "window_sum (exercises window_scatter adjust)" do + x = + Nx.tensor( + [ + [[1.0, 2.0, 3.0, 4.0]], + [[5.0, 6.0, 7.0, 8.0]], + [[-1.0, 0.0, 1.0, 2.0]] + ], + type: :f32 + ) + + check_vectorized_grad(x, fn x -> Nx.sum(Nx.window_sum(x, {1, 2})) end) + end + + test "dot with captured matrix (exercises concrete tensor handling)" do + w = Nx.tensor([[1.0, 2.0], [3.0, 4.0]], type: :f32) + + check_vectorized_grad( + Nx.tensor( + [ + [[1.0, 2.0], [3.0, 4.0]], + [[5.0, 6.0], [7.0, 8.0]], + [[9.0, 10.0], [11.0, 12.0]] + ], + type: :f32 + ), + fn x -> Nx.sum(Nx.dot(x, w)) end + ) + end + + test "cumulative_sum (exercises axis name collision fix)" do + x = Nx.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]], type: :f32) + check_vectorized_grad(x, fn x -> Nx.sum(Nx.cumulative_sum(x)) end) + end + + test "conv with vectorized (exercises conv grad)" do + k = Nx.tensor([[[1.0, 0.0, -1.0]]]) + x = Nx.tensor([[[[1.0, 2.0, 3.0, 4.0]]], [[[5.0, 6.0, 7.0, 8.0]]]], type: :f32) + check_vectorized_grad(x, fn x -> Nx.sum(Nx.conv(x, k)) end) + end + + test "while loop: repeated squaring (exercises while boundary)" do + x = Nx.tensor([[0.5, 0.8], [1.2, 0.3], [0.9, 0.4]], type: :f32) + check_vectorized_grad(x, &while_square_n/1) + end + + test "mixed: vectorized x * non-vectorized y (exercises broadcast_vectors)" do + y = Nx.tensor([10.0, 20.0, 30.0], type: :f32) + x = Nx.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]], type: :f32) + check_vectorized_grad(x, fn x -> Nx.sum(Nx.multiply(x, y)) end) + end + + test "two vectorized axes (exercises unbroadcast for multiple axes)" do + x = Nx.tensor([[[1.0, 2.0], [3.0, 4.0]], [[5.0, 6.0], [7.0, 8.0]]], type: :f32) + x_vec = x |> Nx.vectorize(:a) |> Nx.vectorize(:b) + + check_vectorized_grad(x_vec, fn x -> Nx.sum(Nx.multiply(x, x)) end) + end + + test "composed: sigmoid(x @ w + b) (exercises composed chain with captures)" do + w = Nx.tensor([[0.5, -0.3], [0.2, 0.8]], type: :f32) + b = Nx.tensor([0.1, -0.1], type: :f32) + x = Nx.tensor([[1.0, 2.0], [3.0, 4.0], [-1.0, 0.5]], type: :f32) + + check_vectorized_grad(x, fn x -> + Nx.sum(Nx.sigmoid(Nx.add(Nx.dot(x, w), b))) + end) end end end