Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 25 additions & 12 deletions hackable_diffusion/lib/architecture/arch_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
"""

import enum
from typing import Any
from typing import Callable, Protocol
from hackable_diffusion.lib import hd_typing
import jax
Expand Down Expand Up @@ -50,17 +51,6 @@ class EmbeddingMergeMethod(enum.StrEnum):
CONCAT = "concat"


class ConditioningMechanism(enum.StrEnum):
"""Types of conditioning mechanisms."""

ADAPTIVE_NORM = "adaptive_norm"
CROSS_ATTENTION = "cross_attention"
CONCATENATE = "concatenate"
SUM = "sum"
SELF_CONDITIONING = "self_conditioning"
CUSTOM = "custom"


class RoPEPositionType(enum.StrEnum):
"""Rotary Position Embedding (RoPE) types."""

Expand Down Expand Up @@ -95,6 +85,29 @@ class SkipConnectionMethod(enum.StrEnum):
NORMALIZED_ADD = "normalized_add"


################################################################################
# MARK: Conditioning Mechanism
################################################################################


class ConditioningMechanism(enum.StrEnum):
"""Types of conditioning mechanisms."""

ADAPTIVE_NORM = "adaptive_norm"
CROSS_ATTENTION = "cross_attention"
CONCATENATE = "concatenate"
SUM = "sum"
SELF_CONDITIONING = "self_conditioning"
CUSTOM = "custom"


# Conditioning embeddings corresponds to a dictionary with keys corresponding to
# the specification of a conditioning mechanism. We use `ConditioningMechanism`
# as a reference for the most common conditioning mechanisms, but the type
# structure is more flexible to allow for more general conditioning mechanisms.
ConditioningEmbeddings = dict[str, Any]


################################################################################
# MARK: Types and protocols
################################################################################
Expand All @@ -108,7 +121,7 @@ class ConditionalBackbone(Protocol):
def __call__(
self,
x: DataTree,
conditioning_embeddings: dict[ConditioningMechanism, Float["batch ..."]],
conditioning_embeddings: ConditioningEmbeddings,
is_training: bool,
) -> DataTree:
...
Expand Down
29 changes: 26 additions & 3 deletions hackable_diffusion/lib/architecture/conditioning_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@

ConditioningMechanism = arch_typing.ConditioningMechanism
EmbeddingMergeMethod = arch_typing.EmbeddingMergeMethod
ConditioningEmbeddings = arch_typing.ConditioningEmbeddings

################################################################################
# MARK: Base classes
Expand Down Expand Up @@ -76,7 +77,7 @@ def __call__(
time: hd_typing.TimeArray,
conditioning: hd_typing.Conditioning | None,
is_training: bool,
) -> dict[ConditioningMechanism, Float['batch ...']]:
) -> ConditioningEmbeddings:
...


Expand Down Expand Up @@ -342,6 +343,28 @@ def __call__(
################################################################################


class CopyConditioningEncoder(nn.Module, BaseConditioningEncoder):
"""Copies the conditioning to the output and optionally encodes time."""

time_embedder: BaseTimeEmbedder | None = None

@nn.compact
def __call__(
self,
time: hd_typing.TimeArray,
conditioning: hd_typing.Conditioning | None,
is_training: bool,
) -> ConditioningEmbeddings:
cond_embs = {}
if self.time_embedder is not None:
cond_embs['time'] = cast(nn.Module, self.time_embedder).copy(
name='TimeEmbedder'
)(time)
if conditioning is not None:
cond_embs.update(conditioning)
return cond_embs


class ConditioningEncoder(nn.Module, BaseConditioningEncoder):
"""Encodes and combines time and conditioning signals for a diffusion model.

Expand Down Expand Up @@ -392,7 +415,7 @@ class ConditioningEncoder(nn.Module, BaseConditioningEncoder):
time_embedder: BaseTimeEmbedder
conditioning_embedders: dict[str, BaseEmbedder]
embedding_merging_method: EmbeddingMergeMethod
conditioning_rules: dict[str, ConditioningMechanism]
conditioning_rules: dict[str, str]
conditioning_dropout_rate: float = 0.0

def setup(self):
Expand Down Expand Up @@ -427,7 +450,7 @@ def __call__(
time: hd_typing.TimeTree,
conditioning: hd_typing.Conditioning | None,
is_training: bool,
) -> dict[ConditioningMechanism, Num['batch ...']]:
) -> ConditioningEmbeddings:
"""Encodes and combines time and conditioning signals.

The output is a dictionary where keys are the embedding mechanisms specified
Expand Down
74 changes: 74 additions & 0 deletions hackable_diffusion/lib/architecture/conditioning_encoder_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -699,6 +699,80 @@ def test_dropout(self):
jnp.all(output_eval['adaptive_norm'] == time_embedding_train)
)

def test_copy_conditioning_encoder(self):
"""Tests the CopyConditioningEncoder."""
batch_size = 4
encoder = conditioning_encoder.CopyConditioningEncoder()
t = jnp.ones((batch_size,))
conditioning = {
'label': jnp.arange(batch_size),
'arbitrary_pytree': dict(
a=jnp.ones((batch_size, 127, 199)),
b=dict(c=jnp.ones((batch_size, 127, 199))),
),
}
rng = jax.random.PRNGKey(0)
variables = encoder.init(rng, t, conditioning, is_training=True)
jitted_apply = jax.jit(encoder.apply, static_argnames=['is_training'])
output = jitted_apply(
variables,
t,
conditioning,
is_training=True,
rngs={'dropout': rng},
)
# Compare the pytree structure and leaves individually.
output_leaves, output_treedef = jax.tree.flatten(output)
cond_leaves, cond_treedef = jax.tree.flatten(conditioning)
self.assertEqual(output_treedef, cond_treedef)
for out_leaf, cond_leaf in zip(output_leaves, cond_leaves):
self.assertTrue(jnp.array_equal(out_leaf, cond_leaf))

def test_copy_conditioning_encoder_with_time_embedder(self):
"""Tests the CopyConditioningEncoder."""
batch_size = 4
num_features = 17
time_embedder = conditioning_encoder.SinusoidalTimeEmbedder(
activation='silu',
embedding_dim=5,
num_features=num_features,
)

encoder = conditioning_encoder.CopyConditioningEncoder(
time_embedder=time_embedder
)
t = jnp.ones((batch_size,))
conditioning = {
'label': jnp.arange(batch_size),
'arbitrary_pytree': dict(
a=jnp.ones((batch_size, 127, 199)),
b=dict(c=jnp.ones((batch_size, 127, 199))),
),
}
rng = jax.random.PRNGKey(0)
variables = encoder.init(rng, t, conditioning, is_training=True)
jitted_apply = jax.jit(encoder.apply, static_argnames=['is_training'])
output = jitted_apply(
variables,
t,
conditioning,
is_training=True,
rngs={'dropout': rng},
)

# Verify all conditioning fields are copied through.
for key in conditioning:
self.assertIn(key, output)
out_leaves, out_td = jax.tree.flatten(output[key])
cond_leaves, cond_td = jax.tree.flatten(conditioning[key])
self.assertEqual(out_td, cond_td)
for out_leaf, cond_leaf in zip(out_leaves, cond_leaves):
self.assertTrue(jnp.array_equal(out_leaf, cond_leaf))

# Verify the time embedding shape.
self.assertIn('time', output)
self.assertEqual(output['time'].shape, (batch_size, num_features))


if __name__ == '__main__':
absltest.main()
3 changes: 2 additions & 1 deletion hackable_diffusion/lib/architecture/discrete.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@

ConditionalBackbone = arch_typing.ConditionalBackbone
ConditioningMechanism = arch_typing.ConditioningMechanism
ConditioningEmbeddings = arch_typing.ConditioningEmbeddings

################################################################################
# MARK: Token Embedder
Expand Down Expand Up @@ -213,7 +214,7 @@ def __post_init__(self):
def __call__(
self,
x: Int['batch *other 1'],
conditioning_embeddings: dict[ConditioningMechanism, Float['batch ...']],
conditioning_embeddings: ConditioningEmbeddings,
is_training: bool,
) -> Float['batch *other V']:

Expand Down
3 changes: 2 additions & 1 deletion hackable_diffusion/lib/architecture/dit.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@

ConditionalBackbone = arch_typing.ConditionalBackbone
ConditioningMechanism = arch_typing.ConditioningMechanism
ConditioningEmbeddings = arch_typing.ConditioningEmbeddings
NormalizationType = arch_typing.NormalizationType

################################################################################
Expand Down Expand Up @@ -100,7 +101,7 @@ def setup(self):
def __call__(
self,
x: DataArray,
conditioning_embeddings: dict[ConditioningMechanism, Float["batch ..."]],
conditioning_embeddings: ConditioningEmbeddings,
is_training: bool,
) -> DataArray:
adaptive_norm_emb = conditioning_embeddings.get(
Expand Down
3 changes: 2 additions & 1 deletion hackable_diffusion/lib/architecture/mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@

ConditionalBackbone = arch_typing.ConditionalBackbone
ConditioningMechanism = arch_typing.ConditioningMechanism
ConditioningEmbeddings = arch_typing.ConditioningEmbeddings

################################################################################
# MARK: ConditionalMLP
Expand Down Expand Up @@ -77,7 +78,7 @@ class ConditionalMLP(nn.Module, ConditionalBackbone):
def __call__(
self,
x: DataArray,
conditioning_embeddings: dict[ConditioningMechanism, Float['batch ...']],
conditioning_embeddings: ConditioningEmbeddings,
*,
is_training: bool,
) -> DataArray:
Expand Down
13 changes: 10 additions & 3 deletions hackable_diffusion/lib/architecture/riemannian.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
################################################################################

ConditionalBackbone = arch_typing.ConditionalBackbone
ConditioningEmbeddings = arch_typing.ConditioningEmbeddings


class RiemannianConditionalBackbone(nn.Module, ConditionalBackbone):
Expand All @@ -35,9 +36,15 @@ class RiemannianConditionalBackbone(nn.Module, ConditionalBackbone):
manifold: manifolds.Manifold

@nn.compact
def __call__(self, x, conditioning_embeddings, is_training=True):

v = self.backbone(x, conditioning_embeddings, is_training=is_training)
def __call__(
self, x, conditioning_embeddings: ConditioningEmbeddings, is_training=True
):

v = self.backbone(
x,
conditioning_embeddings=conditioning_embeddings,
is_training=is_training,
)

# Project v to tangent space at xt.
if isinstance(v, dict) and 'velocity' in v:
Expand Down
3 changes: 2 additions & 1 deletion hackable_diffusion/lib/architecture/unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
SkipConnectionMethod = arch_typing.SkipConnectionMethod
ConditionalBackbone = arch_typing.ConditionalBackbone
ConditioningMechanism = arch_typing.ConditioningMechanism
ConditioningEmbeddings = arch_typing.ConditioningEmbeddings

################################################################################
# MARK: Unet
Expand Down Expand Up @@ -163,7 +164,7 @@ def setup(self):
def __call__(
self,
x: Float["batch height width channels"],
conditioning_embeddings: dict[ConditioningMechanism, Float["batch ..."]],
conditioning_embeddings: ConditioningEmbeddings,
*,
is_training: bool,
) -> Float["batch height width output_channels"]:
Expand Down
4 changes: 1 addition & 3 deletions hackable_diffusion/lib/test_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,9 +91,7 @@ class IdentityBackbone(nn.Module, arch_typing.ConditionalBackbone):
def __call__(
self,
x: arch_typing.DataTree,
conditioning_embeddings: dict[
arch_typing.ConditioningMechanism, Float['batch ...']
],
conditioning_embeddings: arch_typing.ConditioningEmbeddings,
is_training: bool,
) -> arch_typing.DataTree:
return x
Loading