diff --git a/docs/architecture.md b/docs/architecture.md index 627d19b..100f41c 100644 --- a/docs/architecture.md +++ b/docs/architecture.md @@ -32,16 +32,16 @@ to standardize the architecture's configuration. ### `ConditioningMechanism` -This enum specifies how a conditioning signal is injected into the backbone. +This string specifies how a conditioning signal is injected into the backbone. - * `ADAPTIVE_NORM`: The conditioning embedding is used to modulate the scale + * `adaptative_norm`: The conditioning embedding is used to modulate the scale and shift in an adaptive normalization layer (e.g., AdaLN). - * `CROSS_ATTENTION`: The conditioning embedding is used as the key and value + * `cross_attention`: The conditioning embedding is used as the key and value in a cross-attention layer, with the model's intermediate representation as the query. - * `CONCATENATE`: The conditioning is concatenated to the input of a layer or + * `concatenate`: The conditioning is concatenated to the input of a layer or module. - * `SUM`: The conditioning is added to the input of a layer or module. + * `sum`: The conditioning is added to the input of a layer or module. ### `EmbeddingMergeMethod` @@ -105,8 +105,8 @@ adaptive_norm_emb = jnp.ones((1, 128)) cross_attention_emb = jnp.ones((1, 10, 256)) # 10 tokens, 256 dim conditioning_embeddings = { - ConditioningMechanism.ADAPTIVE_NORM: adaptive_norm_emb, - ConditioningMechanism.CROSS_ATTENTION: cross_attention_emb, + 'adaptive_norm': adaptive_norm_emb, + 'cross_attention': cross_attention_emb, } unet = Unet( @@ -317,9 +317,9 @@ conditioning_encoder = ConditioningEncoder( }, embedding_merging_method=EmbeddingMergeMethod.SUM, conditioning_rules={ - 'time': ConditioningMechanism.ADAPTIVE_NORM, - 'label_adanorm': ConditioningMechanism.ADAPTIVE_NORM, - 'label_xattn': ConditioningMechanism.CROSS_ATTENTION, + 'time': 'adaptive_norm', + 'label_adanorm': 'adaptive_norm', + 'label_xattn': 'cross_attention', }, conditioning_dropout_rate=0.1, ) @@ -339,8 +339,8 @@ output_embeddings = conditioning_encoder.apply( ) # 5. Inspect the output -adanorm_emb = output_embeddings[ConditioningMechanism.ADAPTIVE_NORM] -xattn_emb = output_embeddings[ConditioningMechanism.CROSS_ATTENTION] +adanorm_emb = output_embeddings['adaptive_norm'] +xattn_emb = output_embeddings['cross_attention'] # The adanorm embedding is the sum of time and label_adanorm embeddings print(f"Adaptive Norm embedding shape: {adanorm_emb.shape}") diff --git a/hackable_diffusion/kdiff/configs/imagenet64_unet.py b/hackable_diffusion/kdiff/configs/imagenet64_unet.py index e4a73cf..7b31134 100644 --- a/hackable_diffusion/kdiff/configs/imagenet64_unet.py +++ b/hackable_diffusion/kdiff/configs/imagenet64_unet.py @@ -65,8 +65,8 @@ def get_config(): }, embedding_merging_method=hd.architecture.EmbeddingMergeMethod.SUM, conditioning_rules={ - "label": hd.architecture.ConditioningMechanism.ADAPTIVE_NORM, - "time": hd.architecture.ConditioningMechanism.ADAPTIVE_NORM, + "label": 'adaptive_norm', + "time": 'adaptive_norm', }, conditioning_dropout_rate=0.1, ) diff --git a/hackable_diffusion/lib/architecture/__init__.py b/hackable_diffusion/lib/architecture/__init__.py index 52b5d1e..ab03dc2 100644 --- a/hackable_diffusion/lib/architecture/__init__.py +++ b/hackable_diffusion/lib/architecture/__init__.py @@ -16,7 +16,6 @@ # pylint: disable=g-importing-member from hackable_diffusion.lib.architecture.arch_typing import ConditionalBackbone -from hackable_diffusion.lib.architecture.arch_typing import ConditioningMechanism from hackable_diffusion.lib.architecture.arch_typing import DownsampleType from hackable_diffusion.lib.architecture.arch_typing import EmbeddingMergeMethod from hackable_diffusion.lib.architecture.arch_typing import NormalizationType diff --git a/hackable_diffusion/lib/architecture/arch_typing.py b/hackable_diffusion/lib/architecture/arch_typing.py index 0a8cd8d..09fd006 100644 --- a/hackable_diffusion/lib/architecture/arch_typing.py +++ b/hackable_diffusion/lib/architecture/arch_typing.py @@ -19,7 +19,7 @@ """ import enum -from typing import Callable, Protocol +from typing import Any, Callable, Protocol from hackable_diffusion.lib import hd_typing import jax @@ -50,17 +50,6 @@ class EmbeddingMergeMethod(enum.StrEnum): CONCAT = "concat" -class ConditioningMechanism(enum.StrEnum): - """Types of conditioning mechanisms.""" - - ADAPTIVE_NORM = "adaptive_norm" - CROSS_ATTENTION = "cross_attention" - CONCATENATE = "concatenate" - SUM = "sum" - SELF_CONDITIONING = "self_conditioning" - CUSTOM = "custom" - - class RoPEPositionType(enum.StrEnum): """Rotary Position Embedding (RoPE) types.""" @@ -95,6 +84,22 @@ class SkipConnectionMethod(enum.StrEnum): NORMALIZED_ADD = "normalized_add" +################################################################################ +# MARK: Conditioning Mechanism +################################################################################ + + +# Conditioning embeddings corresponds to a dictionary with keys corresponding to +# the specification of a conditioning mechanism. +# Example of common conditioning mechanisms: +# - adaptive_norm +# - cross_attention +# - concatenate +# - sum +# - self_conditioning +# +ConditioningEmbeddings = dict[str, Any] + ################################################################################ # MARK: Types and protocols ################################################################################ @@ -108,7 +113,7 @@ class ConditionalBackbone(Protocol): def __call__( self, x: DataTree, - conditioning_embeddings: dict[ConditioningMechanism, Float["batch ..."]], + conditioning_embeddings: ConditioningEmbeddings, is_training: bool, ) -> DataTree: ... diff --git a/hackable_diffusion/lib/architecture/conditioning_encoder.py b/hackable_diffusion/lib/architecture/conditioning_encoder.py index d368743..0b90cc2 100644 --- a/hackable_diffusion/lib/architecture/conditioning_encoder.py +++ b/hackable_diffusion/lib/architecture/conditioning_encoder.py @@ -40,7 +40,7 @@ Num = hd_typing.Num -ConditioningMechanism = arch_typing.ConditioningMechanism + EmbeddingMergeMethod = arch_typing.EmbeddingMergeMethod ################################################################################ @@ -76,7 +76,7 @@ def __call__( time: hd_typing.TimeArray, conditioning: hd_typing.Conditioning | None, is_training: bool, - ) -> dict[ConditioningMechanism, Float['batch ...']]: + ) -> arch_typing.ConditioningEmbeddings: ... @@ -365,9 +365,9 @@ class ConditioningEncoder(nn.Module, BaseConditioningEncoder): ), embedding_merging_method=EmbeddingMergeMethod.SUM, conditioning_rules=dict( - label_foo =ConditioningMechanism.ADAPTIVE_NORM, - label_bar=ConditioningMechanism.CROSS_ATTENTION, - time=ConditioningMechanism.ADAPTIVE_NORM, + label_foo ='adaptive_norm', + label_bar='cross_attention', + time='adaptive_norm', ), ) ``` @@ -392,7 +392,7 @@ class ConditioningEncoder(nn.Module, BaseConditioningEncoder): time_embedder: BaseTimeEmbedder conditioning_embedders: dict[str, BaseEmbedder] embedding_merging_method: EmbeddingMergeMethod - conditioning_rules: dict[str, ConditioningMechanism] + conditioning_rules: arch_typing.ConditioningEmbeddings conditioning_dropout_rate: float = 0.0 def setup(self): @@ -427,7 +427,7 @@ def __call__( time: hd_typing.TimeTree, conditioning: hd_typing.Conditioning | None, is_training: bool, - ) -> dict[ConditioningMechanism, Num['batch ...']]: + ) -> arch_typing.ConditioningEmbeddings: """Encodes and combines time and conditioning signals. The output is a dictionary where keys are the embedding mechanisms specified diff --git a/hackable_diffusion/lib/architecture/conditioning_encoder_test.py b/hackable_diffusion/lib/architecture/conditioning_encoder_test.py index 4b6849b..e42eaa9 100644 --- a/hackable_diffusion/lib/architecture/conditioning_encoder_test.py +++ b/hackable_diffusion/lib/architecture/conditioning_encoder_test.py @@ -26,7 +26,7 @@ ################################################################################ EmbeddingMergeMethod = arch_typing.EmbeddingMergeMethod -ConditioningMechanism = arch_typing.ConditioningMechanism + ################################################################################ # MARK: Tests @@ -39,13 +39,13 @@ class EncodeConditioningTest(parameterized.TestCase): ( 'test1', EmbeddingMergeMethod.SUM, - ConditioningMechanism.ADAPTIVE_NORM, + 'adaptive_norm', True, ), ( 'test2', EmbeddingMergeMethod.CONCAT, - ConditioningMechanism.CROSS_ATTENTION, + 'cross_attention', False, ), ) @@ -110,13 +110,13 @@ def test_basic( ( 'test1', EmbeddingMergeMethod.SUM, - ConditioningMechanism.ADAPTIVE_NORM, + 'adaptive_norm', True, ), ( 'test2', EmbeddingMergeMethod.CONCAT, - ConditioningMechanism.CROSS_ATTENTION, + 'cross_attention', False, ), ) @@ -182,13 +182,13 @@ def test_mlp_embedder( ( 'test1', EmbeddingMergeMethod.SUM, - ConditioningMechanism.ADAPTIVE_NORM, + 'adaptive_norm', True, ), ( 'test2', EmbeddingMergeMethod.CONCAT, - ConditioningMechanism.CROSS_ATTENTION, + 'cross_attention', False, ), ) @@ -257,13 +257,13 @@ def test_mlp_embedder_process_multiple_keys( ( 'test1', EmbeddingMergeMethod.SUM, - ConditioningMechanism.ADAPTIVE_NORM, + 'adaptive_norm', True, ), ( 'test2', EmbeddingMergeMethod.CONCAT, - ConditioningMechanism.CROSS_ATTENTION, + 'cross_attention', False, ), ) @@ -330,8 +330,8 @@ def test_field_selector_embedder(self): ) conditioning_encoders = {'image': image_selector} conditioning_rules = { - 'time': ConditioningMechanism.ADAPTIVE_NORM, - 'image': ConditioningMechanism.CROSS_ATTENTION, + 'time': 'adaptive_norm', + 'image': 'cross_attention', } embedding_merging_method = EmbeddingMergeMethod.CONCAT @@ -353,18 +353,18 @@ def test_field_selector_embedder(self): {'params': params}, t, c, is_training=False, rngs={'dropout': rng} ) - self.assertIn(ConditioningMechanism.CROSS_ATTENTION, output) + self.assertIn('cross_attention', output) self.assertEqual( - output[ConditioningMechanism.CROSS_ATTENTION].shape, + output['cross_attention'].shape, (batch_size,) + image_shape, ) self.assertTrue( - jnp.all(output[ConditioningMechanism.CROSS_ATTENTION] == c['image']) + jnp.all(output['cross_attention'] == c['image']) ) - self.assertIn(ConditioningMechanism.ADAPTIVE_NORM, output) + self.assertIn('adaptive_norm', output) self.assertEqual( - output[ConditioningMechanism.ADAPTIVE_NORM].shape, + output['adaptive_norm'].shape, (batch_size, num_features), ) @@ -384,8 +384,8 @@ def test_field_selector_embedder_fails_on_missing_key(self): ) conditioning_encoders = {'image': image_selector} conditioning_rules = { - 'time': ConditioningMechanism.ADAPTIVE_NORM, - 'image': ConditioningMechanism.CROSS_ATTENTION, + 'time': 'adaptive_norm', + 'image': 'cross_attention', } embedding_merging_method = EmbeddingMergeMethod.CONCAT @@ -411,7 +411,7 @@ def test_field_selector_embedder_fails_on_missing_key(self): ( 'test1', EmbeddingMergeMethod.CONCAT, - ConditioningMechanism.CROSS_ATTENTION, + 'cross_attention', 8, 16, False, @@ -475,7 +475,7 @@ def test_different_num_features( ( 'test1', EmbeddingMergeMethod.CONCAT, - ConditioningMechanism.CROSS_ATTENTION, + 'cross_attention', 8, 9, 10, @@ -484,7 +484,7 @@ def test_different_num_features( ( 'test2', EmbeddingMergeMethod.CONCAT, - ConditioningMechanism.CROSS_ATTENTION, + 'cross_attention', 8, 9, 10, @@ -567,7 +567,7 @@ def test_multilabel( ( 'test1', EmbeddingMergeMethod.CONCAT, - ConditioningMechanism.CROSS_ATTENTION, + 'cross_attention', 8, 9, 10, @@ -576,7 +576,7 @@ def test_multilabel( ( 'test2', EmbeddingMergeMethod.CONCAT, - ConditioningMechanism.CROSS_ATTENTION, + 'cross_attention', 8, 9, 10, @@ -668,8 +668,8 @@ def test_dropout(self): conditioning_embedders=conditioning_encoders, embedding_merging_method=EmbeddingMergeMethod.SUM, conditioning_rules={ - 'time': ConditioningMechanism.ADAPTIVE_NORM, - 'label': ConditioningMechanism.ADAPTIVE_NORM, + 'time': 'adaptive_norm', + 'label': 'adaptive_norm', }, conditioning_dropout_rate=1.0, # Drop all conditioning ) diff --git a/hackable_diffusion/lib/architecture/discrete.py b/hackable_diffusion/lib/architecture/discrete.py index 77d93ac..afda938 100644 --- a/hackable_diffusion/lib/architecture/discrete.py +++ b/hackable_diffusion/lib/architecture/discrete.py @@ -32,7 +32,7 @@ Int = hd_typing.Int ConditionalBackbone = arch_typing.ConditionalBackbone -ConditioningMechanism = arch_typing.ConditioningMechanism + ################################################################################ # MARK: Token Embedder @@ -213,7 +213,7 @@ def __post_init__(self): def __call__( self, x: Int['batch *other 1'], - conditioning_embeddings: dict[ConditioningMechanism, Float['batch ...']], + conditioning_embeddings: arch_typing.ConditioningEmbeddings, is_training: bool, ) -> Float['batch *other V']: diff --git a/hackable_diffusion/lib/architecture/discrete_test.py b/hackable_diffusion/lib/architecture/discrete_test.py index 75d42ee..6cc0d04 100644 --- a/hackable_diffusion/lib/architecture/discrete_test.py +++ b/hackable_diffusion/lib/architecture/discrete_test.py @@ -31,7 +31,7 @@ # MARK: Type Aliases ################################################################################ -ConditioningMechanism = arch_typing.ConditioningMechanism + ################################################################################ # MARK: Tests @@ -50,15 +50,15 @@ def setUp(self): self.cond_dim = 16 self.discrete_x = jnp.ones((self.batch_size, *self.shape), dtype=jnp.int32) self.concatenate_emb = { - ConditioningMechanism.CONCATENATE: jnp.ones( + 'concatenate': jnp.ones( (self.batch_size, self.cond_dim) ), } self.sum_emb = { - ConditioningMechanism.SUM: jnp.ones((self.batch_size, self.cond_dim)), + 'sum': jnp.ones((self.batch_size, self.cond_dim)), } self.adaptive_norm_emb = { - ConditioningMechanism.ADAPTIVE_NORM: jnp.ones( + 'adaptive_norm': jnp.ones( (self.batch_size, self.cond_dim) ), } @@ -68,7 +68,7 @@ def setUp(self): activation='relu', dropout_rate=0.0, zero_init_output=False, - conditioning_mechanism=ConditioningMechanism.CONCATENATE, + conditioning_mechanism='concatenate', ) self.unet_module = unet.Unet( base_channels=8, diff --git a/hackable_diffusion/lib/architecture/dit.py b/hackable_diffusion/lib/architecture/dit.py index fec0891..d3b4d70 100644 --- a/hackable_diffusion/lib/architecture/dit.py +++ b/hackable_diffusion/lib/architecture/dit.py @@ -39,7 +39,7 @@ DataArray = hd_typing.DataArray ConditionalBackbone = arch_typing.ConditionalBackbone -ConditioningMechanism = arch_typing.ConditioningMechanism + NormalizationType = arch_typing.NormalizationType ################################################################################ @@ -100,11 +100,11 @@ def setup(self): def __call__( self, x: DataArray, - conditioning_embeddings: dict[ConditioningMechanism, Float["batch ..."]], + conditioning_embeddings: arch_typing.ConditioningEmbeddings, is_training: bool, ) -> DataArray: adaptive_norm_emb = conditioning_embeddings.get( - ConditioningMechanism.ADAPTIVE_NORM + 'adaptive_norm' ) if adaptive_norm_emb is None: raise ValueError("adaptive_norm_emb must be provided.") diff --git a/hackable_diffusion/lib/architecture/dit_test.py b/hackable_diffusion/lib/architecture/dit_test.py index 4423d63..4255f14 100644 --- a/hackable_diffusion/lib/architecture/dit_test.py +++ b/hackable_diffusion/lib/architecture/dit_test.py @@ -28,7 +28,7 @@ # MARK: Type Aliases ################################################################################ -ConditioningMechanism = arch_typing.ConditioningMechanism + ################################################################################ # MARK: Tests @@ -67,7 +67,7 @@ def test_output_shape_with_patchify(self): ), ) conditioning_embeddings = { - ConditioningMechanism.ADAPTIVE_NORM: jnp.ones( + 'adaptive_norm': jnp.ones( (self.batch_size, self.cond_dim) ), } @@ -103,7 +103,7 @@ def test_variable_shapes_with_patchify(self): absolute_posenc=dit_blocks.PositionalEmbedding(), ) conditioning_embeddings = { - ConditioningMechanism.ADAPTIVE_NORM: jnp.ones( + 'adaptive_norm': jnp.ones( (self.batch_size, self.cond_dim) ), } @@ -212,7 +212,7 @@ def test_output_shape_tokens(self): input_shape = (self.batch_size, self.sequence_length, self.embedding_dim) x = jnp.ones(input_shape) conditioning_embeddings = { - ConditioningMechanism.ADAPTIVE_NORM: jnp.ones( + 'adaptive_norm': jnp.ones( (self.batch_size, self.cond_dim) ), } diff --git a/hackable_diffusion/lib/architecture/mlp.py b/hackable_diffusion/lib/architecture/mlp.py index f2c2ab0..d7bfbf8 100644 --- a/hackable_diffusion/lib/architecture/mlp.py +++ b/hackable_diffusion/lib/architecture/mlp.py @@ -37,7 +37,7 @@ DataArray = hd_typing.DataArray ConditionalBackbone = arch_typing.ConditionalBackbone -ConditioningMechanism = arch_typing.ConditioningMechanism + ################################################################################ # MARK: ConditionalMLP @@ -68,7 +68,7 @@ class ConditionalMLP(nn.Module, ConditionalBackbone): zero_init_output: bool dropout_rate: float conditioning_mechanism: Literal[ - ConditioningMechanism.SUM, ConditioningMechanism.CONCATENATE + 'sum', 'concatenate' ] dtype: DType = jnp.float32 @@ -77,7 +77,7 @@ class ConditionalMLP(nn.Module, ConditionalBackbone): def __call__( self, x: DataArray, - conditioning_embeddings: dict[ConditioningMechanism, Float['batch ...']], + conditioning_embeddings: arch_typing.ConditioningEmbeddings, *, is_training: bool, ) -> DataArray: @@ -100,7 +100,7 @@ def __call__( c_emb = conditioning_embeddings.get(self.conditioning_mechanism) if c_emb is None: raise ValueError('Conditioning embeddings are not provided.') - if self.conditioning_mechanism == ConditioningMechanism.SUM: + if self.conditioning_mechanism == 'sum': # Since the conditioning embedding may not have the same dimension as # `x_emb`, we project it to the same size as `x_emb`. c_emb = nn.Dense( @@ -109,7 +109,7 @@ def __call__( name='Dense_Projection_Conditioning', )(c_emb) emb = c_emb + x_emb - elif self.conditioning_mechanism == ConditioningMechanism.CONCATENATE: + elif self.conditioning_mechanism == 'concatenate': emb = jnp.concatenate((c_emb, x_emb), axis=-1) else: raise ValueError( diff --git a/hackable_diffusion/lib/architecture/mlp_test.py b/hackable_diffusion/lib/architecture/mlp_test.py index 284fc7a..115a768 100644 --- a/hackable_diffusion/lib/architecture/mlp_test.py +++ b/hackable_diffusion/lib/architecture/mlp_test.py @@ -29,7 +29,7 @@ # MARK: Type Aliases ################################################################################ -ConditioningMechanism = arch_typing.ConditioningMechanism + ################################################################################ # MARK: Tests @@ -49,12 +49,12 @@ def setUp(self): self.x = jnp.ones((self.batch_size, *self.shape)) self.flatten_x = jnp.reshape(self.x, (self.batch_size, -1)) self.concatenate_emb = { - ConditioningMechanism.CONCATENATE: jnp.ones( + 'concatenate': jnp.ones( (self.batch_size, self.cond_dim) ), } self.sum_emb = { - ConditioningMechanism.SUM: jnp.ones((self.batch_size, self.cond_dim)), + 'sum': jnp.ones((self.batch_size, self.cond_dim)), } # ConditionalMLP tests @@ -67,7 +67,7 @@ def test_conditional_mlp_output_shape(self): activation='relu', dropout_rate=0.0, zero_init_output=False, - conditioning_mechanism=ConditioningMechanism.CONCATENATE, + conditioning_mechanism='concatenate', ) variables = mlp_module.init( self.key, @@ -91,7 +91,7 @@ def test_conditional_mlp_zero_init_output(self): activation='relu', dropout_rate=0.0, zero_init_output=True, - conditioning_mechanism=ConditioningMechanism.CONCATENATE, + conditioning_mechanism='concatenate', ) variables = mlp_module.init( self.key, @@ -125,7 +125,7 @@ def test_conditional_mlp_concatenate_variables_shape(self): activation='relu', dropout_rate=0.0, zero_init_output=False, - conditioning_mechanism=ConditioningMechanism.CONCATENATE, + conditioning_mechanism='concatenate', ) variables = mlp_module.init( self.key, @@ -191,7 +191,7 @@ def test_conditional_mlp_sum_variables_shape(self): activation='relu', dropout_rate=0.0, zero_init_output=False, - conditioning_mechanism=ConditioningMechanism.SUM, + conditioning_mechanism='sum', ) variables = mlp_module.init( self.key, diff --git a/hackable_diffusion/lib/architecture/riemannian_test.py b/hackable_diffusion/lib/architecture/riemannian_test.py index 3e010d0..c4c1003 100644 --- a/hackable_diffusion/lib/architecture/riemannian_test.py +++ b/hackable_diffusion/lib/architecture/riemannian_test.py @@ -34,7 +34,7 @@ def test_riemannian_backbone_projection(self): activation='relu', zero_init_output=True, dropout_rate=0.0, - conditioning_mechanism=mlp.ConditioningMechanism.CONCATENATE, + conditioning_mechanism='concatenate', ) model = riemannian.RiemannianConditionalBackbone( backbone=backbone, @@ -45,9 +45,8 @@ def test_riemannian_backbone_projection(self): xt = manifold.random_uniform(key, (4, 3)) time_emb = jnp.array([[0.5], [0.5], [0.5], [0.5]]) - # conditioning_embeddings must be a dict keyed by ConditioningMechanism. conditioning_embeddings = { - arch_typing.ConditioningMechanism.CONCATENATE: time_emb, + 'concatenate': time_emb, } variables = model.init(key, xt, conditioning_embeddings, is_training=False) @@ -69,7 +68,7 @@ def test_variable_names_and_shapes(self): activation='relu', zero_init_output=True, dropout_rate=0.0, - conditioning_mechanism=mlp.ConditioningMechanism.CONCATENATE, + conditioning_mechanism='concatenate', ) model = riemannian.RiemannianConditionalBackbone( backbone=backbone, @@ -80,7 +79,7 @@ def test_variable_names_and_shapes(self): xt = manifold.random_uniform(key, (4, 3)) time_emb = jnp.array([[0.5], [0.5], [0.5], [0.5]]) conditioning_embeddings = { - arch_typing.ConditioningMechanism.CONCATENATE: time_emb, + 'concatenate': time_emb, } variables = model.init(key, xt, conditioning_embeddings, is_training=False) diff --git a/hackable_diffusion/lib/architecture/simplicial.py b/hackable_diffusion/lib/architecture/simplicial.py index f5f8eb3..3f90264 100644 --- a/hackable_diffusion/lib/architecture/simplicial.py +++ b/hackable_diffusion/lib/architecture/simplicial.py @@ -56,7 +56,7 @@ Int = hd_typing.Int ConditionalBackbone = arch_typing.ConditionalBackbone -ConditioningMechanism = arch_typing.ConditioningMechanism + BaseProjector = discrete.BaseProjector @@ -162,7 +162,7 @@ def __post_init__(self): def __call__( self, x: Float['batch *other V'], - conditioning_embeddings: dict[ConditioningMechanism, Float['batch ...']], + conditioning_embeddings: arch_typing.ConditioningEmbeddings, is_training: bool, ) -> Float['batch *other V']: diff --git a/hackable_diffusion/lib/architecture/unet.py b/hackable_diffusion/lib/architecture/unet.py index 60b4aa0..e593e82 100644 --- a/hackable_diffusion/lib/architecture/unet.py +++ b/hackable_diffusion/lib/architecture/unet.py @@ -38,7 +38,7 @@ RoPEPositionType = arch_typing.RoPEPositionType SkipConnectionMethod = arch_typing.SkipConnectionMethod ConditionalBackbone = arch_typing.ConditionalBackbone -ConditioningMechanism = arch_typing.ConditioningMechanism + ################################################################################ # MARK: Unet @@ -163,21 +163,21 @@ def setup(self): def __call__( self, x: Float["batch height width channels"], - conditioning_embeddings: dict[ConditioningMechanism, Float["batch ..."]], + conditioning_embeddings: arch_typing.ConditioningEmbeddings, *, is_training: bool, ) -> Float["batch height width output_channels"]: # Extract conditioning embeddings to use with adaptive normalization. adaptive_norm_emb = conditioning_embeddings.get( - ConditioningMechanism.ADAPTIVE_NORM + 'adaptive_norm' ) if adaptive_norm_emb is None: raise ValueError("adaptive_norm_emb must be provided.") # Extract conditioning embeddings to use with cross attention. cross_attention_emb = conditioning_embeddings.get( - ConditioningMechanism.CROSS_ATTENTION + 'cross_attention' ) if any(self.cross_attention_bool) and cross_attention_emb is None: raise ValueError( diff --git a/hackable_diffusion/lib/architecture/unet_test.py b/hackable_diffusion/lib/architecture/unet_test.py index fbbff4f..cd0b54c 100644 --- a/hackable_diffusion/lib/architecture/unet_test.py +++ b/hackable_diffusion/lib/architecture/unet_test.py @@ -34,7 +34,7 @@ UpsampleType = arch_typing.UpsampleType SkipConnectionMethod = arch_typing.SkipConnectionMethod INVALID_INT = arch_typing.INVALID_INT -ConditioningMechanism = arch_typing.ConditioningMechanism + ################################################################################ # MARK: Tests @@ -102,8 +102,8 @@ def test_output_shape(self, config: Config): """Tests Unet output shape.""" x_shape = (2, 16, 16, 3) conditioning_embeddings = { - ConditioningMechanism.ADAPTIVE_NORM: jnp.ones((2, 32)), - ConditioningMechanism.CROSS_ATTENTION: jnp.ones((2, 16, 32)), + 'adaptive_norm': jnp.ones((2, 32)), + 'cross_attention': jnp.ones((2, 16, 32)), } x = jnp.ones(x_shape) model = unet.Unet(**dataclasses.asdict(config), dtype=jnp.float32) @@ -130,8 +130,8 @@ def test_output_num_channels(self, config: Config): num_input_channels = 3 x_shape = (2, 16, 16, num_input_channels) conditioning_embeddings = { - ConditioningMechanism.ADAPTIVE_NORM: jnp.ones((2, 32)), - ConditioningMechanism.CROSS_ATTENTION: jnp.ones((2, 16, 32)), + 'adaptive_norm': jnp.ones((2, 32)), + 'cross_attention': jnp.ones((2, 16, 32)), } x = jnp.ones(x_shape) model = unet.Unet( diff --git a/hackable_diffusion/lib/diffusion_network_test.py b/hackable_diffusion/lib/diffusion_network_test.py index aa7bad6..196e871 100644 --- a/hackable_diffusion/lib/diffusion_network_test.py +++ b/hackable_diffusion/lib/diffusion_network_test.py @@ -14,8 +14,6 @@ """Tests for diffusion_network and its components.""" -from collections.abc import Mapping - import chex from flax import linen as nn from hackable_diffusion.lib import diffusion_network @@ -131,9 +129,9 @@ def setUp(self): }, embedding_merging_method=arch_typing.EmbeddingMergeMethod.CONCAT, conditioning_rules={ - 'time': arch_typing.ConditioningMechanism.ADAPTIVE_NORM, - 'label_foo': arch_typing.ConditioningMechanism.ADAPTIVE_NORM, - 'label_bar': arch_typing.ConditioningMechanism.CROSS_ATTENTION, + 'time': 'adaptive_norm', + 'label_foo': 'adaptive_norm', + 'label_bar': 'cross_attention', }, ) self.backbone = unet.Unet(**UNET_CONFIG) @@ -273,9 +271,7 @@ class SelfConditioningBackbone(nn.Module, arch_typing.ConditionalBackbone): def __call__( self, x: arch_typing.DataTree, - conditioning_embeddings: Mapping[ - arch_typing.ConditioningMechanism, Float['batch ...'] - ], + conditioning_embeddings: arch_typing.ConditioningEmbeddings, is_training: bool, ) -> arch_typing.DataTree: return nn.Dense(features=self.num_classes)(x) @@ -334,8 +330,8 @@ def get_schedule_info(self, time): }, embedding_merging_method=arch_typing.EmbeddingMergeMethod.CONCAT, conditioning_rules={ - 'time': arch_typing.ConditioningMechanism.ADAPTIVE_NORM, - 'label': arch_typing.ConditioningMechanism.ADAPTIVE_NORM, + 'time': 'adaptive_norm', + 'label': 'adaptive_norm', }, ) self.backbone = SelfConditioningBackbone( @@ -619,9 +615,9 @@ def test_multimodal_diffusion_network(self, input_type: str): }, embedding_merging_method=arch_typing.EmbeddingMergeMethod.CONCAT, conditioning_rules={ - 'time': arch_typing.ConditioningMechanism.ADAPTIVE_NORM, - 'label_foo': arch_typing.ConditioningMechanism.ADAPTIVE_NORM, - 'label_bar': arch_typing.ConditioningMechanism.CROSS_ATTENTION, + 'time': 'adaptive_norm', + 'label_foo': 'adaptive_norm', + 'label_bar': 'cross_attention', }, ) @@ -702,9 +698,9 @@ def setUp(self): }, embedding_merging_method=arch_typing.EmbeddingMergeMethod.CONCAT, conditioning_rules={ - 'time': arch_typing.ConditioningMechanism.ADAPTIVE_NORM, - 'label_foo': arch_typing.ConditioningMechanism.ADAPTIVE_NORM, - 'label_bar': arch_typing.ConditioningMechanism.CROSS_ATTENTION, + 'time': 'adaptive_norm', + 'label_foo': 'adaptive_norm', + 'label_bar': 'cross_attention', }, ) diff --git a/hackable_diffusion/lib/inference/diffusion_inference_test.py b/hackable_diffusion/lib/inference/diffusion_inference_test.py index dc703fe..605d16b 100644 --- a/hackable_diffusion/lib/inference/diffusion_inference_test.py +++ b/hackable_diffusion/lib/inference/diffusion_inference_test.py @@ -62,9 +62,9 @@ ), } CONDITIONING_RULES = { - 'time': arch_typing.ConditioningMechanism.ADAPTIVE_NORM, - 'label_foo': arch_typing.ConditioningMechanism.ADAPTIVE_NORM, - 'label_bar': arch_typing.ConditioningMechanism.CROSS_ATTENTION, + 'time': 'adaptive_norm', + 'label_foo': 'adaptive_norm', + 'label_bar': 'cross_attention', } UNET_CONFIG = { diff --git a/hackable_diffusion/lib/multimodal_test.py b/hackable_diffusion/lib/multimodal_test.py index 6d5eefa..902679f 100644 --- a/hackable_diffusion/lib/multimodal_test.py +++ b/hackable_diffusion/lib/multimodal_test.py @@ -59,9 +59,7 @@ class IdentityBackbone(nn.Module, arch_typing.ConditionalBackbone): def __call__( self, x: arch_typing.DataTree, - conditioning_embeddings: dict[ - arch_typing.ConditioningMechanism, Float['batch ...'] - ], + conditioning_embeddings: arch_typing.ConditioningEmbeddings, is_training: bool, ) -> arch_typing.DataTree: return x diff --git a/hackable_diffusion/lib/test_helpers.py b/hackable_diffusion/lib/test_helpers.py index fd2b307..2c834c4 100644 --- a/hackable_diffusion/lib/test_helpers.py +++ b/hackable_diffusion/lib/test_helpers.py @@ -91,9 +91,7 @@ class IdentityBackbone(nn.Module, arch_typing.ConditionalBackbone): def __call__( self, x: arch_typing.DataTree, - conditioning_embeddings: dict[ - arch_typing.ConditioningMechanism, Float['batch ...'] - ], + conditioning_embeddings: arch_typing.ConditioningEmbeddings, is_training: bool, ) -> arch_typing.DataTree: return x diff --git a/hackable_diffusion/notebooks/2d_training.ipynb b/hackable_diffusion/notebooks/2d_training.ipynb index 18e7faa..a6229df 100644 --- a/hackable_diffusion/notebooks/2d_training.ipynb +++ b/hackable_diffusion/notebooks/2d_training.ipynb @@ -187,7 +187,7 @@ " zero_init_output=False,\n", " dtype=jnp.float32,\n", " dropout_rate=0.0,\n", - " conditioning_mechanism=arch_typing.ConditioningMechanism.CONCATENATE,\n", + " conditioning_mechanism='concatenate',\n", ")\n", "\n", "conditioning_embedders = {}\n", @@ -199,7 +199,7 @@ " conditioning_embedders=conditioning_embedders,\n", " embedding_merging_method='concat',\n", " conditioning_rules={\n", - " 'time': arch_typing.ConditioningMechanism.CONCATENATE,\n", + " 'time': 'concatenate',\n", " },\n", ")\n", "\n", diff --git a/hackable_diffusion/notebooks/mnist.ipynb b/hackable_diffusion/notebooks/mnist.ipynb index 69cafe8..9047f5f 100644 --- a/hackable_diffusion/notebooks/mnist.ipynb +++ b/hackable_diffusion/notebooks/mnist.ipynb @@ -289,8 +289,8 @@ " conditioning_embedders=conditioning_embedders,\n", " embedding_merging_method=arch_typing.EmbeddingMergeMethod.SUM,\n", " conditioning_rules={\n", - " 'time': arch_typing.ConditioningMechanism.ADAPTIVE_NORM,\n", - " 'label': arch_typing.ConditioningMechanism.ADAPTIVE_NORM,\n", + " 'time': 'adaptive_norm',\n", + " 'label': 'adaptive_norm',\n", " },\n", ")" ], diff --git a/hackable_diffusion/notebooks/mnist_discrete.ipynb b/hackable_diffusion/notebooks/mnist_discrete.ipynb index eea5b40..c93c870 100644 --- a/hackable_diffusion/notebooks/mnist_discrete.ipynb +++ b/hackable_diffusion/notebooks/mnist_discrete.ipynb @@ -340,8 +340,8 @@ " conditioning_embedders=conditioning_embedders,\n", " embedding_merging_method=arch_typing.EmbeddingMergeMethod.SUM,\n", " conditioning_rules={\n", - " 'time': arch_typing.ConditioningMechanism.ADAPTIVE_NORM,\n", - " 'label': arch_typing.ConditioningMechanism.ADAPTIVE_NORM,\n", + " 'time': 'adaptive_norm',\n", + " 'label': 'adaptive_norm',\n", " },\n", ")" ], diff --git a/hackable_diffusion/notebooks/mnist_dit.ipynb b/hackable_diffusion/notebooks/mnist_dit.ipynb index 02531bc..6e88f88 100644 --- a/hackable_diffusion/notebooks/mnist_dit.ipynb +++ b/hackable_diffusion/notebooks/mnist_dit.ipynb @@ -286,8 +286,8 @@ " conditioning_embedders=conditioning_embedders,\n", " embedding_merging_method=arch_typing.EmbeddingMergeMethod.CONCAT,\n", " conditioning_rules={\n", - " 'time': arch_typing.ConditioningMechanism.ADAPTIVE_NORM,\n", - " 'label': arch_typing.ConditioningMechanism.ADAPTIVE_NORM,\n", + " 'time': 'adaptive_norm',\n", + " 'label': 'adaptive_norm',\n", " },\n", ")" ], diff --git a/hackable_diffusion/notebooks/mnist_multimodal.ipynb b/hackable_diffusion/notebooks/mnist_multimodal.ipynb index 9f99c33..c6b8508 100644 --- a/hackable_diffusion/notebooks/mnist_multimodal.ipynb +++ b/hackable_diffusion/notebooks/mnist_multimodal.ipynb @@ -576,7 +576,7 @@ "# Conditional diffusion.\n", "################################################################################\n", "\n", - "ConditioningMechanism = arch_typing.ConditioningMechanism\n", + "\n", "\n", "conditioning_embedders = {\n", " 'label': conditioning_encoder.LabelEmbedder(\n", @@ -604,8 +604,8 @@ " conditioning_embedders=conditioning_embedders,\n", " embedding_merging_method=arch_typing.EmbeddingMergeMethod.SUM,\n", " conditioning_rules={\n", - " 'time': ConditioningMechanism.ADAPTIVE_NORM,\n", - " 'label': ConditioningMechanism.ADAPTIVE_NORM,\n", + " 'time': 'adaptive_norm',\n", + " 'label': 'adaptive_norm',\n", " },\n", ")" ] diff --git a/hackable_diffusion/notebooks/mnist_nn_and_nnx.ipynb b/hackable_diffusion/notebooks/mnist_nn_and_nnx.ipynb index 37c8713..ebc3114 100644 --- a/hackable_diffusion/notebooks/mnist_nn_and_nnx.ipynb +++ b/hackable_diffusion/notebooks/mnist_nn_and_nnx.ipynb @@ -292,8 +292,8 @@ " conditioning_embedders=conditioning_embedders,\n", " embedding_merging_method=arch_typing.EmbeddingMergeMethod.SUM,\n", " conditioning_rules={\n", - " 'time': arch_typing.ConditioningMechanism.ADAPTIVE_NORM,\n", - " 'label': arch_typing.ConditioningMechanism.ADAPTIVE_NORM,\n", + " 'time': 'adaptive_norm',\n", + " 'label': 'adaptive_norm',\n", " },\n", ")" ], diff --git a/hackable_diffusion/notebooks/mnist_simplicial.ipynb b/hackable_diffusion/notebooks/mnist_simplicial.ipynb index be7315d..1116e1d 100644 --- a/hackable_diffusion/notebooks/mnist_simplicial.ipynb +++ b/hackable_diffusion/notebooks/mnist_simplicial.ipynb @@ -318,8 +318,8 @@ " conditioning_embedders=conditioning_embedders,\n", " embedding_merging_method=arch_typing.EmbeddingMergeMethod.SUM,\n", " conditioning_rules={\n", - " 'time': arch_typing.ConditioningMechanism.ADAPTIVE_NORM,\n", - " 'label': arch_typing.ConditioningMechanism.ADAPTIVE_NORM,\n", + " 'time': 'adaptive_norm',\n", + " 'label': 'adaptive_norm',\n", " },\n", ")" ],