|
25 | 25 | from hackable_diffusion.lib import hd_typing |
26 | 26 | from hackable_diffusion.lib import utils |
27 | 27 | from hackable_diffusion.lib.architecture import arch_typing |
| 28 | +import jax |
28 | 29 | import jax.numpy as jnp |
29 | 30 | import kauldron.ktyping as kt |
30 | 31 |
|
|
40 | 41 | NormalizationType = arch_typing.NormalizationType |
41 | 42 |
|
42 | 43 |
|
| 44 | + |
| 45 | +################################################################################ |
| 46 | +# MARK: Fused Kernels |
| 47 | +################################################################################ |
| 48 | + |
| 49 | +def fused_rms_norm(x, scale, epsilon=1e-6): |
| 50 | + """Fused RMSNorm implementation for XLA efficiency.""" |
| 51 | + ms = jnp.mean(jnp.square(x), axis=-1, keepdims=True) |
| 52 | + return x * jax.lax.rsqrt(ms + epsilon) * scale |
| 53 | + |
| 54 | + |
43 | 55 | ################################################################################ |
44 | 56 | # MARK: NormalizationLayer |
45 | 57 | ################################################################################ |
@@ -128,13 +140,17 @@ def __call__( |
128 | 140 | ch = x_shape[-1] |
129 | 141 |
|
130 | 142 | if self.normalization_method == NormalizationType.RMS_NORM: |
131 | | - x = nn.RMSNorm( |
132 | | - epsilon=self.epsilon, |
133 | | - dtype=self.dtype, |
134 | | - reduction_axes=-1, # For (B ... ch) results in (B ... ) RMS values. |
135 | | - feature_axes=-1, # Per channel scale. |
136 | | - use_scale=self.use_scale, |
137 | | - )(x=x, mask=mask) |
| 143 | + if mask is None and self.use_scale: |
| 144 | + scale = self.param("scale", nn.initializers.ones, (ch,), self.dtype) |
| 145 | + x = fused_rms_norm(x, scale, self.epsilon) |
| 146 | + else: |
| 147 | + x = nn.RMSNorm( |
| 148 | + epsilon=self.epsilon, |
| 149 | + dtype=self.dtype, |
| 150 | + reduction_axes=-1, # For (B ... ch) results in (B ... ) RMS values. |
| 151 | + feature_axes=-1, # Per channel scale. |
| 152 | + use_scale=self.use_scale, |
| 153 | + )(x=x, mask=mask) |
138 | 154 | elif self.normalization_method == NormalizationType.GROUP_NORM: |
139 | 155 |
|
140 | 156 | # If using GroupNorm the mask data must be such that the last dimension |
@@ -187,6 +203,17 @@ def __call__( |
187 | 203 | return x |
188 | 204 |
|
189 | 205 |
|
| 206 | + |
| 207 | +################################################################################ |
| 208 | +# MARK: Fused Kernels |
| 209 | +################################################################################ |
| 210 | + |
| 211 | +def fused_rms_norm(x, scale, epsilon=1e-6): |
| 212 | + """Fused RMSNorm implementation for XLA efficiency.""" |
| 213 | + ms = jnp.mean(jnp.square(x), axis=-1, keepdims=True) |
| 214 | + return x * jax.lax.rsqrt(ms + epsilon) * scale |
| 215 | + |
| 216 | + |
190 | 217 | ################################################################################ |
191 | 218 | # MARK: NormalizationLayerFactory |
192 | 219 | ################################################################################ |
|
0 commit comments