Nx.LinAlg.triangular_solve gradient raises on batched (3D+) inputs — failure mode is a shape-mismatch in the Nx.dot call inside the custom grad formula. Related to but distinct from #1729: this one uses a plain batched tensor with no Nx.vectorize wrapping.
a = Nx.tensor([
[[4.0, 0.0], [2.0, 5.0]],
[[9.0, 0.0], [3.0, 10.0]]
]) # {2, 2, 2} lower-triangular
b = Nx.broadcast(1.0, {2, 2})
Nx.Defn.grad(a, fn x ->
Nx.sum(Nx.LinAlg.triangular_solve(x, b))
end)
# ** (ArgumentError) dot/zip expects shapes to be compatible,
# dimension 2 of left-side (2) does not equal dimension 0 of right-side (2)
Same class of bug as the LU/SVD batched-grad issues — the custom_grad formulas were written assuming 2D inputs and haven't been made batch-aware. Likely fixable with the same batch_axes/1 helper pattern PR #1731 applies to qr_grad and cholesky_grad.
Found by a batched-grad fuzz probe while auditing linalg gradients.
Nx.LinAlg.triangular_solvegradient raises on batched (3D+) inputs — failure mode is a shape-mismatch in theNx.dotcall inside the custom grad formula. Related to but distinct from #1729: this one uses a plain batched tensor with noNx.vectorizewrapping.Same class of bug as the LU/SVD batched-grad issues — the
custom_gradformulas were written assuming 2D inputs and haven't been made batch-aware. Likely fixable with the samebatch_axes/1helper pattern PR #1731 applies toqr_gradandcholesky_grad.Found by a batched-grad fuzz probe while auditing linalg gradients.