Discovered while implementing the boundary-wrapper approach for #1533. Three ops fail when their inputs are vectorized:
# cholesky
x =
Nx.tensor([
[[4.0, 2.0], [2.0, 5.0]],
[[9.0, 3.0], [3.0, 5.0]],
[[16.0, 4.0], [4.0, 8.0]]
])
|> Nx.vectorize(:batch)
Nx.Defn.grad(x, fn x -> Nx.sum(Nx.LinAlg.cholesky(x)) end)
# triangular_solve with a captured non-vectorized matrix
a = Nx.tensor([[1.0, 0.0], [2.0, 3.0]])
b = Nx.tensor([[4.0, 5.0], [2.0, 3.0], [1.0, 1.0]]) |> Nx.vectorize(:batch)
Nx.Defn.grad(b, fn b -> Nx.sum(Nx.LinAlg.triangular_solve(a, b)) end)
# cond with vectorized input
x = Nx.tensor([[2.0, 3.0], [-5.0, -6.0], [1.0, 1.0]]) |> Nx.vectorize(:batch)
Nx.Defn.grad(x, fn x -> Nx.sum(Nx.select(Nx.greater(x, 0), x, -x)) end)
Root cause: all three are apply_vectorized ops — they devectorize their inputs internally and re-vectorize the result. That internal re-vectorization is incompatible with a boundary-wrapper grad that runs the recursion in fully-devectorized space. @polvalente noted on prior #1533 work that the boundary wrapper is allowed to forbid internal re-vectorization, so the fix is in the ops themselves: rewrite each to operate without internal re-vectorization, OR document the limitation as a permanent constraint of the boundary approach.
Tests for all three exist in nx/test/nx/defn/grad_test.exs, in blasphemetheus#8 currently skipped with comments pointing at this root cause.
Discovered while implementing the boundary-wrapper approach for #1533. Three ops fail when their inputs are vectorized:
Root cause: all three are
apply_vectorizedops — they devectorize their inputs internally and re-vectorize the result. That internal re-vectorization is incompatible with a boundary-wrapper grad that runs the recursion in fully-devectorized space. @polvalente noted on prior #1533 work that the boundary wrapper is allowed to forbid internal re-vectorization, so the fix is in the ops themselves: rewrite each to operate without internal re-vectorization, OR document the limitation as a permanent constraint of the boundary approach.Tests for all three exist in
nx/test/nx/defn/grad_test.exs, in blasphemetheus#8 currently skipped with comments pointing at this root cause.