diff --git a/hackable_diffusion/lib/architecture/arch_typing.py b/hackable_diffusion/lib/architecture/arch_typing.py index 0a8cd8d..5a8264d 100644 --- a/hackable_diffusion/lib/architecture/arch_typing.py +++ b/hackable_diffusion/lib/architecture/arch_typing.py @@ -19,6 +19,7 @@ """ import enum +from typing import Any from typing import Callable, Protocol from hackable_diffusion.lib import hd_typing import jax @@ -50,17 +51,6 @@ class EmbeddingMergeMethod(enum.StrEnum): CONCAT = "concat" -class ConditioningMechanism(enum.StrEnum): - """Types of conditioning mechanisms.""" - - ADAPTIVE_NORM = "adaptive_norm" - CROSS_ATTENTION = "cross_attention" - CONCATENATE = "concatenate" - SUM = "sum" - SELF_CONDITIONING = "self_conditioning" - CUSTOM = "custom" - - class RoPEPositionType(enum.StrEnum): """Rotary Position Embedding (RoPE) types.""" @@ -95,6 +85,29 @@ class SkipConnectionMethod(enum.StrEnum): NORMALIZED_ADD = "normalized_add" +################################################################################ +# MARK: Conditioning Mechanism +################################################################################ + + +class ConditioningMechanism(enum.StrEnum): + """Types of conditioning mechanisms.""" + + ADAPTIVE_NORM = "adaptive_norm" + CROSS_ATTENTION = "cross_attention" + CONCATENATE = "concatenate" + SUM = "sum" + SELF_CONDITIONING = "self_conditioning" + CUSTOM = "custom" + + +# Conditioning embeddings corresponds to a dictionary with keys corresponding to +# the specification of a conditioning mechanism. We use `ConditioningMechanism` +# as a reference for the most common conditioning mechanisms, but the type +# structure is more flexible to allow for more general conditioning mechanisms. +ConditioningEmbeddings = dict[str, Any] + + ################################################################################ # MARK: Types and protocols ################################################################################ @@ -108,7 +121,7 @@ class ConditionalBackbone(Protocol): def __call__( self, x: DataTree, - conditioning_embeddings: dict[ConditioningMechanism, Float["batch ..."]], + conditioning_embeddings: ConditioningEmbeddings, is_training: bool, ) -> DataTree: ... diff --git a/hackable_diffusion/lib/architecture/conditioning_encoder.py b/hackable_diffusion/lib/architecture/conditioning_encoder.py index dc2d973..d613525 100644 --- a/hackable_diffusion/lib/architecture/conditioning_encoder.py +++ b/hackable_diffusion/lib/architecture/conditioning_encoder.py @@ -42,6 +42,7 @@ ConditioningMechanism = arch_typing.ConditioningMechanism EmbeddingMergeMethod = arch_typing.EmbeddingMergeMethod +ConditioningEmbeddings = arch_typing.ConditioningEmbeddings ################################################################################ # MARK: Base classes @@ -76,7 +77,7 @@ def __call__( time: hd_typing.TimeArray, conditioning: hd_typing.Conditioning | None, is_training: bool, - ) -> dict[ConditioningMechanism, Float['batch ...']]: + ) -> ConditioningEmbeddings: ... @@ -392,7 +393,7 @@ class ConditioningEncoder(nn.Module, BaseConditioningEncoder): time_embedder: BaseTimeEmbedder conditioning_embedders: dict[str, BaseEmbedder] embedding_merging_method: EmbeddingMergeMethod - conditioning_rules: dict[str, ConditioningMechanism] + conditioning_rules: dict[str, str] conditioning_dropout_rate: float = 0.0 def setup(self): @@ -427,7 +428,7 @@ def __call__( time: hd_typing.TimeTree, conditioning: hd_typing.Conditioning | None, is_training: bool, - ) -> dict[ConditioningMechanism, Num['batch ...']]: + ) -> ConditioningEmbeddings: """Encodes and combines time and conditioning signals. The output is a dictionary where keys are the embedding mechanisms specified diff --git a/hackable_diffusion/lib/architecture/discrete.py b/hackable_diffusion/lib/architecture/discrete.py index 77d93ac..d8521a3 100644 --- a/hackable_diffusion/lib/architecture/discrete.py +++ b/hackable_diffusion/lib/architecture/discrete.py @@ -33,6 +33,7 @@ ConditionalBackbone = arch_typing.ConditionalBackbone ConditioningMechanism = arch_typing.ConditioningMechanism +ConditioningEmbeddings = arch_typing.ConditioningEmbeddings ################################################################################ # MARK: Token Embedder @@ -213,7 +214,7 @@ def __post_init__(self): def __call__( self, x: Int['batch *other 1'], - conditioning_embeddings: dict[ConditioningMechanism, Float['batch ...']], + conditioning_embeddings: ConditioningEmbeddings, is_training: bool, ) -> Float['batch *other V']: diff --git a/hackable_diffusion/lib/architecture/dit.py b/hackable_diffusion/lib/architecture/dit.py index fec0891..7118153 100644 --- a/hackable_diffusion/lib/architecture/dit.py +++ b/hackable_diffusion/lib/architecture/dit.py @@ -40,6 +40,7 @@ ConditionalBackbone = arch_typing.ConditionalBackbone ConditioningMechanism = arch_typing.ConditioningMechanism +ConditioningEmbeddings = arch_typing.ConditioningEmbeddings NormalizationType = arch_typing.NormalizationType ################################################################################ @@ -100,7 +101,7 @@ def setup(self): def __call__( self, x: DataArray, - conditioning_embeddings: dict[ConditioningMechanism, Float["batch ..."]], + conditioning_embeddings: ConditioningEmbeddings, is_training: bool, ) -> DataArray: adaptive_norm_emb = conditioning_embeddings.get( diff --git a/hackable_diffusion/lib/architecture/mlp.py b/hackable_diffusion/lib/architecture/mlp.py index f2c2ab0..8d7d585 100644 --- a/hackable_diffusion/lib/architecture/mlp.py +++ b/hackable_diffusion/lib/architecture/mlp.py @@ -38,6 +38,7 @@ ConditionalBackbone = arch_typing.ConditionalBackbone ConditioningMechanism = arch_typing.ConditioningMechanism +ConditioningEmbeddings = arch_typing.ConditioningEmbeddings ################################################################################ # MARK: ConditionalMLP @@ -77,7 +78,7 @@ class ConditionalMLP(nn.Module, ConditionalBackbone): def __call__( self, x: DataArray, - conditioning_embeddings: dict[ConditioningMechanism, Float['batch ...']], + conditioning_embeddings: ConditioningEmbeddings, *, is_training: bool, ) -> DataArray: diff --git a/hackable_diffusion/lib/architecture/riemannian.py b/hackable_diffusion/lib/architecture/riemannian.py index 225e88e..a81677b 100644 --- a/hackable_diffusion/lib/architecture/riemannian.py +++ b/hackable_diffusion/lib/architecture/riemannian.py @@ -23,6 +23,7 @@ ################################################################################ ConditionalBackbone = arch_typing.ConditionalBackbone +ConditioningEmbeddings = arch_typing.ConditioningEmbeddings class RiemannianConditionalBackbone(nn.Module, ConditionalBackbone): @@ -35,9 +36,15 @@ class RiemannianConditionalBackbone(nn.Module, ConditionalBackbone): manifold: manifolds.Manifold @nn.compact - def __call__(self, x, conditioning_embeddings, is_training=True): - - v = self.backbone(x, conditioning_embeddings, is_training=is_training) + def __call__( + self, x, conditioning_embeddings: ConditioningEmbeddings, is_training=True + ): + + v = self.backbone( + x, + conditioning_embeddings=conditioning_embeddings, + is_training=is_training, + ) # Project v to tangent space at xt. if isinstance(v, dict) and 'velocity' in v: diff --git a/hackable_diffusion/lib/architecture/unet.py b/hackable_diffusion/lib/architecture/unet.py index d8159c5..cef7426 100644 --- a/hackable_diffusion/lib/architecture/unet.py +++ b/hackable_diffusion/lib/architecture/unet.py @@ -39,6 +39,7 @@ SkipConnectionMethod = arch_typing.SkipConnectionMethod ConditionalBackbone = arch_typing.ConditionalBackbone ConditioningMechanism = arch_typing.ConditioningMechanism +ConditioningEmbeddings = arch_typing.ConditioningEmbeddings ################################################################################ # MARK: Unet @@ -163,7 +164,7 @@ def setup(self): def __call__( self, x: Float["batch height width channels"], - conditioning_embeddings: dict[ConditioningMechanism, Float["batch ..."]], + conditioning_embeddings: ConditioningEmbeddings, *, is_training: bool, ) -> Float["batch height width output_channels"]: diff --git a/hackable_diffusion/lib/test_helpers.py b/hackable_diffusion/lib/test_helpers.py index fd2b307..2c834c4 100644 --- a/hackable_diffusion/lib/test_helpers.py +++ b/hackable_diffusion/lib/test_helpers.py @@ -91,9 +91,7 @@ class IdentityBackbone(nn.Module, arch_typing.ConditionalBackbone): def __call__( self, x: arch_typing.DataTree, - conditioning_embeddings: dict[ - arch_typing.ConditioningMechanism, Float['batch ...'] - ], + conditioning_embeddings: arch_typing.ConditioningEmbeddings, is_training: bool, ) -> arch_typing.DataTree: return x