Support vectorized gradients via boundary devectorization (#1533)#1731
Support vectorized gradients via boundary devectorization (#1533)#1731blasphemetheus wants to merge 6 commits intoelixir-nx:mainfrom
Conversation
…1533) Replaces the existing per-op vectorization handling in Nx.Defn.Grad with a boundary wrapper: * At the entry of `Grad.transform`, align heterogenous vectorized inputs to the union of vec axes via `Nx.broadcast_vectors`. This matches the implicit alignment the forward pass already performs and lets the grad recursion run in a homogeneous case. * Devectorize the gradient seed and the expression args in the recursion, so grad operates in fully devectorized space. Vec axes become leading "batch" dimensions of the devectorized tensors. * Thread a `batch_count` through `update_grads` and the per-op `grad/5` clauses. `unbroadcast` and `:reshape` use it to preserve batch dims (which should not be summed out as if they were ordinary broadcast dims). Most ops just pass it through. * In `to_grad`, re-vectorize each gradient with the appropriate vec axes before returning to the caller, so the user's grad result carries the same vec shape as their input. Each gradient in a heterogenous-axes call ends up carrying all output vec axes (Option A — output-shape semantics, per polvalente's clarification on the issue thread). Each (foo:i, bar:j) element holds the per-instance partial derivative; nothing is collapsed. Callers that want a parameter-shaped gradient sum over the foreign axes themselves. Linalg grad rules in Cholesky and QR are updated to use rank- independent dot axes (`Nx.dot(a, [-2], b, [-2])`) so they survive batched inputs that arise from devectorization. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Adds tests covering the boundary wrapper from the previous commit: * `check_vectorized_grad` helper that compares vectorized grad against per-element stacked grads, used as the workhorse for most coverage. * Edge case tests for vectorize/devectorize inside grad, rename, reshape-then-vectorize, non-vectorized input → vectorized output, multiple vectorized axes, second-order grad, value_and_grad, grad of non-vec target with vec capture, grad w.r.t. a tuple of mixed inputs, large vectorized batches, and constant-grad ops (all/any/argmax/argmin). * Assertion that heterogenous-vec inputs are aligned at the boundary and produce per-instance gradients with all output vec axes (replaces the prior `assert_raise`). * Concrete-value tests for `multiply`, `add`, `sin(add)`, and `x^2 * y` over heterogenous vec inputs, asserting the per- instance partials directly. * Coverage tests for individual ops touched by the boundary rework: exp, multiply, sum-axis-on-2D-inner, reshape, concatenate, window_sum, dot with captured matrix, cumulative_sum, conv, while loop, mixed vec/non-vec, two vec axes, composed sigmoid(x @ w + b). Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
per @polvalente's suggestion change if else to case Co-authored-by: Paulo Valente <16843419+polvalente@users.noreply.github.com>
| @vec_atol 1.0e-4 | ||
|
|
||
| # Compares vectorized grad against per-element stacked grads. | ||
| defp check_vectorized_grad(x_data, fun, opts \\ []) do |
There was a problem hiding this comment.
Let's improve Nx.Testing.assert_all_close to support vectorization? We can do it in a follow-up PR
| devec = Nx.devectorize(t, keep_names: true) | ||
| new_axis = Nx.reshape(devec, {1, 3, 1}, names: [:x, nil, nil]) | ||
| devec = Nx.devectorize(t, keep_names: false) | ||
| re_vec = Nx.vectorize(devec, :batch) |
There was a problem hiding this comment.
We need to test (for a 3D input, vectorized with [:batch]):
- output with different names and length: input |> Nx.devectorize(keep_names: false) |> Nx.vectorize([:a, :b])
- output with same name, but "hidden" vectors appear: input |> Nx.devectorize(keep_names: false) |> Nx.vectorize([:a, :b]) |> Nx.sum() |> Nx.devectorize(keep_names: false) |> Nx.vectorize([:batch])
- output with same length but different names
polvalente
left a comment
There was a problem hiding this comment.
PR generally looks great! I think the test cases I suggested are covered, but I'd like to confirm.
And we should also add a test on Nx.LinAlg.qr grad (hidden vectors, but it has custom grad) and Nx.LinAlg.eigh (hidden vectors, but relies on the autograd engine)
* `apply_boundary_broadcast/1` — switch from `length(flat) > 1` to a pattern match on `[_ | _]`, per @polvalente's review suggestion. Functionally equivalent (broadcast_vectors of a single tensor is already a no-op), idiomatically nicer. * `Nx.LinAlg.QR.qr_grad` — make the `Nx.dot` calls batch-aware by passing explicit batch axes (everything except the last two dims). Without this, vectorized inputs that get devectorized into batched tensors hit a duplicate-axis-name error from `Nx.dot/4` because both operands carry the same `:batch` axis. The 2D-input case is unchanged (empty batch axes list is a no-op). Surfaced by the new vectorized QR grad test below. * New tests covering the review-requested cases: - `hidden vec axes inside grad: rename to different names and lengths` — input vec[:batch], inside fn devectorize → vectorize([:a, :b]) with different names and lengths - `hidden vec axes inside grad: same outer name with hidden intermediate axes` — input vec[:batch], devec → vec[:a, :b] → sum → devec → vec[:batch] → sum (the "hidden vectors" case where a different vec name appears in the middle of the computation) - `hidden vec axes inside grad: same length, different name` - `vectorized grad through Nx.LinAlg.qr (custom grad)` — exercises the QR `qr_grad` fix above - `vectorized grad through Nx.LinAlg.eigh (autograd via defn)` — exercises the boundary wrapper through eigh's autograd path * Note on `Nx.Testing.assert_all_close` / `assert_equal` for vectorized tensors — being addressed in a separate follow-up PR per @polvalente's suggestion. The hidden-vec tests above use flat-list value comparisons + structural assertions to avoid the helper's current vec limitations. Co-Authored-By: Paulo Valente <16843419+polvalente@users.noreply.github.com> Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
| assert Nx.shape(Nx.devectorize(grad, keep_names: false)) == {4, 2, 3} | ||
| assert Nx.to_flat_list(grad) == List.duplicate(1.0, 24) |
There was a problem hiding this comment.
| assert Nx.shape(Nx.devectorize(grad, keep_names: false)) == {4, 2, 3} | |
| assert Nx.to_flat_list(grad) == List.duplicate(1.0, 24) | |
| assert Nx.devectorize(grad, keep_names: false) == Nx.broadcast(1.0, {4, 2, 3}) |
| assert Nx.shape(Nx.devectorize(grad, keep_names: false)) == {4, 2, 3} | ||
| assert Nx.to_flat_list(grad) == List.duplicate(1.0, 24) |
There was a problem hiding this comment.
| assert Nx.shape(Nx.devectorize(grad, keep_names: false)) == {4, 2, 3} | |
| assert Nx.to_flat_list(grad) == List.duplicate(1.0, 24) | |
| assert Nx.devectorize(grad, keep_names: false)) == Nx.broadcast(1.0, {4, 2, 3}) |
| x = Nx.iota({4, 2, 3}, type: :f32) |> Nx.vectorize(:batch) | ||
|
|
||
| grad = | ||
| Nx.Defn.grad(x, fn t -> | ||
| t | ||
| |> Nx.devectorize(keep_names: false) | ||
| |> Nx.vectorize(a: 4, b: 2) | ||
| |> Nx.sum() | ||
| end) | ||
|
|
||
| assert grad.vectorized_axes == [batch: 4] | ||
| assert Nx.shape(Nx.devectorize(grad, keep_names: false)) == {4, 2, 3} |
There was a problem hiding this comment.
I think this is right because this isn't too different than having names [:batch, nil, nil] -> names [:a, :b, nil] -> sum(axes: [2]).
| x_vec = x |> Nx.vectorize(:a) |> Nx.vectorize(:b) | ||
| vec_grad = Nx.Defn.grad(x_vec, fn x -> Nx.sum(Nx.multiply(x, x)) end) | ||
| vec_devec = Nx.devectorize(vec_grad, keep_names: false) | ||
|
|
||
| for i <- 0..1, j <- 0..1 do | ||
| x_ij = x[i][j] |> Nx.reshape({2}) | ||
| elem_grad = Nx.Defn.grad(x_ij, fn x -> Nx.sum(Nx.multiply(x, x)) end) | ||
|
|
||
| for {v, e} <- | ||
| Enum.zip(Nx.to_flat_list(vec_devec[i][j]), Nx.to_flat_list(elem_grad)) do | ||
| assert_in_delta v, e, @vec_atol | ||
| end | ||
| end |
There was a problem hiding this comment.
I think you should use the helper function you added above here
polvalente
left a comment
There was a problem hiding this comment.
@josevalim Overall this one looks good to me. WDYT?
Addresses the second round of @polvalente's inline comments on the vectorized grad PR: * `hidden vec axes inside grad` tests: replace the `to_flat_list + List.duplicate` workaround with `Nx.broadcast(1.0, devec_grad)`, which produces a broadcast tensor with the same shape AND names as the devectorized grad. This respects Pol's suggestion while handling the inner `:b` axis name that persists from the `Nx.vectorize(a: 4, b: 2)` step (`keep_names: false` strips only the merged vec-axis names, not pre-existing inner names). * Extend `check_vectorized_grad` to accept a pre-vectorized tensor with any number of vec axes (was hardcoded to a single `:batch` axis). The helper now: - Accepts either a plain tensor (vectorized internally with `:batch` as before) or a pre-vectorized tensor. - Handles multi-axis vec inputs via a Cartesian product of vec dimension ranges. - Uses the devectorized view of the input for slicing, so the per-element grad is computed against the exact element that corresponds to the same slice of the vectorized grad (avoids the vec-axis-order-vs-original-axis-order confusion). * Use the extended helper in the `"two vectorized axes"` test (replaces a manual nested iteration + `assert_in_delta` loop) and in the new QR / eigh grad tests (replaces the explicit `Nx.Defn.grad` call with a per-element comparison). Co-Authored-By: Paulo Valente <16843419+polvalente@users.noreply.github.com> Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Resolves conflict in nx/lib/nx/defn/grad.ex: main renamed `:optional` to `:block` and changed the args pattern (elixir-nx#1738 linalg block rename). This branch removed the `parent_vectorized_names` threading from the same function because the boundary wrapper handles vec axes centrally. Kept the rename and new args pattern from main; kept the thread-removal from this branch. Verified: full grad_test.exs passes (273 tests, 0 failures).
|
@polvalente — while stress-testing a fuzz branch against a squash-merged simulation of this PR, I found that Minimal reproducer, run on this PR branch: x = Nx.tensor([[[4.0, 2.0], [2.0, 5.0]], [[9.0, 3.0], [3.0, 10.0]]])
Nx.Defn.grad(x, fn a -> Nx.sum(Nx.LinAlg.cholesky(a)) end)
# ** (ArgumentError) incompatible dimensions for a and b on triangular solve
# at nx/lib/nx/lin_alg/cholesky.ex:...With a differently-shaped batched input the failure surfaces one line earlier as Fix looks like mirroring the Happy to either:
This finding is the same class as #1740/#1741/#1742/#1743 (eigh / triangular_solve / LU / SVD batched grads). These are all followups anyway |
|
That stuff belongs in it's own PR on reflection, sort of a specific shape of input that maybe isn't often given. |
Closes #1533 (pending review).
Replaces the existing per-op vectorization handling in
Nx.Defn.Gradwith a boundary wrapper.Approach
At the entry of
Grad.transform, align heterogenous vectorized inputs to the union of vec axes viaNx.broadcast_vectors. This matches the implicit alignment the forward pass already performs and lets the grad recursion run in a homogeneous case. The seed and expression args are devectorized, so grad operates in fully devectorized space — vec axes become leading "batch" dimensions of the devectorized tensors.A
batch_countis threaded throughupdate_gradsand the per-opgrad/5clauses.unbroadcastand:reshapeuse it to preserve batch dims (which should not be summed out as if they were ordinary broadcast dims). Most ops just pass it through. Into_grad, each gradient is re-vectorized with the appropriate vec axes before returning to the caller, so the user's grad result carries the same vec shape as their input.Behavior for heterogenous vec axes
Each gradient in a heterogenous-axes call carries all output vec axes (output-shape semantics). Each
(foo:i, bar:j)element holds the per-instance partial derivative; nothing is collapsed. Callers that want a parameter-shaped gradient sum over the foreign axes themselves.Linalg simplifications
Nx.LinAlg.CholeskyandNx.LinAlg.QRgrad rules are updated to use rank-independent dot axes (Nx.dot(a, [-2], b, [-2])) so they survive the batched inputs that arise from devectorization. Cholesky also picks up an inlinedNx.eye(l)(was the deprecatedl |> Nx.shape() |> Nx.eye()) and a smallbatch_transposehelper for swapping the last two axes regardless of rank.Test coverage
The test suite in
nx/test/nx/defn/grad_test.exsadds acheck_vectorized_gradhelper that compares vectorized grad against per-element stacked grads, plus coverage for:value_and_grad, grad of non-vec target with vec capture, grad w.r.t. tuple of mixed inputs, large vectorized batches, and constant-grad ops (all/any/argmax/argmin).multiply,add,sin(add), andx^2 * yasserting per-instance partials directly.exp,multiply,sumover a 2D inner axis,reshape,concatenate,window_sum,dotwith a captured matrix,cumulative_sum,conv,while, mixed vec/non-vec, two stacked vec axes, composedsigmoid(x @ w + b).mix test test/nx/defn/grad_test.exs→ 268 tests, 0 failures, 0 skipped. Fullmix testfromnx/→ 1353 doctests, 1262 tests, 0 failures, 0 skipped.Out of scope — separate issues
Three independent vectorized-grad bugs surfaced while implementing the boundary wrapper but are not addressed here. Each carries its original reproducer:
unbroadcastinside the dot gradient does not account for vec axes; the contracted form fails for any vectorized inputapply_vectorizedops (Nx.LinAlg.cholesky,Nx.LinAlg.triangular_solve,Nx.cond/Nx.select) re-vectorize internally, which the boundary-wrapper grad cannot reconcile with its devectorized recursion