Skip to content

Commit 36902d4

Browse files
Fareed SheriffHackable Diffusion Authors
authored andcommitted
Refactor Hackable Diffusion for improved XLA performance and Flash Attention support.
PiperOrigin-RevId: 908804155
1 parent 8ab544c commit 36902d4

5 files changed

Lines changed: 79 additions & 13 deletions

File tree

hackable_diffusion/lib/architecture/attention.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,14 @@ def _dot_product_attention(
142142
The output tensor.
143143
"""
144144

145-
b, _, t, _ = q.shape
145+
b, _, t, head_d = q.shape
146+
147+
# Use jax.nn.dot_product_attention for hardware acceleration when possible.
148+
if mask is None and hasattr(jax.nn, "dot_product_attention"):
149+
q_scaled = q * (rescale * jnp.sqrt(head_d))
150+
attn_output = jax.nn.dot_product_attention(q_scaled, k, v)
151+
attn_output = attn_output.transpose(0, 2, 1, 3).reshape(b, t, -1)
152+
return attn_output
146153

147154
# Attention scores
148155
attn_logits = jnp.einsum("bhtd,bhsd->bhts", q, k) * rescale

hackable_diffusion/lib/architecture/attention_test.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -409,5 +409,21 @@ def test_multi_head_attention_invalid_mask_shape_raises_error(
409409
module.init(self.rng, self.x, c, mask=invalid_mask)
410410

411411

412+
413+
def test_optimized_attention_path(self):
414+
"""Tests the optimized attention path (unmasked)."""
415+
module = attention.MultiHeadAttention(num_heads=self.num_heads)
416+
variables = module.init(self.rng, self.x, self.c)
417+
output = module.apply(variables, self.x, self.c)
418+
self.assertEqual(output.shape, self.x.shape)
419+
420+
def test_masked_attention_path(self):
421+
"""Tests the manual attention path (masked)."""
422+
module = attention.MultiHeadAttention(num_heads=self.num_heads)
423+
mask = jnp.ones((self.batch_size, self.seq_len_kv), dtype=jnp.bool_)
424+
variables = module.init(self.rng, self.x, self.c, mask=mask)
425+
output = module.apply(variables, self.x, self.c, mask=mask)
426+
self.assertEqual(output.shape, self.x.shape)
427+
412428
if __name__ == "__main__":
413429
absltest.main()

hackable_diffusion/lib/architecture/dit_blocks.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -158,27 +158,30 @@ def __call__(
158158
The output tensor.
159159
"""
160160

161+
# Precompute activation for conditioning
162+
cond_act = nn.silu(cond)
163+
161164
# Attention Branch
162-
x_attn_modulated = self.conditional_norm(x, c=nn.silu(cond))
165+
x_attn_modulated = self.conditional_norm(x, c=cond_act)
163166
attn_out = self.attn(x_attn_modulated, c=None, mask=mask)
164167
# Optional dropout
165168
if self.dropout_rate > 0.0:
166169
attn_out = nn.Dropout(rate=self.dropout_rate)(
167170
attn_out, deterministic=not is_training
168171
)
169-
gate_msa = self.gate_msa(nn.silu(cond))
172+
gate_msa = self.gate_msa(cond_act)
170173
# Add a sequence dimension [...,None,:] to broadcast to [*batch,seq,dim].
171174
x = x + gate_msa[..., None, :] * attn_out
172175

173176
# MLP Branch
174-
x_mlp_modulated = self.conditional_norm(x, c=nn.silu(cond))
177+
x_mlp_modulated = self.conditional_norm(x, c=cond_act)
175178
mlp_out = self.mlp(x_mlp_modulated, is_training=is_training)
176179
# Optional dropout
177180
if self.dropout_rate > 0.0:
178181
mlp_out = nn.Dropout(rate=self.dropout_rate)(
179182
mlp_out, deterministic=not is_training
180183
)
181-
gate_mlp = self.gate_mlp(nn.silu(cond))
184+
gate_mlp = self.gate_mlp(cond_act)
182185
# Add a sequence dimension [...,None,:] to broadcast to [*batch,seq,dim].
183186
x = x + gate_mlp[..., None, :] * mlp_out
184187
return x
@@ -267,7 +270,8 @@ def __call__(
267270
hn = h // hp
268271
wn = w // wp
269272

270-
x = self.conditional_norm(x, c=nn.silu(cond))
273+
cond_act = nn.silu(cond)
274+
x = self.conditional_norm(x, c=cond_act)
271275
x = nn.Dense(
272276
features=hp * wp * c,
273277
name="Dense_Out",

hackable_diffusion/lib/architecture/normalization.py

Lines changed: 34 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from hackable_diffusion.lib import hd_typing
2626
from hackable_diffusion.lib import utils
2727
from hackable_diffusion.lib.architecture import arch_typing
28+
import jax
2829
import jax.numpy as jnp
2930
import kauldron.ktyping as kt
3031

@@ -40,6 +41,17 @@
4041
NormalizationType = arch_typing.NormalizationType
4142

4243

44+
45+
################################################################################
46+
# MARK: Fused Kernels
47+
################################################################################
48+
49+
def fused_rms_norm(x, scale, epsilon=1e-6):
50+
"""Fused RMSNorm implementation for XLA efficiency."""
51+
ms = jnp.mean(jnp.square(x), axis=-1, keepdims=True)
52+
return x * jax.lax.rsqrt(ms + epsilon) * scale
53+
54+
4355
################################################################################
4456
# MARK: NormalizationLayer
4557
################################################################################
@@ -128,13 +140,17 @@ def __call__(
128140
ch = x_shape[-1]
129141

130142
if self.normalization_method == NormalizationType.RMS_NORM:
131-
x = nn.RMSNorm(
132-
epsilon=self.epsilon,
133-
dtype=self.dtype,
134-
reduction_axes=-1, # For (B ... ch) results in (B ... ) RMS values.
135-
feature_axes=-1, # Per channel scale.
136-
use_scale=self.use_scale,
137-
)(x=x, mask=mask)
143+
if mask is None and self.use_scale:
144+
scale = self.param("scale", nn.initializers.ones, (ch,), self.dtype)
145+
x = fused_rms_norm(x, scale, self.epsilon)
146+
else:
147+
x = nn.RMSNorm(
148+
epsilon=self.epsilon,
149+
dtype=self.dtype,
150+
reduction_axes=-1, # For (B ... ch) results in (B ... ) RMS values.
151+
feature_axes=-1, # Per channel scale.
152+
use_scale=self.use_scale,
153+
)(x=x, mask=mask)
138154
elif self.normalization_method == NormalizationType.GROUP_NORM:
139155

140156
# If using GroupNorm the mask data must be such that the last dimension
@@ -187,6 +203,17 @@ def __call__(
187203
return x
188204

189205

206+
207+
################################################################################
208+
# MARK: Fused Kernels
209+
################################################################################
210+
211+
def fused_rms_norm(x, scale, epsilon=1e-6):
212+
"""Fused RMSNorm implementation for XLA efficiency."""
213+
ms = jnp.mean(jnp.square(x), axis=-1, keepdims=True)
214+
return x * jax.lax.rsqrt(ms + epsilon) * scale
215+
216+
190217
################################################################################
191218
# MARK: NormalizationLayerFactory
192219
################################################################################

hackable_diffusion/lib/architecture/normalization_test.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -469,5 +469,17 @@ def test_rmsnorm_mask_equivalence(self):
469469
)
470470

471471

472+
473+
def test_fused_rms_norm_path(self):
474+
"""Tests the fused RMSNorm path (unmasked)."""
475+
module = normalization.NormalizationLayer(
476+
normalization_method=normalization.NormalizationType.RMS_NORM,
477+
conditional=False
478+
)
479+
x = jnp.ones((2, 16, 32))
480+
variables = module.init(self.rng, x)
481+
output = module.apply(variables, x)
482+
self.assertEqual(output.shape, x.shape)
483+
472484
if __name__ == "__main__":
473485
absltest.main()

0 commit comments

Comments
 (0)