Skip to content

Commit ff6d3dd

Browse files
ccrepyHackable Diffusion Authors
authored andcommitted
Relax ConditioningEmbedding by removing Conditioning Mechanism
PiperOrigin-RevId: 911898415
1 parent ae44b17 commit ff6d3dd

27 files changed

Lines changed: 137 additions & 142 deletions

docs/architecture.md

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -32,16 +32,16 @@ to standardize the architecture's configuration.
3232

3333
### `ConditioningMechanism`
3434

35-
This enum specifies how a conditioning signal is injected into the backbone.
35+
This string specifies how a conditioning signal is injected into the backbone.
3636

37-
* `ADAPTIVE_NORM`: The conditioning embedding is used to modulate the scale
37+
* `adaptative_norm`: The conditioning embedding is used to modulate the scale
3838
and shift in an adaptive normalization layer (e.g., AdaLN).
39-
* `CROSS_ATTENTION`: The conditioning embedding is used as the key and value
39+
* `cross_attention`: The conditioning embedding is used as the key and value
4040
in a cross-attention layer, with the model's intermediate representation as
4141
the query.
42-
* `CONCATENATE`: The conditioning is concatenated to the input of a layer or
42+
* `concatenate`: The conditioning is concatenated to the input of a layer or
4343
module.
44-
* `SUM`: The conditioning is added to the input of a layer or module.
44+
* `sum`: The conditioning is added to the input of a layer or module.
4545

4646
### `EmbeddingMergeMethod`
4747

@@ -105,8 +105,8 @@ adaptive_norm_emb = jnp.ones((1, 128))
105105
cross_attention_emb = jnp.ones((1, 10, 256)) # 10 tokens, 256 dim
106106

107107
conditioning_embeddings = {
108-
ConditioningMechanism.ADAPTIVE_NORM: adaptive_norm_emb,
109-
ConditioningMechanism.CROSS_ATTENTION: cross_attention_emb,
108+
'adaptive_norm': adaptive_norm_emb,
109+
'cross_attention': cross_attention_emb,
110110
}
111111

112112
unet = Unet(
@@ -317,9 +317,9 @@ conditioning_encoder = ConditioningEncoder(
317317
},
318318
embedding_merging_method=EmbeddingMergeMethod.SUM,
319319
conditioning_rules={
320-
'time': ConditioningMechanism.ADAPTIVE_NORM,
321-
'label_adanorm': ConditioningMechanism.ADAPTIVE_NORM,
322-
'label_xattn': ConditioningMechanism.CROSS_ATTENTION,
320+
'time': 'adaptive_norm',
321+
'label_adanorm': 'adaptive_norm',
322+
'label_xattn': 'cross_attention',
323323
},
324324
conditioning_dropout_rate=0.1,
325325
)
@@ -339,8 +339,8 @@ output_embeddings = conditioning_encoder.apply(
339339
)
340340

341341
# 5. Inspect the output
342-
adanorm_emb = output_embeddings[ConditioningMechanism.ADAPTIVE_NORM]
343-
xattn_emb = output_embeddings[ConditioningMechanism.CROSS_ATTENTION]
342+
adanorm_emb = output_embeddings['adaptive_norm']
343+
xattn_emb = output_embeddings['cross_attention']
344344

345345
# The adanorm embedding is the sum of time and label_adanorm embeddings
346346
print(f"Adaptive Norm embedding shape: {adanorm_emb.shape}")

hackable_diffusion/kdiff/configs/imagenet64_unet.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -65,8 +65,8 @@ def get_config():
6565
},
6666
embedding_merging_method=hd.architecture.EmbeddingMergeMethod.SUM,
6767
conditioning_rules={
68-
"label": hd.architecture.ConditioningMechanism.ADAPTIVE_NORM,
69-
"time": hd.architecture.ConditioningMechanism.ADAPTIVE_NORM,
68+
"label": 'adaptive_norm',
69+
"time": 'adaptive_norm',
7070
},
7171
conditioning_dropout_rate=0.1,
7272
)

hackable_diffusion/lib/architecture/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616

1717
# pylint: disable=g-importing-member
1818
from hackable_diffusion.lib.architecture.arch_typing import ConditionalBackbone
19-
from hackable_diffusion.lib.architecture.arch_typing import ConditioningMechanism
2019
from hackable_diffusion.lib.architecture.arch_typing import DownsampleType
2120
from hackable_diffusion.lib.architecture.arch_typing import EmbeddingMergeMethod
2221
from hackable_diffusion.lib.architecture.arch_typing import NormalizationType

hackable_diffusion/lib/architecture/arch_typing.py

Lines changed: 18 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
"""
2020

2121
import enum
22-
from typing import Callable, Protocol
22+
from typing import Any, Callable, Protocol
2323
from hackable_diffusion.lib import hd_typing
2424
import jax
2525

@@ -50,17 +50,6 @@ class EmbeddingMergeMethod(enum.StrEnum):
5050
CONCAT = "concat"
5151

5252

53-
class ConditioningMechanism(enum.StrEnum):
54-
"""Types of conditioning mechanisms."""
55-
56-
ADAPTIVE_NORM = "adaptive_norm"
57-
CROSS_ATTENTION = "cross_attention"
58-
CONCATENATE = "concatenate"
59-
SUM = "sum"
60-
SELF_CONDITIONING = "self_conditioning"
61-
CUSTOM = "custom"
62-
63-
6453
class RoPEPositionType(enum.StrEnum):
6554
"""Rotary Position Embedding (RoPE) types."""
6655

@@ -95,6 +84,22 @@ class SkipConnectionMethod(enum.StrEnum):
9584
NORMALIZED_ADD = "normalized_add"
9685

9786

87+
################################################################################
88+
# MARK: Conditioning Mechanism
89+
################################################################################
90+
91+
92+
# Conditioning embeddings corresponds to a dictionary with keys corresponding to
93+
# the specification of a conditioning mechanism.
94+
# Example of common conditioning mechanisms:
95+
# - adaptive_norm
96+
# - cross_attention
97+
# - concatenate
98+
# - sum
99+
# - self_conditioning
100+
#
101+
ConditioningEmbeddings = dict[str, Any]
102+
98103
################################################################################
99104
# MARK: Types and protocols
100105
################################################################################
@@ -108,7 +113,7 @@ class ConditionalBackbone(Protocol):
108113
def __call__(
109114
self,
110115
x: DataTree,
111-
conditioning_embeddings: dict[ConditioningMechanism, Float["batch ..."]],
116+
conditioning_embeddings: ConditioningEmbeddings,
112117
is_training: bool,
113118
) -> DataTree:
114119
...

hackable_diffusion/lib/architecture/conditioning_encoder.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@
4040
Num = hd_typing.Num
4141

4242

43-
ConditioningMechanism = arch_typing.ConditioningMechanism
43+
4444
EmbeddingMergeMethod = arch_typing.EmbeddingMergeMethod
4545

4646
################################################################################
@@ -76,7 +76,7 @@ def __call__(
7676
time: hd_typing.TimeArray,
7777
conditioning: hd_typing.Conditioning | None,
7878
is_training: bool,
79-
) -> dict[ConditioningMechanism, Float['batch ...']]:
79+
) -> arch_typing.ConditioningEmbeddings:
8080
...
8181

8282

@@ -365,9 +365,9 @@ class ConditioningEncoder(nn.Module, BaseConditioningEncoder):
365365
),
366366
embedding_merging_method=EmbeddingMergeMethod.SUM,
367367
conditioning_rules=dict(
368-
label_foo =ConditioningMechanism.ADAPTIVE_NORM,
369-
label_bar=ConditioningMechanism.CROSS_ATTENTION,
370-
time=ConditioningMechanism.ADAPTIVE_NORM,
368+
label_foo ='adaptive_norm',
369+
label_bar='cross_attention',
370+
time='adaptive_norm',
371371
),
372372
)
373373
```
@@ -392,7 +392,7 @@ class ConditioningEncoder(nn.Module, BaseConditioningEncoder):
392392
time_embedder: BaseTimeEmbedder
393393
conditioning_embedders: dict[str, BaseEmbedder]
394394
embedding_merging_method: EmbeddingMergeMethod
395-
conditioning_rules: dict[str, ConditioningMechanism]
395+
conditioning_rules: arch_typing.ConditioningEmbeddings
396396
conditioning_dropout_rate: float = 0.0
397397

398398
def setup(self):
@@ -427,7 +427,7 @@ def __call__(
427427
time: hd_typing.TimeTree,
428428
conditioning: hd_typing.Conditioning | None,
429429
is_training: bool,
430-
) -> dict[ConditioningMechanism, Num['batch ...']]:
430+
) -> arch_typing.ConditioningEmbeddings:
431431
"""Encodes and combines time and conditioning signals.
432432
433433
The output is a dictionary where keys are the embedding mechanisms specified

hackable_diffusion/lib/architecture/conditioning_encoder_test.py

Lines changed: 25 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
################################################################################
2727

2828
EmbeddingMergeMethod = arch_typing.EmbeddingMergeMethod
29-
ConditioningMechanism = arch_typing.ConditioningMechanism
29+
3030

3131
################################################################################
3232
# MARK: Tests
@@ -39,13 +39,13 @@ class EncodeConditioningTest(parameterized.TestCase):
3939
(
4040
'test1',
4141
EmbeddingMergeMethod.SUM,
42-
ConditioningMechanism.ADAPTIVE_NORM,
42+
'adaptive_norm',
4343
True,
4444
),
4545
(
4646
'test2',
4747
EmbeddingMergeMethod.CONCAT,
48-
ConditioningMechanism.CROSS_ATTENTION,
48+
'cross_attention',
4949
False,
5050
),
5151
)
@@ -110,13 +110,13 @@ def test_basic(
110110
(
111111
'test1',
112112
EmbeddingMergeMethod.SUM,
113-
ConditioningMechanism.ADAPTIVE_NORM,
113+
'adaptive_norm',
114114
True,
115115
),
116116
(
117117
'test2',
118118
EmbeddingMergeMethod.CONCAT,
119-
ConditioningMechanism.CROSS_ATTENTION,
119+
'cross_attention',
120120
False,
121121
),
122122
)
@@ -182,13 +182,13 @@ def test_mlp_embedder(
182182
(
183183
'test1',
184184
EmbeddingMergeMethod.SUM,
185-
ConditioningMechanism.ADAPTIVE_NORM,
185+
'adaptive_norm',
186186
True,
187187
),
188188
(
189189
'test2',
190190
EmbeddingMergeMethod.CONCAT,
191-
ConditioningMechanism.CROSS_ATTENTION,
191+
'cross_attention',
192192
False,
193193
),
194194
)
@@ -257,13 +257,13 @@ def test_mlp_embedder_process_multiple_keys(
257257
(
258258
'test1',
259259
EmbeddingMergeMethod.SUM,
260-
ConditioningMechanism.ADAPTIVE_NORM,
260+
'adaptive_norm',
261261
True,
262262
),
263263
(
264264
'test2',
265265
EmbeddingMergeMethod.CONCAT,
266-
ConditioningMechanism.CROSS_ATTENTION,
266+
'cross_attention',
267267
False,
268268
),
269269
)
@@ -330,8 +330,8 @@ def test_field_selector_embedder(self):
330330
)
331331
conditioning_encoders = {'image': image_selector}
332332
conditioning_rules = {
333-
'time': ConditioningMechanism.ADAPTIVE_NORM,
334-
'image': ConditioningMechanism.CROSS_ATTENTION,
333+
'time': 'adaptive_norm',
334+
'image': 'cross_attention',
335335
}
336336
embedding_merging_method = EmbeddingMergeMethod.CONCAT
337337

@@ -353,18 +353,18 @@ def test_field_selector_embedder(self):
353353
{'params': params}, t, c, is_training=False, rngs={'dropout': rng}
354354
)
355355

356-
self.assertIn(ConditioningMechanism.CROSS_ATTENTION, output)
356+
self.assertIn('cross_attention', output)
357357
self.assertEqual(
358-
output[ConditioningMechanism.CROSS_ATTENTION].shape,
358+
output['cross_attention'].shape,
359359
(batch_size,) + image_shape,
360360
)
361361
self.assertTrue(
362-
jnp.all(output[ConditioningMechanism.CROSS_ATTENTION] == c['image'])
362+
jnp.all(output['cross_attention'] == c['image'])
363363
)
364364

365-
self.assertIn(ConditioningMechanism.ADAPTIVE_NORM, output)
365+
self.assertIn('adaptive_norm', output)
366366
self.assertEqual(
367-
output[ConditioningMechanism.ADAPTIVE_NORM].shape,
367+
output['adaptive_norm'].shape,
368368
(batch_size, num_features),
369369
)
370370

@@ -384,8 +384,8 @@ def test_field_selector_embedder_fails_on_missing_key(self):
384384
)
385385
conditioning_encoders = {'image': image_selector}
386386
conditioning_rules = {
387-
'time': ConditioningMechanism.ADAPTIVE_NORM,
388-
'image': ConditioningMechanism.CROSS_ATTENTION,
387+
'time': 'adaptive_norm',
388+
'image': 'cross_attention',
389389
}
390390
embedding_merging_method = EmbeddingMergeMethod.CONCAT
391391

@@ -411,7 +411,7 @@ def test_field_selector_embedder_fails_on_missing_key(self):
411411
(
412412
'test1',
413413
EmbeddingMergeMethod.CONCAT,
414-
ConditioningMechanism.CROSS_ATTENTION,
414+
'cross_attention',
415415
8,
416416
16,
417417
False,
@@ -475,7 +475,7 @@ def test_different_num_features(
475475
(
476476
'test1',
477477
EmbeddingMergeMethod.CONCAT,
478-
ConditioningMechanism.CROSS_ATTENTION,
478+
'cross_attention',
479479
8,
480480
9,
481481
10,
@@ -484,7 +484,7 @@ def test_different_num_features(
484484
(
485485
'test2',
486486
EmbeddingMergeMethod.CONCAT,
487-
ConditioningMechanism.CROSS_ATTENTION,
487+
'cross_attention',
488488
8,
489489
9,
490490
10,
@@ -567,7 +567,7 @@ def test_multilabel(
567567
(
568568
'test1',
569569
EmbeddingMergeMethod.CONCAT,
570-
ConditioningMechanism.CROSS_ATTENTION,
570+
'cross_attention',
571571
8,
572572
9,
573573
10,
@@ -576,7 +576,7 @@ def test_multilabel(
576576
(
577577
'test2',
578578
EmbeddingMergeMethod.CONCAT,
579-
ConditioningMechanism.CROSS_ATTENTION,
579+
'cross_attention',
580580
8,
581581
9,
582582
10,
@@ -668,8 +668,8 @@ def test_dropout(self):
668668
conditioning_embedders=conditioning_encoders,
669669
embedding_merging_method=EmbeddingMergeMethod.SUM,
670670
conditioning_rules={
671-
'time': ConditioningMechanism.ADAPTIVE_NORM,
672-
'label': ConditioningMechanism.ADAPTIVE_NORM,
671+
'time': 'adaptive_norm',
672+
'label': 'adaptive_norm',
673673
},
674674
conditioning_dropout_rate=1.0, # Drop all conditioning
675675
)

hackable_diffusion/lib/architecture/discrete.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
Int = hd_typing.Int
3333

3434
ConditionalBackbone = arch_typing.ConditionalBackbone
35-
ConditioningMechanism = arch_typing.ConditioningMechanism
35+
3636

3737
################################################################################
3838
# MARK: Token Embedder
@@ -213,7 +213,7 @@ def __post_init__(self):
213213
def __call__(
214214
self,
215215
x: Int['batch *other 1'],
216-
conditioning_embeddings: dict[ConditioningMechanism, Float['batch ...']],
216+
conditioning_embeddings: arch_typing.ConditioningEmbeddings,
217217
is_training: bool,
218218
) -> Float['batch *other V']:
219219

0 commit comments

Comments
 (0)