Skip to content

Nx.LinAlg.eigh gradient broken for 2D inputs and for all f64 inputs #1740

@blasphemetheus

Description

@blasphemetheus

Nx.LinAlg.eigh gradient fails for two distinct input classes on current main. Both are regular-shape tensors (no Nx.vectorize involved).

A. 2D input — any dtype:

x = Nx.tensor([[4.0, 2.0], [2.0, 5.0]], type: :f32)
Nx.Defn.grad(x, fn a -> {s, _} = Nx.LinAlg.eigh(a); Nx.sum(s) end)
# ** (ArgumentError) cannot reshape, current shape {} is not compatible with new shape {1, 2, 2}

B. f64 input (any rank):

x = Nx.tensor([[[4.0, 2.0], [2.0, 5.0]]], type: :f64)   # {1, 2, 2} f64
Nx.Defn.grad(x, fn a -> {s, _} = Nx.LinAlg.eigh(a); Nx.sum(s) end)
# ** (ArgumentError) unexpected size for tensor data, expected 32 bits got: 64 bits

Only working case: 3D batched :f32:

Nx.tensor([[[4.0, 2.0], [2.0, 5.0]]], type: :f32) |> Nx.Defn.grad(fn a -> ... end)   # ✓

Found by a batched-grad fuzz probe while auditing linalg gradients after the LU/SVD batched-grad bugs. A parameterized (rank, dtype) coverage property fails on every input combination except the one shown as working.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions