Skip to content

Commit 165d642

Browse files
agalashovHackable Diffusion Authors
authored andcommitted
Add AdditiveSequenceEmbedding
PiperOrigin-RevId: 872366209
1 parent 00cd04e commit 165d642

6 files changed

Lines changed: 62 additions & 4 deletions

File tree

hackable_diffusion/lib/architecture/mlp.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,7 @@ def __call__(
7878
self,
7979
x: DataArray,
8080
conditioning_embeddings: dict[ConditioningMechanism, Float['batch ...']],
81+
*,
8182
is_training: bool,
8283
) -> DataArray:
8384
x_emb = jnp.reshape(x, shape=(x.shape[0], -1))
@@ -92,7 +93,7 @@ def __call__(
9293
dropout_rate=self.dropout_rate,
9394
dtype=self.dtype,
9495
name='PreprocessMLP',
95-
)(x_emb, is_training)
96+
)(x_emb, is_training=is_training)
9697

9798
# The conditioning was already processed by the `conditioning_encoder`, so
9899
# here we just need to concatenate it with the `x`.
@@ -125,7 +126,7 @@ def __call__(
125126
dtype=self.dtype,
126127
zero_init_output=self.zero_init_output,
127128
name='PostprocessMLP',
128-
)(emb, is_training)
129+
)(emb, is_training=is_training)
129130

130131
output = jnp.reshape(output, shape=x.shape)
131132
output = utils.optional_bf16_to_fp32(output)

hackable_diffusion/lib/architecture/mlp_blocks.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ class MLP(nn.Module):
4646
@nn.compact
4747
@typechecked
4848
def __call__(
49-
self, x: Float['batch num_inputs'], is_training: bool
49+
self, x: Float['batch num_inputs'], *, is_training: bool
5050
) -> Float['batch num_features']:
5151
"""Applies MLP blocks to the input tensor.
5252

hackable_diffusion/lib/architecture/sequence_embedders.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -186,3 +186,26 @@ def __call__(
186186

187187
out = jnp.asarray(jnp.concatenate(result, axis=-1), x.dtype)
188188
return out
189+
190+
191+
class AdditiveSequenceEmbedding(nn.Module):
192+
"""Learnable additive sequence positional embedding."""
193+
194+
num_features: int
195+
196+
def setup(self):
197+
if self.num_features <= 0:
198+
raise ValueError("Number of features must be positive.")
199+
200+
@nn.compact
201+
@typechecked
202+
def __call__(
203+
self, x: Num["batch *#data_shape"]
204+
) -> Float["batch *#data_shape"]:
205+
pos_embed = self.param(
206+
"PositionalEmbeddingTensor",
207+
nn.initializers.normal(stddev=0.02),
208+
(1, *x.shape[1:]),
209+
x.dtype,
210+
)
211+
return x + pos_embed

hackable_diffusion/lib/architecture/sequence_embedders_test.py

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,11 @@
3838
def _get_invalid_num_features_params():
3939
"""Generates parameters for testing invalid num_features."""
4040
params = []
41-
modes = ["sinusoidal_embedding", "random_fourier_embedding"]
41+
modes = [
42+
"sinusoidal_embedding",
43+
"random_fourier_embedding",
44+
"additive_embedding",
45+
]
4246
feature_values = [
4347
("default", INVALID_INT),
4448
("zero", 0),
@@ -102,6 +106,10 @@ def test_sequence_embedding_raises_error_on_invalid_num_features(
102106
module = sequence_embedders.RandomFourierSequenceEmbedding(
103107
num_features=num_features
104108
)
109+
elif mode == "additive_embedding":
110+
module = sequence_embedders.AdditiveSequenceEmbedding(
111+
num_features=num_features
112+
)
105113
else:
106114
self.fail(f"Unknown mode: {mode}")
107115
inputs = jnp.arange(self.batch_size)
@@ -244,6 +252,30 @@ def test_rope_embedding_has_no_params(self):
244252
variables = module.init(self.rng, x_rope)
245253
self.assertEmpty(variables)
246254

255+
# MARK: AdditiveSequenceEmbedding tests
256+
257+
def test_additive_embedding_output_shape(self):
258+
"""Tests the output shape of AdditiveSequenceEmbedding."""
259+
module = sequence_embedders.AdditiveSequenceEmbedding(num_features=self.dim)
260+
variables = module.init({"params": self.rng}, self.x)
261+
output = module.apply(variables, self.x)
262+
self.assertEqual(output.shape, self.x.shape)
263+
264+
def test_additive_embedding_params_are_updated(self):
265+
"""Tests that AdditiveSequenceEmbedding params are updated."""
266+
module = sequence_embedders.AdditiveSequenceEmbedding(num_features=self.dim)
267+
variables = module.init({"params": self.rng}, self.x)
268+
initial_params = variables["params"]
269+
270+
def loss_fn(params):
271+
output = module.apply({"params": params}, self.x)
272+
return jnp.sum(output)
273+
274+
grads = jax.grad(loss_fn)(initial_params)
275+
276+
# Check that the gradients are not zero.
277+
self.assertFalse(jnp.allclose(grads["PositionalEmbeddingTensor"], 0.0))
278+
247279

248280
if __name__ == "__main__":
249281
absltest.main()

hackable_diffusion/lib/architecture/unet.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -164,6 +164,7 @@ def __call__(
164164
self,
165165
x: Float["batch height width channels"],
166166
conditioning_embeddings: dict[ConditioningMechanism, Float["batch ..."]],
167+
*,
167168
is_training: bool,
168169
) -> Float["batch height width output_channels"]:
169170

hackable_diffusion/lib/architecture/unet_blocks.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -266,6 +266,7 @@ def __call__(
266266
self,
267267
x: Float["batch height width channels"],
268268
cross_attention_emb: Float["batch seq cond_dim2"] | None,
269+
*,
269270
is_training: bool,
270271
) -> Float["batch height width channels"]:
271272
skip = x

0 commit comments

Comments
 (0)