Skip to content

Nx.LinAlg.triangular_solve gradient broken for batched inputs #1741

@blasphemetheus

Description

@blasphemetheus

Nx.LinAlg.triangular_solve gradient raises on batched (3D+) inputs — failure mode is a shape-mismatch in the Nx.dot call inside the custom grad formula. Related to but distinct from #1729: this one uses a plain batched tensor with no Nx.vectorize wrapping.

a = Nx.tensor([
      [[4.0, 0.0], [2.0, 5.0]],
      [[9.0, 0.0], [3.0, 10.0]]
    ])                                # {2, 2, 2} lower-triangular
b = Nx.broadcast(1.0, {2, 2})

Nx.Defn.grad(a, fn x ->
  Nx.sum(Nx.LinAlg.triangular_solve(x, b))
end)
# ** (ArgumentError) dot/zip expects shapes to be compatible,
#    dimension 2 of left-side (2) does not equal dimension 0 of right-side (2)

Same class of bug as the LU/SVD batched-grad issues — the custom_grad formulas were written assuming 2D inputs and haven't been made batch-aware. Likely fixable with the same batch_axes/1 helper pattern PR #1731 applies to qr_grad and cholesky_grad.

Found by a batched-grad fuzz probe while auditing linalg gradients.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions