You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Consolidated tracking for a class of related gradient bugs in Nx.LinAlg: hand-written custom_grad formulas for linalg ops were written assuming 2-D input tensors and are broken when the input has a leading batch dimension. Forward passes work; only the backward pass miscomputes.
All individual cases are filed as their own focused issues so each has a clean target for a fix PR (see links below). I'll close them as duplicates as they'll likely be addressed in the same PR.
Root cause (shared across all cases)
Several linalg ops have custom_grad formulas inside nx/lib/nx/lin_alg.ex and nx/lib/nx/lin_alg/*.ex. Those formulas typically contain Nx.dot(a, b) calls (no explicit axes / no batch axes), Nx.shape(input) pattern-matches on 2-tuples, or Nx.eye(l) calls assuming 2-D. For 2-D inputs this works; for higher-rank inputs (any batched tensor), the formula produces wrong shapes or crashes.
Fix pattern
For most of these ops, the fix mirrors what PR #1731 did for qr_grad and (partially) cholesky_grad: add a small batch_axes/1 helper that returns the list of leading batch axes, and change every Nx.dot(a, b) to Nx.dot(a, [-1], ba, b, [-2], ba) so the contraction axes are correct regardless of leading batch dims. Similar treatment for Nx.eye (build with the right shape), {m, n} = Nx.shape(input) (read last two dims instead of full shape), etc.
x=Nx.tensor([[[4.0,2.0],[2.0,5.0]],[[9.0,3.0],[3.0,10.0]]])Nx.Defn.grad(x,fna->Nx.sum(Nx.LinAlg.cholesky(a))end)# ** (ArgumentError) incompatible dimensions for a and b on triangular solve
Fix direction: mirror qr_grad's batch_axes/1 approach in cholesky_grad.
svd_grad pattern-matches {m, n} = Nx.shape(input), which fails immediately on 3-D input.
Repro:
x=Nx.tensor([[[4.0,3.0],[6.0,3.0]],[[2.0,1.0],[5.0,7.0]]])Nx.Defn.grad(x,fnt->{u,s,vt}=Nx.LinAlg.svd(t)Nx.add(Nx.add(Nx.sum(u),Nx.sum(s)),Nx.sum(vt))end)# ** (MatchError) no match of right hand side value: {2, 2, 2}
Fix direction: more invasive than the others — the whole svd_grad body assumes rank 2. Nx.eye(k)/Nx.eye(m)/Nx.eye(n), Nx.new_axis(s_sq, 1), and Nx.make_diagonal(s_input) all need to become batch-aware.
Distinct failure mode from the others — eigh has nocustom_grad, so grad flows through the internal Jacobi iterations (a while). Two sub-bugs:
2-D input fails a reshape: cannot reshape {} to {1, N, N} (the grad path assumes 3-D).
3-D f64 input fails: expected 32 bits got: 64 bits (hardcoded f32 somewhere in the eigh path).
Repro (2-D case):
x=Nx.tensor([[4.0,2.0],[2.0,5.0]],type: :f32)Nx.Defn.grad(x,fna->{s,_}=Nx.LinAlg.eigh(a);Nx.sum(s)end)# ** (ArgumentError) cannot reshape, current shape {} is not compatible with new shape {1, 2, 2}
Only working case: 3-D batch-1 :f32, which is numerically correct (returns the identity matrix for the trace-gradient test).
Fix direction: different from the others — no custom_grad to patch. Needs investigation inside the autograd path for the eigh while body.
Same class as [C]; shape mismatch in the dot inside the custom_grad.
Repro:
a=Nx.tensor([[[4.0,0.0],[2.0,5.0]],[[9.0,0.0],[3.0,10.0]]])b=Nx.broadcast(1.0,{2,2})Nx.Defn.grad(a,fnx->Nx.sum(Nx.LinAlg.triangular_solve(x,b))end)# ** (ArgumentError) dot/zip expects shapes to be compatible, dimension 2 of left-side (2)# does not equal dimension 0 of right-side (2)
Note: related to but distinct from #1729 (vectorized grad); this is plain batched grad with no Nx.vectorize wrapping.
Shape mismatch in the grad path (not a custom_grad; solve delegates internally).
Repro:
a=Nx.tensor([[[4.0,1.0,0.0],[1.0,5.0,0.0],[0.0,0.0,6.0]],[[9.0,1.0,0.0],[1.0,10.0,0.0],[0.0,0.0,11.0]]])b=Nx.broadcast(1.0,{2,3})Nx.Defn.grad(a,fnx->Nx.sum(Nx.LinAlg.solve(x,b))end)# ** (ArgumentError) dot/zip expects shapes to be compatible, dimension 2 of left-side (3)# does not equal dimension 2 of right-side (2)
Fix direction: investigation — solve doesn't have an obvious custom_grad to patch; the failure is likely in how its internal definition lowers under grad.
Passes on 3×3 but fails on 4×4 and larger (batched or not).
Repro:
x=Nx.tensor([[[5.0,1.0,2.0,3.0],[4.0,10.0,6.0,7.0],[8.0,9.0,15.0,11.0],[12.0,13.0,14.0,20.0]]])Nx.Defn.grad(x,fna->Nx.sum(Nx.LinAlg.determinant(a))end)# ** (ArgumentError) dot/zip expects shapes to be compatible, dimension 3 of left-side (4)# does not equal dimension 2 of right-side (1)
Fix direction: hardcoded axes in the grad's Nx.dot — same class as [C], [F], [G].
Out-of-scope but possibly related
pinv — fails for batched input with a shape-cond error. Likely downstream of [D]'s SVD grad bug since pinv is defined in terms of SVD. Will probably fix itself once [D] is addressed.
norm — fails forward pass on batched input (expected 1-D or 2-D tensor, got tensor with shape {2, 4, 4}). Not a grad bug but a forward-pass rank limitation. Different class entirely.
Five of these (B, C, F, G, I) look like copy-paste of the same small fix; D and H need deeper work; E is a distinct class (autograd through an internal while).
How this was found
Discovered via a property-based "batched-grad probe" in a local fuzz branch that exercises Nx.Defn.grad through each linalg op with random batched inputs. Anyone wanting to reproduce can do so with:
a=Nx.iota({2,3,3},type: :f32)# or any batched matrix inputNx.Defn.grad(a,fnx->Nx.sum(op.(x))end)
Consolidated tracking for a class of related gradient bugs in
Nx.LinAlg: hand-writtencustom_gradformulas for linalg ops were written assuming 2-D input tensors and are broken when the input has a leading batch dimension. Forward passes work; only the backward pass miscomputes.All individual cases are filed as their own focused issues so each has a clean target for a fix PR (see links below). I'll close them as duplicates as they'll likely be addressed in the same PR.
Root cause (shared across all cases)
Several linalg ops have
custom_gradformulas insidenx/lib/nx/lin_alg.exandnx/lib/nx/lin_alg/*.ex. Those formulas typically containNx.dot(a, b)calls (no explicit axes / no batch axes),Nx.shape(input)pattern-matches on 2-tuples, orNx.eye(l)calls assuming 2-D. For 2-D inputs this works; for higher-rank inputs (any batched tensor), the formula produces wrong shapes or crashes.Fix pattern
For most of these ops, the fix mirrors what PR #1731 did for
qr_gradand (partially)cholesky_grad: add a smallbatch_axes/1helper that returns the list of leading batch axes, and change everyNx.dot(a, b)toNx.dot(a, [-1], ba, b, [-2], ba)so the contraction axes are correct regardless of leading batch dims. Similar treatment forNx.eye(build with the right shape),{m, n} = Nx.shape(input)(read last two dims instead of full shape), etc.Cases
[A]
Nx.LinAlg.qrgrad[B]
Nx.LinAlg.choleskygrad — partial fix in #1731cholesky_gradwith axis-relativeNx.dot([-2], l, [-2])but thebatch_axes/1helper pattern used inqr_gradwasn't applied here. Result: still produces a 4D intermediate on batched input.qr_grad'sbatch_axes/1approach incholesky_grad.[C]
Nx.LinAlg.lugrad — #1742lu_graduses unqualifiedNx.dotcalls that contract wrong axes when the input has a leading batch dim.batch_axes/1helper inlu_grad.[D]
Nx.LinAlg.svdgrad — #1743svd_gradpattern-matches{m, n} = Nx.shape(input), which fails immediately on 3-D input.svd_gradbody assumes rank 2.Nx.eye(k)/Nx.eye(m)/Nx.eye(n),Nx.new_axis(s_sq, 1), andNx.make_diagonal(s_input)all need to become batch-aware.[E]
Nx.LinAlg.eighgrad — #1740eighhas nocustom_grad, so grad flows through the internal Jacobi iterations (awhile). Two sub-bugs:cannot reshape {} to {1, N, N}(the grad path assumes 3-D).expected 32 bits got: 64 bits(hardcoded f32 somewhere in the eigh path).:f32, which is numerically correct (returns the identity matrix for the trace-gradient test).custom_gradto patch. Needs investigation inside the autograd path for theeighwhile body.[F]
Nx.LinAlg.triangular_solvegrad — #1741custom_grad.Nx.vectorizewrapping.batch_axes/1helper.[G]
Nx.LinAlg.invertgrad — #1744custom_gradatnx/lib/nx/lin_alg.ex:866usesNx.dottwice without batch axes.batch_axes/1applied to both dots in the formula.[H]
Nx.LinAlg.solvegrad — #1745custom_grad;solvedelegates internally).solvedoesn't have an obviouscustom_gradto patch; the failure is likely in how its internal definition lowers under grad.[I]
Nx.LinAlg.determinantgrad — #1746Nx.dot— same class as [C], [F], [G].Out-of-scope but possibly related
pinv— fails for batched input with a shape-cond error. Likely downstream of [D]'s SVD grad bug sincepinvis defined in terms of SVD. Will probably fix itself once [D] is addressed.norm— fails forward pass on batched input (expected 1-D or 2-D tensor, got tensor with shape {2, 4, 4}). Not a grad bug but a forward-pass rank limitation. Different class entirely.At-a-glance status
batch_axes/1helper ✓batch_axes/1helper (not yet)batch_axes/1helperbatch_axes/1helperbatch_axes/1helperbatch_axes/1helperFive of these (B, C, F, G, I) look like copy-paste of the same small fix; D and H need deeper work; E is a distinct class (autograd through an internal
while).How this was found
Discovered via a property-based "batched-grad probe" in a local fuzz branch that exercises
Nx.Defn.gradthrough each linalg op with random batched inputs. Anyone wanting to reproduce can do so with:for each op above.