Skip to content

Nx.Defn.while gradient incorrect when loop body's Jacobian w.r.t. accumulator depends on differentiated variable #1747

@blasphemetheus

Description

@blasphemetheus

Nx.Defn.grad through an Nx.Defn.while loop miscomputes the gradient w.r.t. any variable whose presence makes the loop body's Jacobian ∂body/∂acc depend on that variable. Forward pass is correct; only the backward pass is wrong.

Minimal reproducer

defmodule M do
  import Nx.Defn

  defn square_via_while(x) do
    {acc, _, _} =
      while {acc = Nx.tensor(1.0, type: :f32), i = 0, x = x}, Nx.less(i, 2) do
        {Nx.multiply(acc, x), i + 1, x}
      end
    acc
  end
end

x = Nx.tensor(0.5, type: :f32)
Nx.to_number(M.square_via_while(x))               # 0.25  — forward correct (x²)
Nx.Defn.grad(x, &M.square_via_while/1) |> Nx.to_number()
# 1.25  — wrong; expected 1.0 (= 2·x at x=0.5)

Characterization

For the loop above, Nx returns Σₖ x^(2k) for k=0..n-1 instead of the correct n·x^(n-1). The bug fires whenever ∂body/∂acc contains the differentiated variable:

Loop body ∂body/∂acc x=0.5, n=3 expected Nx returns OK?
acc + x 1 3.0 3.0
acc * 2 2 0.0 0.0
sin(acc) cos(acc) 0.0 0.0
acc * acc 2·acc 0.0 0.0
acc + x² 1 3.0 3.0
acc * x x 0.75 1.3125
acc / x 1/x -48.0 -84.0
sin(acc * x) cos(·)·x 0.8920 (f.d.) 1.0819
acc + acc*x 1 + x 6.75 8.3125

Per-variable: in a body like acc * x + y, ∂/∂y is correct (y in additive position), ∂/∂x is wrong.

Catastrophic near zero

x Correct grad of Nx returns Ratio
0.001 3×10⁻⁶ 1.0 333,334×
0.01 3×10⁻⁴ 1.0001 3,334×
0.1 0.030 1.010 34×
0.5 0.75 1.3125 1.75×
1.0 3.0 3.0 1× (coincidental)
2.0 12.0 21.0 1.75×

As x → 0 correct → 0 but buggy → 1. Any gradient-based optimizer starting near the origin steps in a spurious direction.

Sign flips

With acc * (x-2) or acc * (-x), the buggy formula is always positive (sum of even powers) while the correct gradient can be negative:

Loop body x n Correct Nx returns
acc * (x-2) 0.5 2 -3.0 +3.25 sign flip
acc * (x-2) 1.5 4 -0.5 +1.33 sign flip
acc * (-x) 0.3 3 +0.27 -1.10 sign flip
acc * (-x) 0.7 3 +1.47 -1.73 sign flip

An optimizer trusting these would move away from the optimum.

Additional properties

  • Type-independent: identical wrong values for :f32, :f64, :bf16 — not a backend issue.
  • Syntax-independent: both while {state}, cond do and while {state}, i <- range do exhibit it.
  • Compounds for second-order: d²/dx² x³ through the loop gives 1.5 at x=0.5 (correct: 3.0), 36 at x=2 (correct: 12).
  • Scales with tensor x: element-wise wrong for each position in a vector-valued x.
  • Nx.vectorize compounds cleanly: each batch slice gets its own individually-wrong value.
  • Forward pass is always correct — the bug is strictly in reverse-mode accumulation.

Use cases affected

Any defn that uses a while loop with multiplicative accumulation and backprops through it:

  • Power series / Taylor expansions
  • Iterative refinement (Newton's method, power iteration)
  • Matrix exponentials via Padé iterations
  • Custom RNN / Markov-chain dynamics
  • Learning-rate schedules multiplied against state
  • Any user-written iterative algorithm where each step depends multiplicatively on the input

Users of Nx.LinAlg ops are not affected: QR / Cholesky / SVD / LU all have custom_grad formulas that bypass the internal while. The bug manifests for user-written defn code that uses while directly.

Where the fix likely lives

In nx/lib/nx/defn/grad.ex's while handling — reverse-mode accumulation across iterations needs to properly compose the Jacobian at each step, including the x-dependence introduced by the body.

Found by a fuzz sweep over Nx.Defn.while gradient patterns.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions