Skip to content

Nx.LinAlg gradients don't support batched (3D+) inputs #1748

@blasphemetheus

Description

@blasphemetheus

Consolidated tracking for a class of related gradient bugs in Nx.LinAlg: hand-written custom_grad formulas for linalg ops were written assuming 2-D input tensors and are broken when the input has a leading batch dimension. Forward passes work; only the backward pass miscomputes.

All individual cases are filed as their own focused issues so each has a clean target for a fix PR (see links below). I'll close them as duplicates as they'll likely be addressed in the same PR.

Root cause (shared across all cases)

Several linalg ops have custom_grad formulas inside nx/lib/nx/lin_alg.ex and nx/lib/nx/lin_alg/*.ex. Those formulas typically contain Nx.dot(a, b) calls (no explicit axes / no batch axes), Nx.shape(input) pattern-matches on 2-tuples, or Nx.eye(l) calls assuming 2-D. For 2-D inputs this works; for higher-rank inputs (any batched tensor), the formula produces wrong shapes or crashes.

Fix pattern

For most of these ops, the fix mirrors what PR #1731 did for qr_grad and (partially) cholesky_grad: add a small batch_axes/1 helper that returns the list of leading batch axes, and change every Nx.dot(a, b) to Nx.dot(a, [-1], ba, b, [-2], ba) so the contraction axes are correct regardless of leading batch dims. Similar treatment for Nx.eye (build with the right shape), {m, n} = Nx.shape(input) (read last two dims instead of full shape), etc.

Cases

[A] Nx.LinAlg.qr grad

  • Issue: tracked alongside Cholesky in PR #1731
  • Status: fix approved on PR branch; awaiting maintainer merge
  • Repro (pre-fix):
    x = Nx.tensor([[[1.0, 2.0], [3.0, 4.0]], [[5.0, 6.0], [7.0, 8.0]]])
    Nx.Defn.grad(x, fn a -> {q, r} = Nx.LinAlg.qr(a); Nx.add(Nx.sum(q), Nx.sum(r)) end)
    # ** (ArgumentError) cannot broadcast tensor of dimensions {2, 4, 2, 4} to {2, 4, 4, 2}

[B] Nx.LinAlg.cholesky grad — partial fix in #1731

[C] Nx.LinAlg.lu grad — #1742

  • lu_grad uses unqualified Nx.dot calls that contract wrong axes when the input has a leading batch dim.
  • Repro:
    x = Nx.tensor([[[4.0, 3.0], [6.0, 3.0]], [[2.0, 1.0], [5.0, 7.0]]])
    Nx.Defn.grad(x, fn t ->
      {_p, l, u} = Nx.LinAlg.lu(t)
      Nx.add(Nx.sum(l), Nx.sum(u))
    end)
    # ** (ArgumentError) cannot broadcast tensor of dimensions {2, 2, 2, 2, 2, 2, 2} to {2, 2, 2}
  • Fix direction: batch_axes/1 helper in lu_grad.

[D] Nx.LinAlg.svd grad — #1743

  • svd_grad pattern-matches {m, n} = Nx.shape(input), which fails immediately on 3-D input.
  • Repro:
    x = Nx.tensor([[[4.0, 3.0], [6.0, 3.0]], [[2.0, 1.0], [5.0, 7.0]]])
    Nx.Defn.grad(x, fn t ->
      {u, s, vt} = Nx.LinAlg.svd(t)
      Nx.add(Nx.add(Nx.sum(u), Nx.sum(s)), Nx.sum(vt))
    end)
    # ** (MatchError) no match of right hand side value: {2, 2, 2}
  • Fix direction: more invasive than the others — the whole svd_grad body assumes rank 2. Nx.eye(k)/Nx.eye(m)/Nx.eye(n), Nx.new_axis(s_sq, 1), and Nx.make_diagonal(s_input) all need to become batch-aware.

[E] Nx.LinAlg.eigh grad — #1740

  • Distinct failure mode from the others — eigh has no custom_grad, so grad flows through the internal Jacobi iterations (a while). Two sub-bugs:
    • 2-D input fails a reshape: cannot reshape {} to {1, N, N} (the grad path assumes 3-D).
    • 3-D f64 input fails: expected 32 bits got: 64 bits (hardcoded f32 somewhere in the eigh path).
  • Repro (2-D case):
    x = Nx.tensor([[4.0, 2.0], [2.0, 5.0]], type: :f32)
    Nx.Defn.grad(x, fn a -> {s, _} = Nx.LinAlg.eigh(a); Nx.sum(s) end)
    # ** (ArgumentError) cannot reshape, current shape {} is not compatible with new shape {1, 2, 2}
  • Only working case: 3-D batch-1 :f32, which is numerically correct (returns the identity matrix for the trace-gradient test).
  • Fix direction: different from the others — no custom_grad to patch. Needs investigation inside the autograd path for the eigh while body.

[F] Nx.LinAlg.triangular_solve grad — #1741

  • Same class as [C]; shape mismatch in the dot inside the custom_grad.
  • Repro:
    a = Nx.tensor([[[4.0, 0.0], [2.0, 5.0]], [[9.0, 0.0], [3.0, 10.0]]])
    b = Nx.broadcast(1.0, {2, 2})
    Nx.Defn.grad(a, fn x -> Nx.sum(Nx.LinAlg.triangular_solve(x, b)) end)
    # ** (ArgumentError) dot/zip expects shapes to be compatible, dimension 2 of left-side (2)
    #    does not equal dimension 0 of right-side (2)
  • Note: related to but distinct from #1729 (vectorized grad); this is plain batched grad with no Nx.vectorize wrapping.
  • Fix direction: batch_axes/1 helper.

[G] Nx.LinAlg.invert grad — #1744

  • custom_grad at nx/lib/nx/lin_alg.ex:866 uses Nx.dot twice without batch axes.
  • Repro:
    x = Nx.tensor([[[4.0, 1.0], [1.0, 5.0]], [[9.0, 1.0], [1.0, 10.0]]])
    Nx.Defn.grad(x, fn a -> Nx.sum(Nx.LinAlg.invert(a)) end)
    # ** (ArgumentError) cannot broadcast tensor of dimensions {2, 3, 2, 2, 3} to {2, 3, 3}
  • Fix direction: batch_axes/1 applied to both dots in the formula.

[H] Nx.LinAlg.solve grad — #1745

  • Shape mismatch in the grad path (not a custom_grad; solve delegates internally).
  • Repro:
    a = Nx.tensor([[[4.0, 1.0, 0.0], [1.0, 5.0, 0.0], [0.0, 0.0, 6.0]],
                   [[9.0, 1.0, 0.0], [1.0, 10.0, 0.0], [0.0, 0.0, 11.0]]])
    b = Nx.broadcast(1.0, {2, 3})
    Nx.Defn.grad(a, fn x -> Nx.sum(Nx.LinAlg.solve(x, b)) end)
    # ** (ArgumentError) dot/zip expects shapes to be compatible, dimension 2 of left-side (3)
    #    does not equal dimension 2 of right-side (2)
  • Fix direction: investigation — solve doesn't have an obvious custom_grad to patch; the failure is likely in how its internal definition lowers under grad.

[I] Nx.LinAlg.determinant grad — #1746

  • Passes on 3×3 but fails on 4×4 and larger (batched or not).
  • Repro:
    x = Nx.tensor([[
          [5.0, 1.0, 2.0, 3.0],
          [4.0, 10.0, 6.0, 7.0],
          [8.0, 9.0, 15.0, 11.0],
          [12.0, 13.0, 14.0, 20.0]
        ]])
    Nx.Defn.grad(x, fn a -> Nx.sum(Nx.LinAlg.determinant(a)) end)
    # ** (ArgumentError) dot/zip expects shapes to be compatible, dimension 3 of left-side (4)
    #    does not equal dimension 2 of right-side (1)
  • Fix direction: hardcoded axes in the grad's Nx.dot — same class as [C], [F], [G].

Out-of-scope but possibly related

  • pinv — fails for batched input with a shape-cond error. Likely downstream of [D]'s SVD grad bug since pinv is defined in terms of SVD. Will probably fix itself once [D] is addressed.
  • norm — fails forward pass on batched input (expected 1-D or 2-D tensor, got tensor with shape {2, 4, 4}). Not a grad bug but a forward-pass rank limitation. Different class entirely.

At-a-glance status

ID Op Issue Root Fix approach
A qr #1731 (PR) hardcoded dot axes batch_axes/1 helper ✓
B cholesky #1731 thread hardcoded dot axes batch_axes/1 helper (not yet)
C lu #1742 hardcoded dot axes batch_axes/1 helper
D svd #1743 2-D assumption throughout invasive rewrite
E eigh #1740 autograd-through-while different class
F triangular_solve #1741 hardcoded dot axes batch_axes/1 helper
G invert #1744 hardcoded dot axes batch_axes/1 helper
H solve #1745 grad-path shape bug investigation
I determinant #1746 hardcoded dot axes batch_axes/1 helper

Five of these (B, C, F, G, I) look like copy-paste of the same small fix; D and H need deeper work; E is a distinct class (autograd through an internal while).

How this was found

Discovered via a property-based "batched-grad probe" in a local fuzz branch that exercises Nx.Defn.grad through each linalg op with random batched inputs. Anyone wanting to reproduce can do so with:

a = Nx.iota({2, 3, 3}, type: :f32)   # or any batched matrix input
Nx.Defn.grad(a, fn x -> Nx.sum(op.(x)) end)

for each op above.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions