Nx.LinAlg.lu gradient raises on batched inputs. lu_grad uses unqualified Nx.dot calls that contract the wrong axes when the input has a leading batch dimension.
x = Nx.tensor([
[[4.0, 3.0], [6.0, 3.0]],
[[2.0, 1.0], [5.0, 7.0]]
]) # {2, 2, 2}
Nx.Defn.grad(x, fn t ->
{_p, l, u} = Nx.LinAlg.lu(t)
Nx.add(Nx.sum(l), Nx.sum(u))
end)
# ** (ArgumentError) cannot broadcast tensor of dimensions
# {2, 2, 2, 2, 2, 2, 2} to {2, 2, 2}
Same class as #1741 (triangular_solve), and same class as the QR/Cholesky batched-grad bugs fixed on the branch behind #1731. Likely fixable with the same batch_axes/1 helper pattern — Nx.dot(a, [-1], ba, b, [-2], ba) instead of Nx.dot(a, b).
Nx.LinAlg.lugradient raises on batched inputs.lu_graduses unqualifiedNx.dotcalls that contract the wrong axes when the input has a leading batch dimension.Same class as #1741 (triangular_solve), and same class as the QR/Cholesky batched-grad bugs fixed on the branch behind #1731. Likely fixable with the same
batch_axes/1helper pattern —Nx.dot(a, [-1], ba, b, [-2], ba)instead ofNx.dot(a, b).