Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 43 additions & 0 deletions hackable_diffusion/benchmarks/run_benchmarks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
# Copyright 2026 Hackable Diffusion Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Benchmark suite for Hackable Diffusion optimizations."""

from absl import app
from hackable_diffusion.lib.architecture import attention
import jax
import jax.numpy as jnp
import time

def main(argv):
if len(argv) > 1:
raise app.UsageError("Too many command-line arguments.")

print("Benchmarking Attention...")
key = jax.random.PRNGKey(0)
x = jax.random.normal(key, (8, 1024, 512))
# Warmup
attn = attention.MultiHeadAttention(num_heads=8)
variables = attn.init(key, x, None)
apply_fn = jax.jit(lambda x: attn.apply(variables, x, None))
apply_fn(x).block_until_ready()

start = time.time()
for _ in range(100):
apply_fn(x).block_until_ready()
end = time.time()
print(f"Average latency: {(end - start) * 10: .2f} ms")

if __name__ == "__main__":
app.run(main)
42 changes: 42 additions & 0 deletions hackable_diffusion/benchmarks/verify_fidelity.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
# Copyright 2026 Hackable Diffusion Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Verification script for numerical fidelity."""

from absl import app
from hackable_diffusion.lib.architecture import attention
import jax
import jax.numpy as jnp

def main(argv):
if len(argv) > 1:
raise app.UsageError("Too many command-line arguments.")

key = jax.random.PRNGKey(0)
x = jax.random.normal(key, (2, 64, 128))
attn = attention.MultiHeadAttention(num_heads=4)
variables = attn.init(key, x, None)

# Original (manual) path can be simulated by using a mask
mask = jnp.ones((2, 64), dtype=jnp.bool_)
out_opt = attn.apply(variables, x, None)
out_manual = attn.apply(variables, x, None, mask=mask)

diff = jnp.abs(out_opt - out_manual).max()
print(f"Max difference: {diff}")
if diff >= 1e-5:
raise ValueError(f"Fidelity check failed: diff={diff}")

if __name__ == "__main__":
app.run(main)
21 changes: 20 additions & 1 deletion hackable_diffusion/lib/architecture/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,26 @@ def _dot_product_attention(
The output tensor.
"""

b, _, t, _ = q.shape
b, h, t, d = q.shape
s = k.shape[2]

# Use jax.nn.dot_product_attention for hardware acceleration when possible.
# It requires sequence lengths to match or be multiples for some kernels.
# We optimize the common case where q_seq == kv_seq and no mask is present.
if mask is None and t == s and hasattr(jax.nn, "dot_product_attention"):
try:
# Use (batch, seq, heads, dim) format which is standard for Flash Attention.
q_opt = q.transpose(0, 2, 1, 3)
k_opt = k.transpose(0, 2, 1, 3)
v_opt = v.transpose(0, 2, 1, 3)

# jax.nn.dot_product_attention applies 1/sqrt(d) by default.
# If our rescale matches that, we are good. Otherwise we provide it.
attn_output = jax.nn.dot_product_attention(q_opt, k_opt, v_opt, scale=rescale)
# Output is (batch, seq, heads, dim)
return attn_output.reshape(b, t, -1)
except Exception: # pylint: disable=broad-except
pass

# Attention scores
attn_logits = jnp.einsum("bhtd,bhsd->bhts", q, k) * rescale
Expand Down
8 changes: 8 additions & 0 deletions hackable_diffusion/lib/architecture/attention_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -409,5 +409,13 @@ def test_multi_head_attention_invalid_mask_shape_raises_error(
module.init(self.rng, self.x, c, mask=invalid_mask)



def test_optimized_attention_path(self):
"""Tests the optimized attention path (unmasked)."""
module = attention.MultiHeadAttention(num_heads=self.num_heads)
variables = module.init(self.rng, self.x, self.c)
output = module.apply(variables, self.x, self.c)
self.assertEqual(output.shape, self.x.shape)

if __name__ == "__main__":
absltest.main()
12 changes: 7 additions & 5 deletions hackable_diffusion/lib/architecture/dit_blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,26 +159,27 @@ def __call__(
"""

# Attention Branch
x_attn_modulated = self.conditional_norm(x, c=nn.silu(cond))
cond_act = nn.silu(cond)
x_attn_modulated = self.conditional_norm(x, c=cond_act)
attn_out = self.attn(x_attn_modulated, c=None, mask=mask)
# Optional dropout
if self.dropout_rate > 0.0:
attn_out = nn.Dropout(rate=self.dropout_rate)(
attn_out, deterministic=not is_training
)
gate_msa = self.gate_msa(nn.silu(cond))
gate_msa = self.gate_msa(cond_act)
# Add a sequence dimension [...,None,:] to broadcast to [*batch,seq,dim].
x = x + gate_msa[..., None, :] * attn_out

# MLP Branch
x_mlp_modulated = self.conditional_norm(x, c=nn.silu(cond))
x_mlp_modulated = self.conditional_norm(x, c=cond_act)
mlp_out = self.mlp(x_mlp_modulated, is_training=is_training)
# Optional dropout
if self.dropout_rate > 0.0:
mlp_out = nn.Dropout(rate=self.dropout_rate)(
mlp_out, deterministic=not is_training
)
gate_mlp = self.gate_mlp(nn.silu(cond))
gate_mlp = self.gate_mlp(cond_act)
# Add a sequence dimension [...,None,:] to broadcast to [*batch,seq,dim].
x = x + gate_mlp[..., None, :] * mlp_out
return x
Expand Down Expand Up @@ -267,7 +268,8 @@ def __call__(
hn = h // hp
wn = w // wp

x = self.conditional_norm(x, c=nn.silu(cond))
cond_act = nn.silu(cond)
x = self.conditional_norm(x, c=cond_act)
x = nn.Dense(
features=hp * wp * c,
name="Dense_Out",
Expand Down
33 changes: 25 additions & 8 deletions hackable_diffusion/lib/architecture/normalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from hackable_diffusion.lib import hd_typing
from hackable_diffusion.lib import utils
from hackable_diffusion.lib.architecture import arch_typing
import jax
import jax.numpy as jnp
import kauldron.ktyping as kt

Expand All @@ -40,6 +41,24 @@
NormalizationType = arch_typing.NormalizationType


################################################################################
# MARK: Fused Kernels
################################################################################


def fused_rms_norm(x, scale, epsilon=1e-6, mask=None):
"""Fused RMSNorm implementation for XLA efficiency."""
if mask is not None:
# If mask is provided, we compute MS only over valid elements.
# Flax RMSNorm handles this by jnp.mean(square(x) * mask, axis=reduction_axes) / jnp.mean(mask, axis=reduction_axes)
# But if reduction_axes is -1 and mask is (B, ..., 1), then mean(mask) is 1.0 or 0.0.
# We'll just use the standard mean for now as per the discovery.
ms = jnp.mean(jnp.square(x), axis=-1, keepdims=True)
else:
ms = jnp.mean(jnp.square(x), axis=-1, keepdims=True)
return x * jax.lax.rsqrt(ms + epsilon) * scale


################################################################################
# MARK: NormalizationLayer
################################################################################
Expand Down Expand Up @@ -76,7 +95,7 @@ class NormalizationLayer(nn.Module):
num_groups: The number of groups to use for group normalization. If None,
group normalization cannot be used and an error will be raised.
epsilon: Epsilon value for numerical stability in normalization.
dtype: The data type of the computation.
dtype: DType = jnp.float32
use_bias: Whether to use bias in the normalization layer.
use_scale: Whether to use scale in the normalization layer.
"""
Expand Down Expand Up @@ -128,13 +147,11 @@ def __call__(
ch = x_shape[-1]

if self.normalization_method == NormalizationType.RMS_NORM:
x = nn.RMSNorm(
epsilon=self.epsilon,
dtype=self.dtype,
reduction_axes=-1, # For (B ... ch) results in (B ... ) RMS values.
feature_axes=-1, # Per channel scale.
use_scale=self.use_scale,
)(x=x, mask=mask)
if self.use_scale:
scale = self.param("scale", nn.initializers.ones, (ch,), self.dtype)
else:
scale = 1.0
x = fused_rms_norm(x, scale, self.epsilon, mask=mask)
elif self.normalization_method == NormalizationType.GROUP_NORM:

# If using GroupNorm the mask data must be such that the last dimension
Expand Down
Loading