Nx.Defn.grad through an Nx.Defn.while loop miscomputes the gradient w.r.t. any variable whose presence makes the loop body's Jacobian ∂body/∂acc depend on that variable. Forward pass is correct; only the backward pass is wrong.
Minimal reproducer
defmodule M do
import Nx.Defn
defn square_via_while(x) do
{acc, _, _} =
while {acc = Nx.tensor(1.0, type: :f32), i = 0, x = x}, Nx.less(i, 2) do
{Nx.multiply(acc, x), i + 1, x}
end
acc
end
end
x = Nx.tensor(0.5, type: :f32)
Nx.to_number(M.square_via_while(x)) # 0.25 — forward correct (x²)
Nx.Defn.grad(x, &M.square_via_while/1) |> Nx.to_number()
# 1.25 — wrong; expected 1.0 (= 2·x at x=0.5)
Characterization
For the loop above, Nx returns Σₖ x^(2k) for k=0..n-1 instead of the correct n·x^(n-1). The bug fires whenever ∂body/∂acc contains the differentiated variable:
| Loop body |
∂body/∂acc |
x=0.5, n=3 expected |
Nx returns |
OK? |
acc + x |
1 |
3.0 |
3.0 |
✓ |
acc * 2 |
2 |
0.0 |
0.0 |
✓ |
sin(acc) |
cos(acc) |
0.0 |
0.0 |
✓ |
acc * acc |
2·acc |
0.0 |
0.0 |
✓ |
acc + x² |
1 |
3.0 |
3.0 |
✓ |
acc * x |
x |
0.75 |
1.3125 |
✗ |
acc / x |
1/x |
-48.0 |
-84.0 |
✗ |
sin(acc * x) |
cos(·)·x |
0.8920 (f.d.) |
1.0819 |
✗ |
acc + acc*x |
1 + x |
6.75 |
8.3125 |
✗ |
Per-variable: in a body like acc * x + y, ∂/∂y is correct (y in additive position), ∂/∂x is wrong.
Catastrophic near zero
| x |
Correct grad of x³ |
Nx returns |
Ratio |
| 0.001 |
3×10⁻⁶ |
1.0 |
333,334× |
| 0.01 |
3×10⁻⁴ |
1.0001 |
3,334× |
| 0.1 |
0.030 |
1.010 |
34× |
| 0.5 |
0.75 |
1.3125 |
1.75× |
| 1.0 |
3.0 |
3.0 |
1× (coincidental) |
| 2.0 |
12.0 |
21.0 |
1.75× |
As x → 0 correct → 0 but buggy → 1. Any gradient-based optimizer starting near the origin steps in a spurious direction.
Sign flips
With acc * (x-2) or acc * (-x), the buggy formula is always positive (sum of even powers) while the correct gradient can be negative:
| Loop body |
x |
n |
Correct |
Nx returns |
|
acc * (x-2) |
0.5 |
2 |
-3.0 |
+3.25 |
sign flip |
acc * (x-2) |
1.5 |
4 |
-0.5 |
+1.33 |
sign flip |
acc * (-x) |
0.3 |
3 |
+0.27 |
-1.10 |
sign flip |
acc * (-x) |
0.7 |
3 |
+1.47 |
-1.73 |
sign flip |
An optimizer trusting these would move away from the optimum.
Additional properties
- Type-independent: identical wrong values for
:f32, :f64, :bf16 — not a backend issue.
- Syntax-independent: both
while {state}, cond do and while {state}, i <- range do exhibit it.
- Compounds for second-order:
d²/dx² x³ through the loop gives 1.5 at x=0.5 (correct: 3.0), 36 at x=2 (correct: 12).
- Scales with tensor x: element-wise wrong for each position in a vector-valued x.
Nx.vectorize compounds cleanly: each batch slice gets its own individually-wrong value.
- Forward pass is always correct — the bug is strictly in reverse-mode accumulation.
Use cases affected
Any defn that uses a while loop with multiplicative accumulation and backprops through it:
- Power series / Taylor expansions
- Iterative refinement (Newton's method, power iteration)
- Matrix exponentials via Padé iterations
- Custom RNN / Markov-chain dynamics
- Learning-rate schedules multiplied against state
- Any user-written iterative algorithm where each step depends multiplicatively on the input
Users of Nx.LinAlg ops are not affected: QR / Cholesky / SVD / LU all have custom_grad formulas that bypass the internal while. The bug manifests for user-written defn code that uses while directly.
Where the fix likely lives
In nx/lib/nx/defn/grad.ex's while handling — reverse-mode accumulation across iterations needs to properly compose the Jacobian at each step, including the x-dependence introduced by the body.
Found by a fuzz sweep over Nx.Defn.while gradient patterns.
Nx.Defn.gradthrough anNx.Defn.whileloop miscomputes the gradient w.r.t. any variable whose presence makes the loop body's Jacobian∂body/∂accdepend on that variable. Forward pass is correct; only the backward pass is wrong.Minimal reproducer
Characterization
For the loop above, Nx returns
Σₖ x^(2k)for k=0..n-1 instead of the correctn·x^(n-1). The bug fires whenever∂body/∂acccontains the differentiated variable:∂body/∂accacc + xacc * 2sin(acc)acc * accacc + x²acc * xacc / xsin(acc * x)acc + acc*xPer-variable: in a body like
acc * x + y,∂/∂yis correct (y in additive position),∂/∂xis wrong.Catastrophic near zero
x³As
x → 0correct → 0 but buggy → 1. Any gradient-based optimizer starting near the origin steps in a spurious direction.Sign flips
With
acc * (x-2)oracc * (-x), the buggy formula is always positive (sum of even powers) while the correct gradient can be negative:acc * (x-2)acc * (x-2)acc * (-x)acc * (-x)An optimizer trusting these would move away from the optimum.
Additional properties
:f32,:f64,:bf16— not a backend issue.while {state}, cond doandwhile {state}, i <- range doexhibit it.d²/dx² x³through the loop gives1.5at x=0.5 (correct: 3.0),36at x=2 (correct: 12).Nx.vectorizecompounds cleanly: each batch slice gets its own individually-wrong value.Use cases affected
Any
defnthat uses awhileloop with multiplicative accumulation and backprops through it:Users of
Nx.LinAlgops are not affected: QR / Cholesky / SVD / LU all havecustom_gradformulas that bypass the internalwhile. The bug manifests for user-writtendefncode that useswhiledirectly.Where the fix likely lives
In
nx/lib/nx/defn/grad.ex's while handling — reverse-mode accumulation across iterations needs to properly compose the Jacobian at each step, including the x-dependence introduced by the body.Found by a fuzz sweep over
Nx.Defn.whilegradient patterns.