Commit 683fc91
authored
fix: use exp2 for norm_qk_scale to correctly exponentiate log-space parameter
Bug: In MultiHeadAttention, when normalize_qk=True, the learned scale parameter 'norm_qk_scale' was initialized using nn.initializers.constant(jnp.log2(seq_len_kv**2 - ...)) — storing a log2 value — but then used directly as a linear scale factor passed to _dot_product_attention.
This means:
- The initial scale ≈ 2*log2(seq_len_kv) instead of seq_len_kv^2 as intended
- The parameter semantics are broken: gradient updates act on the wrong manifold
- For seq_len_kv=64, the actual initial scale is ~12, not 4096
Fix:
1. Change the initializer to nn.initializers.zeros_init() so the parameter represents a log2-space scale (exp2(0) = 1 is a sensible default)
2. Add scale = jnp.exp2(scale) after the self.param() call to correctly convert to linear space before use
This matches the intent of storing the scale in log space for unconstrained optimization while ensuring the linear scale is always positive.1 parent ac74c70 commit 683fc91
1 file changed
Lines changed: 3 additions & 4 deletions
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
300 | 300 | | |
301 | 301 | | |
302 | 302 | | |
303 | | - | |
304 | | - | |
305 | | - | |
306 | | - | |
| 303 | + | |
| 304 | + | |
307 | 305 | | |
| 306 | + | |
308 | 307 | | |
309 | 308 | | |
310 | 309 | | |
| |||
0 commit comments