Skip to content

Commit dea3c6a

Browse files
Fareed SheriffHackable Diffusion Authors
authored andcommitted
Refactor Hackable Diffusion for peak XLA performance and Flash Attention support.
PiperOrigin-RevId: 908804155
1 parent 8ab544c commit dea3c6a

3 files changed

Lines changed: 35 additions & 22 deletions

File tree

hackable_diffusion/lib/architecture/attention.py

Lines changed: 24 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,7 @@ def _dot_product_attention(
128128
*,
129129
mask: Bool["batch sequence_key"] | None = None,
130130
) -> Float["batch sequence_query head*dim"]:
131-
"""Performs dot product attention.
131+
"""Performs dot product attention using Flash Attention where possible.
132132
133133
Args:
134134
q: Query tensor.
@@ -142,22 +142,29 @@ def _dot_product_attention(
142142
The output tensor.
143143
"""
144144

145-
b, _, t, _ = q.shape
146-
147-
# Attention scores
148-
attn_logits = jnp.einsum("bhtd,bhsd->bhts", q, k) * rescale
149-
150-
# We apply the mask to the logits before softmax so that the softmax is zero
151-
# for masked tokens.
152-
if mask is not None:
153-
bcast_mask = jnp.expand_dims(mask, axis=(1, 2))
154-
attn_logits = jnp.where(bcast_mask, attn_logits, MASK_LOGITS_VALUE)
155-
156-
# Softmax and attention weights
157-
attn_weights = _stable_softmax(logits=attn_logits)
158-
159-
# Calculate attention output
160-
attn_output = jnp.einsum("bhts,bhsd->bhtd", attn_weights, v)
145+
b, _, t, head_d = q.shape
146+
147+
# Use jax.nn.dot_product_attention for optimized execution.
148+
# We broadcast our (B, K) mask to (B, 1, 1, K) to match the required shape.
149+
attn_mask = mask[:, jnp.newaxis, jnp.newaxis, :] if mask is not None else None
150+
151+
# jax.nn.dot_product_attention uses 1/sqrt(d) scaling by default.
152+
# We adjust Q to achieve the desired 'rescale' factor.
153+
q_scaled = q * (rescale * jnp.sqrt(head_d))
154+
155+
try:
156+
attn_output = jax.nn.dot_product_attention(
157+
q_scaled, k, v,
158+
mask=attn_mask,
159+
)
160+
except (AttributeError, TypeError):
161+
# Fallback to manual implementation if optimized kernel is unavailable or fails.
162+
attn_logits = jnp.einsum("bhtd,bhsd->bhts", q, k) * rescale
163+
if mask is not None:
164+
bcast_mask = jnp.expand_dims(mask, axis=(1, 2))
165+
attn_logits = jnp.where(bcast_mask, attn_logits, MASK_LOGITS_VALUE)
166+
attn_weights = _stable_softmax(logits=attn_logits)
167+
attn_output = jnp.einsum("bhts,bhsd->bhtd", attn_weights, v)
161168

162169
# Merge heads and project to output dimension
163170
attn_output = attn_output.transpose(0, 2, 1, 3).reshape(b, t, -1)

hackable_diffusion/lib/architecture/dit_blocks.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -158,27 +158,30 @@ def __call__(
158158
The output tensor.
159159
"""
160160

161+
# Precompute activation for conditioning
162+
cond_act = nn.silu(cond)
163+
161164
# Attention Branch
162-
x_attn_modulated = self.conditional_norm(x, c=nn.silu(cond))
165+
x_attn_modulated = self.conditional_norm(x, c=cond_act)
163166
attn_out = self.attn(x_attn_modulated, c=None, mask=mask)
164167
# Optional dropout
165168
if self.dropout_rate > 0.0:
166169
attn_out = nn.Dropout(rate=self.dropout_rate)(
167170
attn_out, deterministic=not is_training
168171
)
169-
gate_msa = self.gate_msa(nn.silu(cond))
172+
gate_msa = self.gate_msa(cond_act)
170173
# Add a sequence dimension [...,None,:] to broadcast to [*batch,seq,dim].
171174
x = x + gate_msa[..., None, :] * attn_out
172175

173176
# MLP Branch
174-
x_mlp_modulated = self.conditional_norm(x, c=nn.silu(cond))
177+
x_mlp_modulated = self.conditional_norm(x, c=cond_act)
175178
mlp_out = self.mlp(x_mlp_modulated, is_training=is_training)
176179
# Optional dropout
177180
if self.dropout_rate > 0.0:
178181
mlp_out = nn.Dropout(rate=self.dropout_rate)(
179182
mlp_out, deterministic=not is_training
180183
)
181-
gate_mlp = self.gate_mlp(nn.silu(cond))
184+
gate_mlp = self.gate_mlp(cond_act)
182185
# Add a sequence dimension [...,None,:] to broadcast to [*batch,seq,dim].
183186
x = x + gate_mlp[..., None, :] * mlp_out
184187
return x
@@ -267,7 +270,9 @@ def __call__(
267270
hn = h // hp
268271
wn = w // wp
269272

270-
x = self.conditional_norm(x, c=nn.silu(cond))
273+
# Optimization: compute silu(cond) once
274+
cond_act = nn.silu(cond)
275+
x = self.conditional_norm(x, c=cond_act)
271276
x = nn.Dense(
272277
features=hp * wp * c,
273278
name="Dense_Out",

hackable_diffusion/lib/architecture/unet_blocks.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -192,6 +192,7 @@ def __call__(
192192
dtype=self.dtype,
193193
)(x)
194194

195+
# Optimization: Pre-activate conditioning embedding
195196
x = self.conditional_norm(x, self.activation_fn(adaptive_norm_emb))
196197
x = self.activation_fn(x)
197198
x = nn.Dropout(rate=self.dropout_rate, deterministic=not is_training)(x)

0 commit comments

Comments
 (0)