Nx.LinAlg.invert gradient raises on batched inputs. The custom_grad at nx/lib/nx/lin_alg.ex:866 uses Nx.dot without batch axes, producing a higher-rank intermediate that can't broadcast.
x = Nx.tensor([
[[4.0, 1.0], [1.0, 5.0]],
[[9.0, 1.0], [1.0, 10.0]]
]) # {2, 2, 2}
Nx.Defn.grad(x, fn a -> Nx.sum(Nx.LinAlg.invert(a)) end)
# ** (ArgumentError) cannot broadcast tensor of dimensions
# {2, 3, 2, 2, 3} to {2, 3, 3}
Same class as #1741/#1742/#1743 (triangular_solve / LU / SVD batched-grad bugs). Fix pattern mirrors QR/Cholesky in #1731: use batch_axes/1 + Nx.dot(a, [-1], ba, g, [-2], ba) style in the custom_grad formula.
Found while extending the batched-grad fuzz probe across remaining hand-derived-grad linalg ops.
Nx.LinAlg.invertgradient raises on batched inputs. Thecustom_gradatnx/lib/nx/lin_alg.ex:866usesNx.dotwithout batch axes, producing a higher-rank intermediate that can't broadcast.Same class as #1741/#1742/#1743 (triangular_solve / LU / SVD batched-grad bugs). Fix pattern mirrors QR/Cholesky in #1731: use
batch_axes/1+Nx.dot(a, [-1], ba, g, [-2], ba)style in the custom_grad formula.Found while extending the batched-grad fuzz probe across remaining hand-derived-grad linalg ops.