From aa87fb50608b3e3785be350293fb7511b8e6c813 Mon Sep 17 00:00:00 2001 From: Hackable Diffusion Authors Date: Wed, 13 May 2026 10:43:15 -0700 Subject: [PATCH] Add NestedSelfConditioningDiffusionNetwork for multi-modal discrete self-conditioning. PiperOrigin-RevId: 914943848 --- hackable_diffusion/lib/multimodal.py | 167 +++++++++++++++++++++- hackable_diffusion/lib/multimodal_test.py | 110 ++++++++++++++ 2 files changed, 276 insertions(+), 1 deletion(-) diff --git a/hackable_diffusion/lib/multimodal.py b/hackable_diffusion/lib/multimodal.py index 1c190cd..38110ce 100644 --- a/hackable_diffusion/lib/multimodal.py +++ b/hackable_diffusion/lib/multimodal.py @@ -53,8 +53,10 @@ from typing import cast import flax.linen as nn +from hackable_diffusion.lib import diffusion_network from hackable_diffusion.lib import hd_typing from hackable_diffusion.lib import jax_helpers +from hackable_diffusion.lib.architecture import arch_typing from hackable_diffusion.lib.architecture import conditioning_encoder from hackable_diffusion.lib.corruption import base as corruption_base from hackable_diffusion.lib.inference import guidance as guidance_lib @@ -71,6 +73,7 @@ # MARK: Type Aliases ################################################################################ +DType = hd_typing.DType PRNGKey = hd_typing.PRNGKey PyTree = hd_typing.PyTree @@ -531,7 +534,9 @@ def __call__(self, key: PRNGKey, data_spec: DataTree) -> TimeTree: def _call_sampler(key, sampler, data_spec): return sampler(key, data_spec) - return jax_helpers.tree_map_with_key(_call_sampler, key, self.samplers, data_spec) + return jax_helpers.tree_map_with_key( + _call_sampler, key, self.samplers, data_spec + ) @dataclasses.dataclass(kw_only=True, frozen=True) @@ -563,3 +568,163 @@ def _call_sampler(sampler, data_spec): return sampler(key, data_spec) return jax.tree.map(_call_sampler, self.samplers, data_spec) + + +################################################################################ +# MARK: NestedSelfConditioningDiffusionNetwork +################################################################################ + + +class NestedSelfConditioningDiffusionNetwork( + nn.Module, diffusion_network.BaseDiffusionNetwork +): + """Multi-modal DiffusionNetwork with self-conditioning on predicted logits. + + This class generalizes `SelfConditioningDiffusionNetwork` to PyTree data. + It assumes ALL modalities in the PyTree are discrete and require + self-conditioning. + + Attributes: + backbone_network: The backbone network to use. Must accept PyTree inputs + where each leaf has concatenated logits. + conditioning_encoder: The conditioning encoder to use. + prediction_type: PyTree of strings, all must be 'logits'. + processes: A NestedProcess containing the corruption processes for each + modality (used to get `num_categories`). + self_cond_prob: Probability of applying self-conditioning during training. + data_dtype: PyTree of dtypes for each modality. + input_rescaler: Optional PyTree of input rescalers. + time_rescaler: Optional PyTree of time rescalers. + rng_collection: PRNG collection name for the self-conditioning mask. + """ + + backbone_network: arch_typing.ConditionalBackbone + conditioning_encoder: conditioning_encoder.BaseConditioningEncoder + prediction_type: PyTree[str] + processes: NestedProcess + self_cond_prob: float = 0.5 + data_dtype: PyTree[DType] = jnp.float32 + input_rescaler: PyTree[diffusion_network.InputRescaler | None] | None = None + time_rescaler: PyTree[diffusion_network.TimeRescaler | None] | None = None + rng_collection: str = 'self_conditioning' + + def __post_init__(self): + super().__post_init__() + + # Verify all prediction types are 'logits' + def _check_logits(pred_type): + if pred_type != 'logits': + raise ValueError( + f"All prediction types must be 'logits', got {pred_type}" + ) + + jax.tree.map(_check_logits, self.prediction_type) + + @nn.compact + @kt.typechecked + def __call__( + self, + time: TimeTree, + xt: DataTree, + conditioning: Conditioning | None, + is_training: bool, + ) -> TargetInfoTree: + + # 1. Rescale time and input (handling PyTrees) + if self.time_rescaler is not None: + time_rescaled = jax_helpers.lenient_map( + lambda t, tr: tr(t) if tr is not None else t, time, self.time_rescaler + ) + else: + time_rescaled = time + + if self.input_rescaler is not None: + xt_rescaled = jax_helpers.lenient_map( + lambda t, x, ir: ir(t, x) if ir is not None else x, + time, + xt, + self.input_rescaler, + ) + else: + xt_rescaled = xt + + # 2. Encode conditioning + conditioning_embeddings = cast(nn.Module, self.conditioning_encoder).copy( + name='ConditioningEncoder' + )( + time=time_rescaled, + conditioning=conditioning, + is_training=is_training, + ) + + # 3. Create zero logits for each leaf in the data tree + def _create_zero_logits(xt_leaf, process_leaf): + # Assumes process_leaf has `num_categories` + return jnp.zeros( + xt_leaf.shape[:-1] + (process_leaf.num_categories,), + dtype=xt_leaf.dtype, + ) + + zero_logits = jax.tree.map( + _create_zero_logits, xt_rescaled, self.processes.processes + ) + + # 4. First pass: run with zero logits + xt_with_zeros = jax.tree.map( + lambda x, z: jnp.concatenate([x, z], axis=-1), xt_rescaled, zero_logits + ) + + backbone_module = cast(nn.Module, self.backbone_network).copy( + name='Backbone' + ) + + first_output = backbone_module( + x=xt_with_zeros, + conditioning_embeddings=conditioning_embeddings, + is_training=is_training, + ) + + x0_hat_logits = jax.tree.map(jax.lax.stop_gradient, first_output) + + # 5. Apply self-conditioning mask during training + if is_training: + # We assume a global mask for the entire batch across all modalities + # Find a leaf to get the batch size + flat_xt, _ = jax.tree_util.tree_flatten(xt) + batch_size = flat_xt[0].shape[0] + + do_self_cond = ( + jax.random.uniform( + self.make_rng(self.rng_collection), shape=(batch_size,) + ) + < self.self_cond_prob + ) + + def _apply_mask(logits_leaf, zero_leaf): + # Broadcast mask to match leaf dimensions + mask = do_self_cond.reshape( + (batch_size,) + (1,) * (logits_leaf.ndim - 1) + ) + return jnp.where(mask, logits_leaf, zero_leaf) + + x0_hat_logits = jax.tree.map(_apply_mask, x0_hat_logits, zero_logits) + + # 6. Second pass: run with predicted logits concatenated + xt_with_x0_hat_logits = jax.tree.map( + lambda x, l: jnp.concatenate([x, l], axis=-1), + xt_rescaled, + x0_hat_logits, + ) + + backbone_outputs = backbone_module( + x=xt_with_x0_hat_logits, + conditioning_embeddings=conditioning_embeddings, + is_training=is_training, + ) + + # 7. Wrap outputs in prediction type structure + return jax_helpers.lenient_map( + lambda out, pred_type: {pred_type: out}, + backbone_outputs, + self.prediction_type, + ) diff --git a/hackable_diffusion/lib/multimodal_test.py b/hackable_diffusion/lib/multimodal_test.py index 2e54845..53f37ad 100644 --- a/hackable_diffusion/lib/multimodal_test.py +++ b/hackable_diffusion/lib/multimodal_test.py @@ -519,5 +519,115 @@ def test_joint_nested_time_sampler(self): chex.assert_trees_all_equal_structs(time, data_spec) +class NestedSelfConditioningDiffusionNetworkTest(parameterized.TestCase): + """Tests for NestedSelfConditioningDiffusionNetwork. + + Uses a DummyBackbone that distinguishes the two forward passes: + - Pass 1 (second half is zeros): returns first_half + 1.0 + - Pass 2 (second half has self-cond data): returns second_half * 10.0 + + Expected output encodes the self-conditioning semantics: + - 10.0 → self-conditioning is active and passes the correct prediction + - 1.0 → self-conditioning is inactive (both passes see zeros) + """ + + @parameterized.named_parameters( + ('training_with_self_cond', 1.0, True, 10.0), + ('inference_always_self_conds', 0.0, False, 10.0), + ('training_no_self_cond', 0.0, True, 1.0), + ) + def test_nested_self_conditioning_network( + self, + self_cond_prob: float, + is_training: bool, + expected_value: float, + ): + num_categories = 10 + batch_size = 2 + + class DummyBackbone(nn.Module, arch_typing.ConditionalBackbone): + """Backbone that reads and uses the self-conditioning logits.""" + + @nn.compact + def __call__(self, x, conditioning_embeddings, is_training): + + def _forward(arr): + first_half = arr[..., : arr.shape[-1] // 2] + second_half = arr[..., arr.shape[-1] // 2 :] + has_self_cond = jnp.any(second_half != 0) + # Pass 1 (zeros in second half): return first_half + 1.0 + # Pass 2 (prediction in second half): return second_half * 10.0 + return jnp.where(has_self_cond, second_half * 10.0, first_half + 1.0) + + return jax.tree.map(_forward, x) + + mock_process = mock.create_autospec( + discrete.CategoricalProcess, instance=True + ) + mock_process.num_categories = num_categories + + processes = multimodal.NestedProcess( + processes={'a': mock_process, 'b': {'c': mock_process}} + ) + + xt = { + 'a': jnp.zeros((batch_size, 4, num_categories)), + 'b': {'c': jnp.zeros((batch_size, num_categories))}, + } + time = { + 'a': jnp.zeros((batch_size, 1)), + 'b': {'c': jnp.zeros((batch_size, 1))}, + } + + class DummyCondEncoder(nn.Module): + + @nn.compact + def __call__(self, time, conditioning, is_training): + return {} + + network = multimodal.NestedSelfConditioningDiffusionNetwork( + backbone_network=DummyBackbone(), + conditioning_encoder=DummyCondEncoder(), + prediction_type={'a': 'logits', 'b': {'c': 'logits'}}, + processes=processes, + self_cond_prob=self_cond_prob, + ) + + key = jax.random.PRNGKey(0) + params_key, sc_key = jax.random.split(key) + + variables = network.init( + {'params': params_key, 'self_conditioning': sc_key}, + time=time, + xt=xt, + conditioning=None, + is_training=is_training, + ) + + apply_kwargs = dict( + time=time, + xt=xt, + conditioning=None, + is_training=is_training, + ) + if is_training: + apply_kwargs['rngs'] = {'self_conditioning': sc_key} + + output = network.apply(variables, **apply_kwargs) + + expected_output = { + 'a': { + 'logits': jnp.full((batch_size, 4, num_categories), expected_value) + }, + 'b': { + 'c': { + 'logits': jnp.full((batch_size, num_categories), expected_value) + } + }, + } + + chex.assert_trees_all_close(output, expected_output) + + if __name__ == '__main__': absltest.main()