Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 12 additions & 12 deletions docs/architecture.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,16 +32,16 @@ to standardize the architecture's configuration.

### `ConditioningMechanism`

This enum specifies how a conditioning signal is injected into the backbone.
This string specifies how a conditioning signal is injected into the backbone.

* `ADAPTIVE_NORM`: The conditioning embedding is used to modulate the scale
* `adaptative_norm`: The conditioning embedding is used to modulate the scale
and shift in an adaptive normalization layer (e.g., AdaLN).
* `CROSS_ATTENTION`: The conditioning embedding is used as the key and value
* `cross_attention`: The conditioning embedding is used as the key and value
in a cross-attention layer, with the model's intermediate representation as
the query.
* `CONCATENATE`: The conditioning is concatenated to the input of a layer or
* `concatenate`: The conditioning is concatenated to the input of a layer or
module.
* `SUM`: The conditioning is added to the input of a layer or module.
* `sum`: The conditioning is added to the input of a layer or module.

### `EmbeddingMergeMethod`

Expand Down Expand Up @@ -105,8 +105,8 @@ adaptive_norm_emb = jnp.ones((1, 128))
cross_attention_emb = jnp.ones((1, 10, 256)) # 10 tokens, 256 dim

conditioning_embeddings = {
ConditioningMechanism.ADAPTIVE_NORM: adaptive_norm_emb,
ConditioningMechanism.CROSS_ATTENTION: cross_attention_emb,
'adaptive_norm': adaptive_norm_emb,
'cross_attention': cross_attention_emb,
}

unet = Unet(
Expand Down Expand Up @@ -317,9 +317,9 @@ conditioning_encoder = ConditioningEncoder(
},
embedding_merging_method=EmbeddingMergeMethod.SUM,
conditioning_rules={
'time': ConditioningMechanism.ADAPTIVE_NORM,
'label_adanorm': ConditioningMechanism.ADAPTIVE_NORM,
'label_xattn': ConditioningMechanism.CROSS_ATTENTION,
'time': 'adaptive_norm',
'label_adanorm': 'adaptive_norm',
'label_xattn': 'cross_attention',
},
conditioning_dropout_rate=0.1,
)
Expand All @@ -339,8 +339,8 @@ output_embeddings = conditioning_encoder.apply(
)

# 5. Inspect the output
adanorm_emb = output_embeddings[ConditioningMechanism.ADAPTIVE_NORM]
xattn_emb = output_embeddings[ConditioningMechanism.CROSS_ATTENTION]
adanorm_emb = output_embeddings['adaptive_norm']
xattn_emb = output_embeddings['cross_attention']

# The adanorm embedding is the sum of time and label_adanorm embeddings
print(f"Adaptive Norm embedding shape: {adanorm_emb.shape}")
Expand Down
4 changes: 2 additions & 2 deletions hackable_diffusion/kdiff/configs/imagenet64_unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,8 @@ def get_config():
},
embedding_merging_method=hd.architecture.EmbeddingMergeMethod.SUM,
conditioning_rules={
"label": hd.architecture.ConditioningMechanism.ADAPTIVE_NORM,
"time": hd.architecture.ConditioningMechanism.ADAPTIVE_NORM,
"label": 'adaptive_norm',
"time": 'adaptive_norm',
},
conditioning_dropout_rate=0.1,
)
Expand Down
1 change: 0 additions & 1 deletion hackable_diffusion/lib/architecture/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@

# pylint: disable=g-importing-member
from hackable_diffusion.lib.architecture.arch_typing import ConditionalBackbone
from hackable_diffusion.lib.architecture.arch_typing import ConditioningMechanism
from hackable_diffusion.lib.architecture.arch_typing import DownsampleType
from hackable_diffusion.lib.architecture.arch_typing import EmbeddingMergeMethod
from hackable_diffusion.lib.architecture.arch_typing import NormalizationType
Expand Down
31 changes: 18 additions & 13 deletions hackable_diffusion/lib/architecture/arch_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
"""

import enum
from typing import Callable, Protocol
from typing import Any, Callable, Protocol
from hackable_diffusion.lib import hd_typing
import jax

Expand Down Expand Up @@ -50,17 +50,6 @@ class EmbeddingMergeMethod(enum.StrEnum):
CONCAT = "concat"


class ConditioningMechanism(enum.StrEnum):
"""Types of conditioning mechanisms."""

ADAPTIVE_NORM = "adaptive_norm"
CROSS_ATTENTION = "cross_attention"
CONCATENATE = "concatenate"
SUM = "sum"
SELF_CONDITIONING = "self_conditioning"
CUSTOM = "custom"


class RoPEPositionType(enum.StrEnum):
"""Rotary Position Embedding (RoPE) types."""

Expand Down Expand Up @@ -95,6 +84,22 @@ class SkipConnectionMethod(enum.StrEnum):
NORMALIZED_ADD = "normalized_add"


################################################################################
# MARK: Conditioning Mechanism
################################################################################


# Conditioning embeddings corresponds to a dictionary with keys corresponding to
# the specification of a conditioning mechanism.
# Example of common conditioning mechanisms:
# - adaptive_norm
# - cross_attention
# - concatenate
# - sum
# - self_conditioning
#
ConditioningEmbeddings = dict[str, Any]

################################################################################
# MARK: Types and protocols
################################################################################
Expand All @@ -108,7 +113,7 @@ class ConditionalBackbone(Protocol):
def __call__(
self,
x: DataTree,
conditioning_embeddings: dict[ConditioningMechanism, Float["batch ..."]],
conditioning_embeddings: ConditioningEmbeddings,
is_training: bool,
) -> DataTree:
...
Expand Down
14 changes: 7 additions & 7 deletions hackable_diffusion/lib/architecture/conditioning_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
Num = hd_typing.Num


ConditioningMechanism = arch_typing.ConditioningMechanism

EmbeddingMergeMethod = arch_typing.EmbeddingMergeMethod

################################################################################
Expand Down Expand Up @@ -76,7 +76,7 @@ def __call__(
time: hd_typing.TimeArray,
conditioning: hd_typing.Conditioning | None,
is_training: bool,
) -> dict[ConditioningMechanism, Float['batch ...']]:
) -> arch_typing.ConditioningEmbeddings:
...


Expand Down Expand Up @@ -365,9 +365,9 @@ class ConditioningEncoder(nn.Module, BaseConditioningEncoder):
),
embedding_merging_method=EmbeddingMergeMethod.SUM,
conditioning_rules=dict(
label_foo =ConditioningMechanism.ADAPTIVE_NORM,
label_bar=ConditioningMechanism.CROSS_ATTENTION,
time=ConditioningMechanism.ADAPTIVE_NORM,
label_foo ='adaptive_norm',
label_bar='cross_attention',
time='adaptive_norm',
),
)
```
Expand All @@ -392,7 +392,7 @@ class ConditioningEncoder(nn.Module, BaseConditioningEncoder):
time_embedder: BaseTimeEmbedder
conditioning_embedders: dict[str, BaseEmbedder]
embedding_merging_method: EmbeddingMergeMethod
conditioning_rules: dict[str, ConditioningMechanism]
conditioning_rules: arch_typing.ConditioningEmbeddings
conditioning_dropout_rate: float = 0.0

def setup(self):
Expand Down Expand Up @@ -427,7 +427,7 @@ def __call__(
time: hd_typing.TimeTree,
conditioning: hd_typing.Conditioning | None,
is_training: bool,
) -> dict[ConditioningMechanism, Num['batch ...']]:
) -> arch_typing.ConditioningEmbeddings:
"""Encodes and combines time and conditioning signals.

The output is a dictionary where keys are the embedding mechanisms specified
Expand Down
50 changes: 25 additions & 25 deletions hackable_diffusion/lib/architecture/conditioning_encoder_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
################################################################################

EmbeddingMergeMethod = arch_typing.EmbeddingMergeMethod
ConditioningMechanism = arch_typing.ConditioningMechanism


################################################################################
# MARK: Tests
Expand All @@ -39,13 +39,13 @@ class EncodeConditioningTest(parameterized.TestCase):
(
'test1',
EmbeddingMergeMethod.SUM,
ConditioningMechanism.ADAPTIVE_NORM,
'adaptive_norm',
True,
),
(
'test2',
EmbeddingMergeMethod.CONCAT,
ConditioningMechanism.CROSS_ATTENTION,
'cross_attention',
False,
),
)
Expand Down Expand Up @@ -110,13 +110,13 @@ def test_basic(
(
'test1',
EmbeddingMergeMethod.SUM,
ConditioningMechanism.ADAPTIVE_NORM,
'adaptive_norm',
True,
),
(
'test2',
EmbeddingMergeMethod.CONCAT,
ConditioningMechanism.CROSS_ATTENTION,
'cross_attention',
False,
),
)
Expand Down Expand Up @@ -182,13 +182,13 @@ def test_mlp_embedder(
(
'test1',
EmbeddingMergeMethod.SUM,
ConditioningMechanism.ADAPTIVE_NORM,
'adaptive_norm',
True,
),
(
'test2',
EmbeddingMergeMethod.CONCAT,
ConditioningMechanism.CROSS_ATTENTION,
'cross_attention',
False,
),
)
Expand Down Expand Up @@ -257,13 +257,13 @@ def test_mlp_embedder_process_multiple_keys(
(
'test1',
EmbeddingMergeMethod.SUM,
ConditioningMechanism.ADAPTIVE_NORM,
'adaptive_norm',
True,
),
(
'test2',
EmbeddingMergeMethod.CONCAT,
ConditioningMechanism.CROSS_ATTENTION,
'cross_attention',
False,
),
)
Expand Down Expand Up @@ -330,8 +330,8 @@ def test_field_selector_embedder(self):
)
conditioning_encoders = {'image': image_selector}
conditioning_rules = {
'time': ConditioningMechanism.ADAPTIVE_NORM,
'image': ConditioningMechanism.CROSS_ATTENTION,
'time': 'adaptive_norm',
'image': 'cross_attention',
}
embedding_merging_method = EmbeddingMergeMethod.CONCAT

Expand All @@ -353,18 +353,18 @@ def test_field_selector_embedder(self):
{'params': params}, t, c, is_training=False, rngs={'dropout': rng}
)

self.assertIn(ConditioningMechanism.CROSS_ATTENTION, output)
self.assertIn('cross_attention', output)
self.assertEqual(
output[ConditioningMechanism.CROSS_ATTENTION].shape,
output['cross_attention'].shape,
(batch_size,) + image_shape,
)
self.assertTrue(
jnp.all(output[ConditioningMechanism.CROSS_ATTENTION] == c['image'])
jnp.all(output['cross_attention'] == c['image'])
)

self.assertIn(ConditioningMechanism.ADAPTIVE_NORM, output)
self.assertIn('adaptive_norm', output)
self.assertEqual(
output[ConditioningMechanism.ADAPTIVE_NORM].shape,
output['adaptive_norm'].shape,
(batch_size, num_features),
)

Expand All @@ -384,8 +384,8 @@ def test_field_selector_embedder_fails_on_missing_key(self):
)
conditioning_encoders = {'image': image_selector}
conditioning_rules = {
'time': ConditioningMechanism.ADAPTIVE_NORM,
'image': ConditioningMechanism.CROSS_ATTENTION,
'time': 'adaptive_norm',
'image': 'cross_attention',
}
embedding_merging_method = EmbeddingMergeMethod.CONCAT

Expand All @@ -411,7 +411,7 @@ def test_field_selector_embedder_fails_on_missing_key(self):
(
'test1',
EmbeddingMergeMethod.CONCAT,
ConditioningMechanism.CROSS_ATTENTION,
'cross_attention',
8,
16,
False,
Expand Down Expand Up @@ -475,7 +475,7 @@ def test_different_num_features(
(
'test1',
EmbeddingMergeMethod.CONCAT,
ConditioningMechanism.CROSS_ATTENTION,
'cross_attention',
8,
9,
10,
Expand All @@ -484,7 +484,7 @@ def test_different_num_features(
(
'test2',
EmbeddingMergeMethod.CONCAT,
ConditioningMechanism.CROSS_ATTENTION,
'cross_attention',
8,
9,
10,
Expand Down Expand Up @@ -567,7 +567,7 @@ def test_multilabel(
(
'test1',
EmbeddingMergeMethod.CONCAT,
ConditioningMechanism.CROSS_ATTENTION,
'cross_attention',
8,
9,
10,
Expand All @@ -576,7 +576,7 @@ def test_multilabel(
(
'test2',
EmbeddingMergeMethod.CONCAT,
ConditioningMechanism.CROSS_ATTENTION,
'cross_attention',
8,
9,
10,
Expand Down Expand Up @@ -668,8 +668,8 @@ def test_dropout(self):
conditioning_embedders=conditioning_encoders,
embedding_merging_method=EmbeddingMergeMethod.SUM,
conditioning_rules={
'time': ConditioningMechanism.ADAPTIVE_NORM,
'label': ConditioningMechanism.ADAPTIVE_NORM,
'time': 'adaptive_norm',
'label': 'adaptive_norm',
},
conditioning_dropout_rate=1.0, # Drop all conditioning
)
Expand Down
4 changes: 2 additions & 2 deletions hackable_diffusion/lib/architecture/discrete.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
Int = hd_typing.Int

ConditionalBackbone = arch_typing.ConditionalBackbone
ConditioningMechanism = arch_typing.ConditioningMechanism


################################################################################
# MARK: Token Embedder
Expand Down Expand Up @@ -213,7 +213,7 @@ def __post_init__(self):
def __call__(
self,
x: Int['batch *other 1'],
conditioning_embeddings: dict[ConditioningMechanism, Float['batch ...']],
conditioning_embeddings: arch_typing.ConditioningEmbeddings,
is_training: bool,
) -> Float['batch *other V']:

Expand Down
Loading
Loading