Reproducer:
x = Nx.tensor([[1], [2], [3]]) |> Nx.vectorize(x: 3)
Nx.Defn.grad(x, fn t ->
devec = Nx.devectorize(t, keep_names: true)
new_axis = Nx.reshape(devec, {1, 3, 1}, names: [:x, nil, nil])
Nx.vectorize(new_axis, x: 1)
end)
The input has vec axis x: 3, the output has vec axis x: 1. Same axis name, different sizes — they're conceptually different axes that share a name. Grad's boundary devectorization gets confused about which one to seed against.
Discovered while implementing the boundary-wrapper approach for #1533. Distinct from the heterogenous-axes case (only one input here; the conflict is between input and output shapes on a same-named axis). The skipped test is in nx/test/nx/defn/grad_test.exs under "edge case where the same name changes meaning" (in a fork draft PR blasphemetheus#8)
Reproducer:
The input has vec axis
x: 3, the output has vec axisx: 1. Same axis name, different sizes — they're conceptually different axes that share a name. Grad's boundary devectorization gets confused about which one to seed against.Discovered while implementing the boundary-wrapper approach for #1533. Distinct from the heterogenous-axes case (only one input here; the conflict is between input and output shapes on a same-named axis). The skipped test is in
nx/test/nx/defn/grad_test.exsunder "edge case where the same name changes meaning" (in a fork draft PR blasphemetheus#8)