Skip to content

Vectorized gradient fails for apply_vectorized ops (Cholesky, triangular_solve, cond) #1729

@blasphemetheus

Description

@blasphemetheus

Discovered while implementing the boundary-wrapper approach for #1533. Three ops fail when their inputs are vectorized:

# cholesky
x =
  Nx.tensor([
    [[4.0, 2.0], [2.0, 5.0]],
    [[9.0, 3.0], [3.0, 5.0]],
    [[16.0, 4.0], [4.0, 8.0]]
  ])
  |> Nx.vectorize(:batch)

Nx.Defn.grad(x, fn x -> Nx.sum(Nx.LinAlg.cholesky(x)) end)
# triangular_solve with a captured non-vectorized matrix
a = Nx.tensor([[1.0, 0.0], [2.0, 3.0]])
b = Nx.tensor([[4.0, 5.0], [2.0, 3.0], [1.0, 1.0]]) |> Nx.vectorize(:batch)

Nx.Defn.grad(b, fn b -> Nx.sum(Nx.LinAlg.triangular_solve(a, b)) end)
# cond with vectorized input
x = Nx.tensor([[2.0, 3.0], [-5.0, -6.0], [1.0, 1.0]]) |> Nx.vectorize(:batch)
Nx.Defn.grad(x, fn x -> Nx.sum(Nx.select(Nx.greater(x, 0), x, -x)) end)

Root cause: all three are apply_vectorized ops — they devectorize their inputs internally and re-vectorize the result. That internal re-vectorization is incompatible with a boundary-wrapper grad that runs the recursion in fully-devectorized space. @polvalente noted on prior #1533 work that the boundary wrapper is allowed to forbid internal re-vectorization, so the fix is in the ops themselves: rewrite each to operate without internal re-vectorization, OR document the limitation as a permanent constraint of the boundary approach.

Tests for all three exist in nx/test/nx/defn/grad_test.exs, in blasphemetheus#8 currently skipped with comments pointing at this root cause.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions