Skip to content

Commit 683fc91

Browse files
authored
fix: use exp2 for norm_qk_scale to correctly exponentiate log-space parameter
Bug: In MultiHeadAttention, when normalize_qk=True, the learned scale parameter 'norm_qk_scale' was initialized using nn.initializers.constant(jnp.log2(seq_len_kv**2 - ...)) — storing a log2 value — but then used directly as a linear scale factor passed to _dot_product_attention. This means: - The initial scale ≈ 2*log2(seq_len_kv) instead of seq_len_kv^2 as intended - The parameter semantics are broken: gradient updates act on the wrong manifold - For seq_len_kv=64, the actual initial scale is ~12, not 4096 Fix: 1. Change the initializer to nn.initializers.zeros_init() so the parameter represents a log2-space scale (exp2(0) = 1 is a sensible default) 2. Add scale = jnp.exp2(scale) after the self.param() call to correctly convert to linear space before use This matches the intent of storing the scale in log space for unconstrained optimization while ensuring the linear scale is always positive.
1 parent ac74c70 commit 683fc91

1 file changed

Lines changed: 3 additions & 4 deletions

File tree

hackable_diffusion/lib/architecture/attention.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -300,11 +300,10 @@ def __call__(
300300
if self.normalize_qk:
301301
scale = self.param(
302302
"norm_qk_scale",
303-
nn.initializers.constant(
304-
jnp.log2(seq_len_kv**2 - seq_len_kv + SAFETY_EPSILON)
305-
),
306-
(1, 1, 1, 1),
303+
nn.initializers.zeros_init(),
304+
(1, 1, 1, 1),
307305
)
306+
scale = jnp.exp2(scale)
308307

309308
norm_q = jnp.linalg.norm(q, ord=2, axis=-1, keepdims=True)
310309
norm_k = jnp.linalg.norm(k, ord=2, axis=-1, keepdims=True)

0 commit comments

Comments
 (0)