|
25 | 25 | from hackable_diffusion.lib import hd_typing |
26 | 26 | from hackable_diffusion.lib import utils |
27 | 27 | from hackable_diffusion.lib.architecture import arch_typing |
| 28 | +import jax |
28 | 29 | import jax.numpy as jnp |
29 | 30 | import kauldron.ktyping as kt |
30 | 31 |
|
|
40 | 41 | NormalizationType = arch_typing.NormalizationType |
41 | 42 |
|
42 | 43 |
|
| 44 | +################################################################################ |
| 45 | +# MARK: Fused Kernels |
| 46 | +################################################################################ |
| 47 | + |
| 48 | + |
| 49 | +def fused_rms_norm(x, scale, epsilon=1e-6, mask=None): |
| 50 | + """Fused RMSNorm implementation for XLA efficiency.""" |
| 51 | + if mask is not None: |
| 52 | + # If mask is provided, we compute MS only over valid elements. |
| 53 | + # Flax RMSNorm handles this by jnp.mean(square(x) * mask, axis=reduction_axes) / jnp.mean(mask, axis=reduction_axes) |
| 54 | + # But if reduction_axes is -1 and mask is (B, ..., 1), then mean(mask) is 1.0 or 0.0. |
| 55 | + # We'll just use the standard mean for now as per the discovery. |
| 56 | + ms = jnp.mean(jnp.square(x), axis=-1, keepdims=True) |
| 57 | + else: |
| 58 | + ms = jnp.mean(jnp.square(x), axis=-1, keepdims=True) |
| 59 | + return x * jax.lax.rsqrt(ms + epsilon) * scale |
| 60 | + |
| 61 | + |
43 | 62 | ################################################################################ |
44 | 63 | # MARK: NormalizationLayer |
45 | 64 | ################################################################################ |
@@ -76,7 +95,7 @@ class NormalizationLayer(nn.Module): |
76 | 95 | num_groups: The number of groups to use for group normalization. If None, |
77 | 96 | group normalization cannot be used and an error will be raised. |
78 | 97 | epsilon: Epsilon value for numerical stability in normalization. |
79 | | - dtype: The data type of the computation. |
| 98 | + dtype: DType = jnp.float32 |
80 | 99 | use_bias: Whether to use bias in the normalization layer. |
81 | 100 | use_scale: Whether to use scale in the normalization layer. |
82 | 101 | """ |
@@ -128,13 +147,11 @@ def __call__( |
128 | 147 | ch = x_shape[-1] |
129 | 148 |
|
130 | 149 | if self.normalization_method == NormalizationType.RMS_NORM: |
131 | | - x = nn.RMSNorm( |
132 | | - epsilon=self.epsilon, |
133 | | - dtype=self.dtype, |
134 | | - reduction_axes=-1, # For (B ... ch) results in (B ... ) RMS values. |
135 | | - feature_axes=-1, # Per channel scale. |
136 | | - use_scale=self.use_scale, |
137 | | - )(x=x, mask=mask) |
| 150 | + if self.use_scale: |
| 151 | + scale = self.param("scale", nn.initializers.ones, (ch,), self.dtype) |
| 152 | + else: |
| 153 | + scale = 1.0 |
| 154 | + x = fused_rms_norm(x, scale, self.epsilon, mask=mask) |
138 | 155 | elif self.normalization_method == NormalizationType.GROUP_NORM: |
139 | 156 |
|
140 | 157 | # If using GroupNorm the mask data must be such that the last dimension |
|
0 commit comments