diff --git a/hackable_diffusion/lib/architecture/arch_typing.py b/hackable_diffusion/lib/architecture/arch_typing.py index 0a8cd8d..5a8264d 100644 --- a/hackable_diffusion/lib/architecture/arch_typing.py +++ b/hackable_diffusion/lib/architecture/arch_typing.py @@ -19,6 +19,7 @@ """ import enum +from typing import Any from typing import Callable, Protocol from hackable_diffusion.lib import hd_typing import jax @@ -50,17 +51,6 @@ class EmbeddingMergeMethod(enum.StrEnum): CONCAT = "concat" -class ConditioningMechanism(enum.StrEnum): - """Types of conditioning mechanisms.""" - - ADAPTIVE_NORM = "adaptive_norm" - CROSS_ATTENTION = "cross_attention" - CONCATENATE = "concatenate" - SUM = "sum" - SELF_CONDITIONING = "self_conditioning" - CUSTOM = "custom" - - class RoPEPositionType(enum.StrEnum): """Rotary Position Embedding (RoPE) types.""" @@ -95,6 +85,29 @@ class SkipConnectionMethod(enum.StrEnum): NORMALIZED_ADD = "normalized_add" +################################################################################ +# MARK: Conditioning Mechanism +################################################################################ + + +class ConditioningMechanism(enum.StrEnum): + """Types of conditioning mechanisms.""" + + ADAPTIVE_NORM = "adaptive_norm" + CROSS_ATTENTION = "cross_attention" + CONCATENATE = "concatenate" + SUM = "sum" + SELF_CONDITIONING = "self_conditioning" + CUSTOM = "custom" + + +# Conditioning embeddings corresponds to a dictionary with keys corresponding to +# the specification of a conditioning mechanism. We use `ConditioningMechanism` +# as a reference for the most common conditioning mechanisms, but the type +# structure is more flexible to allow for more general conditioning mechanisms. +ConditioningEmbeddings = dict[str, Any] + + ################################################################################ # MARK: Types and protocols ################################################################################ @@ -108,7 +121,7 @@ class ConditionalBackbone(Protocol): def __call__( self, x: DataTree, - conditioning_embeddings: dict[ConditioningMechanism, Float["batch ..."]], + conditioning_embeddings: ConditioningEmbeddings, is_training: bool, ) -> DataTree: ... diff --git a/hackable_diffusion/lib/architecture/conditioning_encoder.py b/hackable_diffusion/lib/architecture/conditioning_encoder.py index dc2d973..adfd284 100644 --- a/hackable_diffusion/lib/architecture/conditioning_encoder.py +++ b/hackable_diffusion/lib/architecture/conditioning_encoder.py @@ -42,6 +42,7 @@ ConditioningMechanism = arch_typing.ConditioningMechanism EmbeddingMergeMethod = arch_typing.EmbeddingMergeMethod +ConditioningEmbeddings = arch_typing.ConditioningEmbeddings ################################################################################ # MARK: Base classes @@ -76,7 +77,7 @@ def __call__( time: hd_typing.TimeArray, conditioning: hd_typing.Conditioning | None, is_training: bool, - ) -> dict[ConditioningMechanism, Float['batch ...']]: + ) -> ConditioningEmbeddings: ... @@ -342,6 +343,28 @@ def __call__( ################################################################################ +class CopyConditioningEncoder(nn.Module, BaseConditioningEncoder): + """Copies the conditioning to the output and optionally encodes time.""" + + time_embedder: BaseTimeEmbedder | None = None + + @nn.compact + def __call__( + self, + time: hd_typing.TimeArray, + conditioning: hd_typing.Conditioning | None, + is_training: bool, + ) -> ConditioningEmbeddings: + cond_embs = {} + if self.time_embedder is not None: + cond_embs['time'] = cast(nn.Module, self.time_embedder).copy( + name='TimeEmbedder' + )(time) + if conditioning is not None: + cond_embs.update(conditioning) + return cond_embs + + class ConditioningEncoder(nn.Module, BaseConditioningEncoder): """Encodes and combines time and conditioning signals for a diffusion model. @@ -392,7 +415,7 @@ class ConditioningEncoder(nn.Module, BaseConditioningEncoder): time_embedder: BaseTimeEmbedder conditioning_embedders: dict[str, BaseEmbedder] embedding_merging_method: EmbeddingMergeMethod - conditioning_rules: dict[str, ConditioningMechanism] + conditioning_rules: dict[str, str] conditioning_dropout_rate: float = 0.0 def setup(self): @@ -427,7 +450,7 @@ def __call__( time: hd_typing.TimeTree, conditioning: hd_typing.Conditioning | None, is_training: bool, - ) -> dict[ConditioningMechanism, Num['batch ...']]: + ) -> ConditioningEmbeddings: """Encodes and combines time and conditioning signals. The output is a dictionary where keys are the embedding mechanisms specified diff --git a/hackable_diffusion/lib/architecture/conditioning_encoder_test.py b/hackable_diffusion/lib/architecture/conditioning_encoder_test.py index 4b6849b..59d2ad1 100644 --- a/hackable_diffusion/lib/architecture/conditioning_encoder_test.py +++ b/hackable_diffusion/lib/architecture/conditioning_encoder_test.py @@ -699,6 +699,80 @@ def test_dropout(self): jnp.all(output_eval['adaptive_norm'] == time_embedding_train) ) + def test_copy_conditioning_encoder(self): + """Tests the CopyConditioningEncoder.""" + batch_size = 4 + encoder = conditioning_encoder.CopyConditioningEncoder() + t = jnp.ones((batch_size,)) + conditioning = { + 'label': jnp.arange(batch_size), + 'arbitrary_pytree': dict( + a=jnp.ones((batch_size, 127, 199)), + b=dict(c=jnp.ones((batch_size, 127, 199))), + ), + } + rng = jax.random.PRNGKey(0) + variables = encoder.init(rng, t, conditioning, is_training=True) + jitted_apply = jax.jit(encoder.apply, static_argnames=['is_training']) + output = jitted_apply( + variables, + t, + conditioning, + is_training=True, + rngs={'dropout': rng}, + ) + # Compare the pytree structure and leaves individually. + output_leaves, output_treedef = jax.tree.flatten(output) + cond_leaves, cond_treedef = jax.tree.flatten(conditioning) + self.assertEqual(output_treedef, cond_treedef) + for out_leaf, cond_leaf in zip(output_leaves, cond_leaves): + self.assertTrue(jnp.array_equal(out_leaf, cond_leaf)) + + def test_copy_conditioning_encoder_with_time_embedder(self): + """Tests the CopyConditioningEncoder.""" + batch_size = 4 + num_features = 17 + time_embedder = conditioning_encoder.SinusoidalTimeEmbedder( + activation='silu', + embedding_dim=5, + num_features=num_features, + ) + + encoder = conditioning_encoder.CopyConditioningEncoder( + time_embedder=time_embedder + ) + t = jnp.ones((batch_size,)) + conditioning = { + 'label': jnp.arange(batch_size), + 'arbitrary_pytree': dict( + a=jnp.ones((batch_size, 127, 199)), + b=dict(c=jnp.ones((batch_size, 127, 199))), + ), + } + rng = jax.random.PRNGKey(0) + variables = encoder.init(rng, t, conditioning, is_training=True) + jitted_apply = jax.jit(encoder.apply, static_argnames=['is_training']) + output = jitted_apply( + variables, + t, + conditioning, + is_training=True, + rngs={'dropout': rng}, + ) + + # Verify all conditioning fields are copied through. + for key in conditioning: + self.assertIn(key, output) + out_leaves, out_td = jax.tree.flatten(output[key]) + cond_leaves, cond_td = jax.tree.flatten(conditioning[key]) + self.assertEqual(out_td, cond_td) + for out_leaf, cond_leaf in zip(out_leaves, cond_leaves): + self.assertTrue(jnp.array_equal(out_leaf, cond_leaf)) + + # Verify the time embedding shape. + self.assertIn('time', output) + self.assertEqual(output['time'].shape, (batch_size, num_features)) + if __name__ == '__main__': absltest.main() diff --git a/hackable_diffusion/lib/architecture/discrete.py b/hackable_diffusion/lib/architecture/discrete.py index 77d93ac..d8521a3 100644 --- a/hackable_diffusion/lib/architecture/discrete.py +++ b/hackable_diffusion/lib/architecture/discrete.py @@ -33,6 +33,7 @@ ConditionalBackbone = arch_typing.ConditionalBackbone ConditioningMechanism = arch_typing.ConditioningMechanism +ConditioningEmbeddings = arch_typing.ConditioningEmbeddings ################################################################################ # MARK: Token Embedder @@ -213,7 +214,7 @@ def __post_init__(self): def __call__( self, x: Int['batch *other 1'], - conditioning_embeddings: dict[ConditioningMechanism, Float['batch ...']], + conditioning_embeddings: ConditioningEmbeddings, is_training: bool, ) -> Float['batch *other V']: diff --git a/hackable_diffusion/lib/architecture/dit.py b/hackable_diffusion/lib/architecture/dit.py index fec0891..7118153 100644 --- a/hackable_diffusion/lib/architecture/dit.py +++ b/hackable_diffusion/lib/architecture/dit.py @@ -40,6 +40,7 @@ ConditionalBackbone = arch_typing.ConditionalBackbone ConditioningMechanism = arch_typing.ConditioningMechanism +ConditioningEmbeddings = arch_typing.ConditioningEmbeddings NormalizationType = arch_typing.NormalizationType ################################################################################ @@ -100,7 +101,7 @@ def setup(self): def __call__( self, x: DataArray, - conditioning_embeddings: dict[ConditioningMechanism, Float["batch ..."]], + conditioning_embeddings: ConditioningEmbeddings, is_training: bool, ) -> DataArray: adaptive_norm_emb = conditioning_embeddings.get( diff --git a/hackable_diffusion/lib/architecture/mlp.py b/hackable_diffusion/lib/architecture/mlp.py index f2c2ab0..8d7d585 100644 --- a/hackable_diffusion/lib/architecture/mlp.py +++ b/hackable_diffusion/lib/architecture/mlp.py @@ -38,6 +38,7 @@ ConditionalBackbone = arch_typing.ConditionalBackbone ConditioningMechanism = arch_typing.ConditioningMechanism +ConditioningEmbeddings = arch_typing.ConditioningEmbeddings ################################################################################ # MARK: ConditionalMLP @@ -77,7 +78,7 @@ class ConditionalMLP(nn.Module, ConditionalBackbone): def __call__( self, x: DataArray, - conditioning_embeddings: dict[ConditioningMechanism, Float['batch ...']], + conditioning_embeddings: ConditioningEmbeddings, *, is_training: bool, ) -> DataArray: diff --git a/hackable_diffusion/lib/architecture/riemannian.py b/hackable_diffusion/lib/architecture/riemannian.py index 225e88e..a81677b 100644 --- a/hackable_diffusion/lib/architecture/riemannian.py +++ b/hackable_diffusion/lib/architecture/riemannian.py @@ -23,6 +23,7 @@ ################################################################################ ConditionalBackbone = arch_typing.ConditionalBackbone +ConditioningEmbeddings = arch_typing.ConditioningEmbeddings class RiemannianConditionalBackbone(nn.Module, ConditionalBackbone): @@ -35,9 +36,15 @@ class RiemannianConditionalBackbone(nn.Module, ConditionalBackbone): manifold: manifolds.Manifold @nn.compact - def __call__(self, x, conditioning_embeddings, is_training=True): - - v = self.backbone(x, conditioning_embeddings, is_training=is_training) + def __call__( + self, x, conditioning_embeddings: ConditioningEmbeddings, is_training=True + ): + + v = self.backbone( + x, + conditioning_embeddings=conditioning_embeddings, + is_training=is_training, + ) # Project v to tangent space at xt. if isinstance(v, dict) and 'velocity' in v: diff --git a/hackable_diffusion/lib/architecture/unet.py b/hackable_diffusion/lib/architecture/unet.py index d8159c5..cef7426 100644 --- a/hackable_diffusion/lib/architecture/unet.py +++ b/hackable_diffusion/lib/architecture/unet.py @@ -39,6 +39,7 @@ SkipConnectionMethod = arch_typing.SkipConnectionMethod ConditionalBackbone = arch_typing.ConditionalBackbone ConditioningMechanism = arch_typing.ConditioningMechanism +ConditioningEmbeddings = arch_typing.ConditioningEmbeddings ################################################################################ # MARK: Unet @@ -163,7 +164,7 @@ def setup(self): def __call__( self, x: Float["batch height width channels"], - conditioning_embeddings: dict[ConditioningMechanism, Float["batch ..."]], + conditioning_embeddings: ConditioningEmbeddings, *, is_training: bool, ) -> Float["batch height width output_channels"]: diff --git a/hackable_diffusion/lib/test_helpers.py b/hackable_diffusion/lib/test_helpers.py index fd2b307..2c834c4 100644 --- a/hackable_diffusion/lib/test_helpers.py +++ b/hackable_diffusion/lib/test_helpers.py @@ -91,9 +91,7 @@ class IdentityBackbone(nn.Module, arch_typing.ConditionalBackbone): def __call__( self, x: arch_typing.DataTree, - conditioning_embeddings: dict[ - arch_typing.ConditioningMechanism, Float['batch ...'] - ], + conditioning_embeddings: arch_typing.ConditioningEmbeddings, is_training: bool, ) -> arch_typing.DataTree: return x