Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
164 changes: 164 additions & 0 deletions hackable_diffusion/lib/multimodal.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,10 @@
from typing import cast

import flax.linen as nn
from hackable_diffusion.lib import diffusion_network
from hackable_diffusion.lib import hd_typing
from hackable_diffusion.lib import utils
from hackable_diffusion.lib.architecture import arch_typing
from hackable_diffusion.lib.architecture import conditioning_encoder
from hackable_diffusion.lib.corruption import base as corruption_base
from hackable_diffusion.lib.inference import guidance as guidance_lib
Expand All @@ -71,6 +73,7 @@
# MARK: Type Aliases
################################################################################

DType = hd_typing.DType
PRNGKey = hd_typing.PRNGKey
PyTree = hd_typing.PyTree

Expand Down Expand Up @@ -563,3 +566,164 @@ def _call_sampler(sampler, data_spec):
return sampler(key, data_spec)

return jax.tree.map(_call_sampler, self.samplers, data_spec)


################################################################################
# MARK: NestedSelfConditioningDiffusionNetwork
################################################################################


class NestedSelfConditioningDiffusionNetwork(
nn.Module, diffusion_network.BaseDiffusionNetwork
):
"""Multi-modal DiffusionNetwork with self-conditioning on predicted logits.

This class generalizes `SelfConditioningDiffusionNetwork` to PyTree data.
It assumes ALL modalities in the PyTree are discrete and require
self-conditioning.

Attributes:
backbone_network: The backbone network to use. Must accept PyTree inputs
where each leaf has concatenated logits.
conditioning_encoder: The conditioning encoder to use.
prediction_type: PyTree of strings, all must be 'logits'.
processes: A NestedProcess containing the corruption processes for each
modality (used to get `num_categories`).
self_cond_prob: Probability of applying self-conditioning during training.
data_dtype: PyTree of dtypes for each modality.
input_rescaler: Optional PyTree of input rescalers.
time_rescaler: Optional PyTree of time rescalers.
rng_collection: PRNG collection name for the self-conditioning mask.
"""

backbone_network: arch_typing.ConditionalBackbone
conditioning_encoder: conditioning_encoder.BaseConditioningEncoder
prediction_type: PyTree[str]
processes: NestedProcess
self_cond_prob: float = 0.5
data_dtype: PyTree[DType] = jnp.float32
input_rescaler: PyTree[diffusion_network.InputRescaler | None] | None = None
time_rescaler: PyTree[diffusion_network.TimeRescaler | None] | None = None
rng_collection: str = 'self_conditioning'

def __post_init__(self):
super().__post_init__()

# Verify all prediction types are 'logits'
def _check_logits(pred_type):
if pred_type != 'logits':
raise ValueError(
f"All prediction types must be 'logits', got {pred_type}"
)

jax.tree.map(_check_logits, self.prediction_type)

@nn.compact
@kt.typechecked
def __call__(
self,
time: TimeTree,
xt: DataTree,
conditioning: Conditioning | None,
is_training: bool,
) -> TargetInfoTree:

# 1. Rescale time and input (handling PyTrees)
if self.time_rescaler is not None:
time_rescaled = utils.lenient_map(
lambda t, tr: tr(t) if tr is not None else t, time, self.time_rescaler
)
else:
time_rescaled = time

if self.input_rescaler is not None:
xt_rescaled = utils.lenient_map(
lambda t, x, ir: ir(t, x) if ir is not None else x,
time,
xt,
self.input_rescaler,
)
else:
xt_rescaled = xt

# 2. Encode conditioning
conditioning_embeddings = cast(nn.Module, self.conditioning_encoder).copy(
name='ConditioningEncoder'
)(
time=time_rescaled,
conditioning=conditioning,
is_training=is_training,
)

# 3. Create zero logits for each leaf in the data tree
def _create_zero_logits(xt_leaf, process_leaf):
# Assumes process_leaf has `num_categories`
return jnp.zeros(
xt_leaf.shape[:-1] + (process_leaf.num_categories,),
dtype=xt_leaf.dtype,
)

zero_logits = jax.tree.map(
_create_zero_logits, xt_rescaled, self.processes.processes
)

# 4. First pass: run with zero logits
xt_with_zeros = jax.tree.map(
lambda x, z: jnp.concatenate([x, z], axis=-1), xt_rescaled, zero_logits
)

backbone_module = cast(nn.Module, self.backbone_network).copy(
name='Backbone'
)

first_output = backbone_module(
x=xt_with_zeros,
conditioning_embeddings=conditioning_embeddings,
is_training=is_training,
)

x0_hat_logits = jax.tree.map(jax.lax.stop_gradient, first_output)

# 5. Apply self-conditioning mask during training
if is_training:
# We assume a global mask for the entire batch across all modalities
# Find a leaf to get the batch size
flat_xt, _ = jax.tree_util.tree_flatten(xt)
batch_size = flat_xt[0].shape[0]

do_self_cond = (
jax.random.uniform(
self.make_rng(self.rng_collection), shape=(batch_size,)
)
< self.self_cond_prob
)

def _apply_mask(logits_leaf, zero_leaf):
# Broadcast mask to match leaf dimensions
mask = do_self_cond.reshape(
(batch_size,) + (1,) * (logits_leaf.ndim - 1)
)
return jnp.where(mask, logits_leaf, zero_leaf)

x0_hat_logits = jax.tree.map(_apply_mask, x0_hat_logits, zero_logits)

# 6. Second pass: run with predicted logits concatenated
xt_with_x0_hat_logits = jax.tree.map(
lambda x, l: jnp.concatenate([x, l], axis=-1),
xt_rescaled,
x0_hat_logits,
)

backbone_outputs = backbone_module(
x=xt_with_x0_hat_logits,
conditioning_embeddings=conditioning_embeddings,
is_training=is_training,
)

# 7. Wrap outputs in prediction type structure
outputs = utils.lenient_map(
lambda out, pred_type: {pred_type: out},
backbone_outputs,
self.prediction_type,
)
return outputs
Loading
Loading