Skip to content

Support vectorized gradients via boundary devectorization (#1533)#1731

Open
blasphemetheus wants to merge 6 commits intoelixir-nx:mainfrom
blasphemetheus:fix/1533-boundary-wrapper
Open

Support vectorized gradients via boundary devectorization (#1533)#1731
blasphemetheus wants to merge 6 commits intoelixir-nx:mainfrom
blasphemetheus:fix/1533-boundary-wrapper

Conversation

@blasphemetheus
Copy link
Copy Markdown
Contributor

@blasphemetheus blasphemetheus commented Apr 14, 2026

Closes #1533 (pending review).

Replaces the existing per-op vectorization handling in Nx.Defn.Grad with a boundary wrapper.

Approach

At the entry of Grad.transform, align heterogenous vectorized inputs to the union of vec axes via Nx.broadcast_vectors. This matches the implicit alignment the forward pass already performs and lets the grad recursion run in a homogeneous case. The seed and expression args are devectorized, so grad operates in fully devectorized space — vec axes become leading "batch" dimensions of the devectorized tensors.

A batch_count is threaded through update_grads and the per-op grad/5 clauses. unbroadcast and :reshape use it to preserve batch dims (which should not be summed out as if they were ordinary broadcast dims). Most ops just pass it through. In to_grad, each gradient is re-vectorized with the appropriate vec axes before returning to the caller, so the user's grad result carries the same vec shape as their input.

Behavior for heterogenous vec axes

Each gradient in a heterogenous-axes call carries all output vec axes (output-shape semantics). Each (foo:i, bar:j) element holds the per-instance partial derivative; nothing is collapsed. Callers that want a parameter-shaped gradient sum over the foreign axes themselves.

x = Nx.tensor([2.0, 5.0]) |> Nx.vectorize(:foo)
y = Nx.tensor([3.0, 4.0]) |> Nx.vectorize(:bar)

{grad_x, grad_y} =
  Nx.Defn.grad({x, y}, fn {x, y} -> Nx.sum(Nx.multiply(x, y)) end)

# grad_x and grad_y both have vec[foo: 2, bar: 3], with per-instance partials:
#   grad_x[foo:i, bar:j] = y[bar:j]
#   grad_y[foo:i, bar:j] = x[foo:i]

Linalg simplifications

Nx.LinAlg.Cholesky and Nx.LinAlg.QR grad rules are updated to use rank-independent dot axes (Nx.dot(a, [-2], b, [-2])) so they survive the batched inputs that arise from devectorization. Cholesky also picks up an inlined Nx.eye(l) (was the deprecated l |> Nx.shape() |> Nx.eye()) and a small batch_transpose helper for swapping the last two axes regardless of rank.

Test coverage

The test suite in nx/test/nx/defn/grad_test.exs adds a check_vectorized_grad helper that compares vectorized grad against per-element stacked grads, plus coverage for:

  • Edge cases: vectorize/devectorize inside grad, rename, reshape-then-vectorize, non-vectorized input → vectorized output, multiple vec axes, second-order grad, value_and_grad, grad of non-vec target with vec capture, grad w.r.t. tuple of mixed inputs, large vectorized batches, and constant-grad ops (all/any/argmax/argmin).
  • Heterogenous-vec inputs: aligned at the boundary, with concrete-value tests for multiply, add, sin(add), and x^2 * y asserting per-instance partials directly.
  • Individual ops touched by the rework: exp, multiply, sum over a 2D inner axis, reshape, concatenate, window_sum, dot with a captured matrix, cumulative_sum, conv, while, mixed vec/non-vec, two stacked vec axes, composed sigmoid(x @ w + b).

mix test test/nx/defn/grad_test.exs → 268 tests, 0 failures, 0 skipped. Full mix test from nx/ → 1353 doctests, 1262 tests, 0 failures, 0 skipped.

Out of scope — separate issues

Three independent vectorized-grad bugs surfaced while implementing the boundary wrapper but are not addressed here. Each carries its original reproducer:

blasphemetheus and others added 2 commits April 13, 2026 20:26
…1533)

Replaces the existing per-op vectorization handling in Nx.Defn.Grad
with a boundary wrapper:

* At the entry of `Grad.transform`, align heterogenous vectorized
  inputs to the union of vec axes via `Nx.broadcast_vectors`. This
  matches the implicit alignment the forward pass already performs
  and lets the grad recursion run in a homogeneous case.

* Devectorize the gradient seed and the expression args in the
  recursion, so grad operates in fully devectorized space. Vec axes
  become leading "batch" dimensions of the devectorized tensors.

* Thread a `batch_count` through `update_grads` and the per-op
  `grad/5` clauses. `unbroadcast` and `:reshape` use it to preserve
  batch dims (which should not be summed out as if they were
  ordinary broadcast dims). Most ops just pass it through.

* In `to_grad`, re-vectorize each gradient with the appropriate vec
  axes before returning to the caller, so the user's grad result
  carries the same vec shape as their input.

Each gradient in a heterogenous-axes call ends up carrying all output
vec axes (Option A — output-shape semantics, per polvalente's
clarification on the issue thread). Each (foo:i, bar:j) element holds
the per-instance partial derivative; nothing is collapsed. Callers
that want a parameter-shaped gradient sum over the foreign axes
themselves.

Linalg grad rules in Cholesky and QR are updated to use rank-
independent dot axes (`Nx.dot(a, [-2], b, [-2])`) so they survive
batched inputs that arise from devectorization.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Adds tests covering the boundary wrapper from the previous commit:

* `check_vectorized_grad` helper that compares vectorized grad
  against per-element stacked grads, used as the workhorse for
  most coverage.

* Edge case tests for vectorize/devectorize inside grad, rename,
  reshape-then-vectorize, non-vectorized input → vectorized output,
  multiple vectorized axes, second-order grad, value_and_grad,
  grad of non-vec target with vec capture, grad w.r.t. a tuple of
  mixed inputs, large vectorized batches, and constant-grad ops
  (all/any/argmax/argmin).

* Assertion that heterogenous-vec inputs are aligned at the
  boundary and produce per-instance gradients with all output
  vec axes (replaces the prior `assert_raise`).

* Concrete-value tests for `multiply`, `add`, `sin(add)`, and
  `x^2 * y` over heterogenous vec inputs, asserting the per-
  instance partials directly.

* Coverage tests for individual ops touched by the boundary
  rework: exp, multiply, sum-axis-on-2D-inner, reshape, concatenate,
  window_sum, dot with captured matrix, cumulative_sum, conv,
  while loop, mixed vec/non-vec, two vec axes, composed
  sigmoid(x @ w + b).

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Comment thread nx/lib/nx/defn/grad.ex Outdated
per @polvalente's suggestion change if else to case

Co-authored-by: Paulo Valente <16843419+polvalente@users.noreply.github.com>
Comment thread nx/test/nx/defn/grad_test.exs Outdated
@vec_atol 1.0e-4

# Compares vectorized grad against per-element stacked grads.
defp check_vectorized_grad(x_data, fun, opts \\ []) do
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's improve Nx.Testing.assert_all_close to support vectorization? We can do it in a follow-up PR

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

assert_equal as well

devec = Nx.devectorize(t, keep_names: true)
new_axis = Nx.reshape(devec, {1, 3, 1}, names: [:x, nil, nil])
devec = Nx.devectorize(t, keep_names: false)
re_vec = Nx.vectorize(devec, :batch)
Copy link
Copy Markdown
Contributor

@polvalente polvalente Apr 14, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need to test (for a 3D input, vectorized with [:batch]):

  1. output with different names and length: input |> Nx.devectorize(keep_names: false) |> Nx.vectorize([:a, :b])
  2. output with same name, but "hidden" vectors appear: input |> Nx.devectorize(keep_names: false) |> Nx.vectorize([:a, :b]) |> Nx.sum() |> Nx.devectorize(keep_names: false) |> Nx.vectorize([:batch])
  3. output with same length but different names

Copy link
Copy Markdown
Contributor

@polvalente polvalente left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

PR generally looks great! I think the test cases I suggested are covered, but I'd like to confirm.
And we should also add a test on Nx.LinAlg.qr grad (hidden vectors, but it has custom grad) and Nx.LinAlg.eigh (hidden vectors, but relies on the autograd engine)

* `apply_boundary_broadcast/1` — switch from `length(flat) > 1` to a
  pattern match on `[_ | _]`, per @polvalente's review suggestion.
  Functionally equivalent (broadcast_vectors of a single tensor is
  already a no-op), idiomatically nicer.

* `Nx.LinAlg.QR.qr_grad` — make the `Nx.dot` calls batch-aware by
  passing explicit batch axes (everything except the last two dims).
  Without this, vectorized inputs that get devectorized into batched
  tensors hit a duplicate-axis-name error from `Nx.dot/4` because
  both operands carry the same `:batch` axis. The 2D-input case is
  unchanged (empty batch axes list is a no-op). Surfaced by the new
  vectorized QR grad test below.

* New tests covering the review-requested cases:
  - `hidden vec axes inside grad: rename to different names and lengths`
    — input vec[:batch], inside fn devectorize → vectorize([:a, :b])
    with different names and lengths
  - `hidden vec axes inside grad: same outer name with hidden
    intermediate axes` — input vec[:batch], devec → vec[:a, :b] →
    sum → devec → vec[:batch] → sum (the "hidden vectors" case
    where a different vec name appears in the middle of the
    computation)
  - `hidden vec axes inside grad: same length, different name`
  - `vectorized grad through Nx.LinAlg.qr (custom grad)` — exercises
    the QR `qr_grad` fix above
  - `vectorized grad through Nx.LinAlg.eigh (autograd via defn)` —
    exercises the boundary wrapper through eigh's autograd path

* Note on `Nx.Testing.assert_all_close` / `assert_equal` for
  vectorized tensors — being addressed in a separate follow-up PR
  per @polvalente's suggestion. The hidden-vec tests above use
  flat-list value comparisons + structural assertions to avoid the
  helper's current vec limitations.

Co-Authored-By: Paulo Valente <16843419+polvalente@users.noreply.github.com>
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Comment thread nx/test/nx/defn/grad_test.exs Outdated
Comment on lines +5225 to +5226
assert Nx.shape(Nx.devectorize(grad, keep_names: false)) == {4, 2, 3}
assert Nx.to_flat_list(grad) == List.duplicate(1.0, 24)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
assert Nx.shape(Nx.devectorize(grad, keep_names: false)) == {4, 2, 3}
assert Nx.to_flat_list(grad) == List.duplicate(1.0, 24)
assert Nx.devectorize(grad, keep_names: false) == Nx.broadcast(1.0, {4, 2, 3})

Comment thread nx/test/nx/defn/grad_test.exs Outdated
Comment on lines +5244 to +5245
assert Nx.shape(Nx.devectorize(grad, keep_names: false)) == {4, 2, 3}
assert Nx.to_flat_list(grad) == List.duplicate(1.0, 24)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
assert Nx.shape(Nx.devectorize(grad, keep_names: false)) == {4, 2, 3}
assert Nx.to_flat_list(grad) == List.duplicate(1.0, 24)
assert Nx.devectorize(grad, keep_names: false)) == Nx.broadcast(1.0, {4, 2, 3})

Comment thread nx/test/nx/defn/grad_test.exs Outdated
Comment on lines +5214 to +5225
x = Nx.iota({4, 2, 3}, type: :f32) |> Nx.vectorize(:batch)

grad =
Nx.Defn.grad(x, fn t ->
t
|> Nx.devectorize(keep_names: false)
|> Nx.vectorize(a: 4, b: 2)
|> Nx.sum()
end)

assert grad.vectorized_axes == [batch: 4]
assert Nx.shape(Nx.devectorize(grad, keep_names: false)) == {4, 2, 3}
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is right because this isn't too different than having names [:batch, nil, nil] -> names [:a, :b, nil] -> sum(axes: [2]).

Comment thread nx/test/nx/defn/grad_test.exs Outdated
Comment on lines +5439 to +5451
x_vec = x |> Nx.vectorize(:a) |> Nx.vectorize(:b)
vec_grad = Nx.Defn.grad(x_vec, fn x -> Nx.sum(Nx.multiply(x, x)) end)
vec_devec = Nx.devectorize(vec_grad, keep_names: false)

for i <- 0..1, j <- 0..1 do
x_ij = x[i][j] |> Nx.reshape({2})
elem_grad = Nx.Defn.grad(x_ij, fn x -> Nx.sum(Nx.multiply(x, x)) end)

for {v, e} <-
Enum.zip(Nx.to_flat_list(vec_devec[i][j]), Nx.to_flat_list(elem_grad)) do
assert_in_delta v, e, @vec_atol
end
end
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you should use the helper function you added above here

Copy link
Copy Markdown
Contributor

@polvalente polvalente left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@josevalim Overall this one looks good to me. WDYT?

Addresses the second round of @polvalente's inline comments on the
vectorized grad PR:

* `hidden vec axes inside grad` tests: replace the `to_flat_list +
  List.duplicate` workaround with `Nx.broadcast(1.0, devec_grad)`,
  which produces a broadcast tensor with the same shape AND names
  as the devectorized grad. This respects Pol's suggestion while
  handling the inner `:b` axis name that persists from the
  `Nx.vectorize(a: 4, b: 2)` step (`keep_names: false` strips only
  the merged vec-axis names, not pre-existing inner names).

* Extend `check_vectorized_grad` to accept a pre-vectorized tensor
  with any number of vec axes (was hardcoded to a single `:batch`
  axis). The helper now:
  - Accepts either a plain tensor (vectorized internally with
    `:batch` as before) or a pre-vectorized tensor.
  - Handles multi-axis vec inputs via a Cartesian product of vec
    dimension ranges.
  - Uses the devectorized view of the input for slicing, so the
    per-element grad is computed against the exact element that
    corresponds to the same slice of the vectorized grad (avoids
    the vec-axis-order-vs-original-axis-order confusion).

* Use the extended helper in the `"two vectorized axes"` test
  (replaces a manual nested iteration + `assert_in_delta` loop) and
  in the new QR / eigh grad tests (replaces the explicit `Nx.Defn.grad`
  call with a per-element comparison).

Co-Authored-By: Paulo Valente <16843419+polvalente@users.noreply.github.com>
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Resolves conflict in nx/lib/nx/defn/grad.ex: main renamed `:optional` to
`:block` and changed the args pattern (elixir-nx#1738 linalg block rename). This
branch removed the `parent_vectorized_names` threading from the same
function because the boundary wrapper handles vec axes centrally. Kept
the rename and new args pattern from main; kept the thread-removal from
this branch.

Verified: full grad_test.exs passes (273 tests, 0 failures).
@blasphemetheus
Copy link
Copy Markdown
Contributor Author

blasphemetheus commented Apr 21, 2026

@polvalente — while stress-testing a fuzz branch against a squash-merged simulation of this PR, I found that cholesky_grad is still broken for batched (3D+) inputs. The qr_grad rewrite on this branch uses batch_axes/1 + explicit batch axes on Nx.dot, and that works end-to-end; cholesky_grad uses Nx.dot([-2], l, [-2]) (axis-relative indices only, no batch axes), which is correct for 2D but produces a 4D intermediate when the input has a leading batch dim.

Minimal reproducer, run on this PR branch:

x = Nx.tensor([[[4.0, 2.0], [2.0, 5.0]], [[9.0, 3.0], [3.0, 10.0]]])
Nx.Defn.grad(x, fn a -> Nx.sum(Nx.LinAlg.cholesky(a)) end)
# ** (ArgumentError) incompatible dimensions for a and b on triangular solve
#    at nx/lib/nx/lin_alg/cholesky.ex:...

With a differently-shaped batched input the failure surfaces one line earlier as cannot broadcast tensor of dimensions {2, 4, 4, 2} to {2, 4, 4} — that's the 4D-from-dot problem I mentioned, caught at the subsequent Nx.divide(num, den).

Fix looks like mirroring the qr_grad treatment — add batch_axes/1 in cholesky.ex, change the dot to Nx.dot(a, [-2], ba, b, [-2], ba), and adapt the Nx.eye(l) |> Nx.add(1) denominator to match shape when batched.

Happy to either:

  1. Extend this PR with the Cholesky fix (I can push a commit), or
  2. File a follow-up issue/PR after this one merges and fix in a separate change.

This finding is the same class as #1740/#1741/#1742/#1743 (eigh / triangular_solve / LU / SVD batched grads). These are all followups anyway

@blasphemetheus
Copy link
Copy Markdown
Contributor Author

That stuff belongs in it's own PR on reflection, sort of a specific shape of input that maybe isn't often given.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Support vectorize/devectorize inside gradients

2 participants