Nx.LinAlg.eigh gradient fails for two distinct input classes on current main. Both are regular-shape tensors (no Nx.vectorize involved).
A. 2D input — any dtype:
x = Nx.tensor([[4.0, 2.0], [2.0, 5.0]], type: :f32)
Nx.Defn.grad(x, fn a -> {s, _} = Nx.LinAlg.eigh(a); Nx.sum(s) end)
# ** (ArgumentError) cannot reshape, current shape {} is not compatible with new shape {1, 2, 2}
B. f64 input (any rank):
x = Nx.tensor([[[4.0, 2.0], [2.0, 5.0]]], type: :f64) # {1, 2, 2} f64
Nx.Defn.grad(x, fn a -> {s, _} = Nx.LinAlg.eigh(a); Nx.sum(s) end)
# ** (ArgumentError) unexpected size for tensor data, expected 32 bits got: 64 bits
Only working case: 3D batched :f32:
Nx.tensor([[[4.0, 2.0], [2.0, 5.0]]], type: :f32) |> Nx.Defn.grad(fn a -> ... end) # ✓
Found by a batched-grad fuzz probe while auditing linalg gradients after the LU/SVD batched-grad bugs. A parameterized (rank, dtype) coverage property fails on every input combination except the one shown as working.
Nx.LinAlg.eighgradient fails for two distinct input classes on currentmain. Both are regular-shape tensors (noNx.vectorizeinvolved).A. 2D input — any dtype:
B. f64 input (any rank):
Only working case: 3D batched
:f32:Found by a batched-grad fuzz probe while auditing linalg gradients after the LU/SVD batched-grad bugs. A parameterized
(rank, dtype)coverage property fails on every input combination except the one shown as working.