diff --git a/hackable_diffusion/lib/multimodal.py b/hackable_diffusion/lib/multimodal.py index 5fdd686..62e3e92 100644 --- a/hackable_diffusion/lib/multimodal.py +++ b/hackable_diffusion/lib/multimodal.py @@ -53,8 +53,10 @@ from typing import cast import flax.linen as nn +from hackable_diffusion.lib import diffusion_network from hackable_diffusion.lib import hd_typing from hackable_diffusion.lib import utils +from hackable_diffusion.lib.architecture import arch_typing from hackable_diffusion.lib.architecture import conditioning_encoder from hackable_diffusion.lib.corruption import base as corruption_base from hackable_diffusion.lib.inference import guidance as guidance_lib @@ -71,6 +73,7 @@ # MARK: Type Aliases ################################################################################ +DType = hd_typing.DType PRNGKey = hd_typing.PRNGKey PyTree = hd_typing.PyTree @@ -563,3 +566,164 @@ def _call_sampler(sampler, data_spec): return sampler(key, data_spec) return jax.tree.map(_call_sampler, self.samplers, data_spec) + + +################################################################################ +# MARK: NestedSelfConditioningDiffusionNetwork +################################################################################ + + +class NestedSelfConditioningDiffusionNetwork( + nn.Module, diffusion_network.BaseDiffusionNetwork +): + """Multi-modal DiffusionNetwork with self-conditioning on predicted logits. + + This class generalizes `SelfConditioningDiffusionNetwork` to PyTree data. + It assumes ALL modalities in the PyTree are discrete and require + self-conditioning. + + Attributes: + backbone_network: The backbone network to use. Must accept PyTree inputs + where each leaf has concatenated logits. + conditioning_encoder: The conditioning encoder to use. + prediction_type: PyTree of strings, all must be 'logits'. + processes: A NestedProcess containing the corruption processes for each + modality (used to get `num_categories`). + self_cond_prob: Probability of applying self-conditioning during training. + data_dtype: PyTree of dtypes for each modality. + input_rescaler: Optional PyTree of input rescalers. + time_rescaler: Optional PyTree of time rescalers. + rng_collection: PRNG collection name for the self-conditioning mask. + """ + + backbone_network: arch_typing.ConditionalBackbone + conditioning_encoder: conditioning_encoder.BaseConditioningEncoder + prediction_type: PyTree[str] + processes: NestedProcess + self_cond_prob: float = 0.5 + data_dtype: PyTree[DType] = jnp.float32 + input_rescaler: PyTree[diffusion_network.InputRescaler | None] | None = None + time_rescaler: PyTree[diffusion_network.TimeRescaler | None] | None = None + rng_collection: str = 'self_conditioning' + + def __post_init__(self): + super().__post_init__() + + # Verify all prediction types are 'logits' + def _check_logits(pred_type): + if pred_type != 'logits': + raise ValueError( + f"All prediction types must be 'logits', got {pred_type}" + ) + + jax.tree.map(_check_logits, self.prediction_type) + + @nn.compact + @kt.typechecked + def __call__( + self, + time: TimeTree, + xt: DataTree, + conditioning: Conditioning | None, + is_training: bool, + ) -> TargetInfoTree: + + # 1. Rescale time and input (handling PyTrees) + if self.time_rescaler is not None: + time_rescaled = utils.lenient_map( + lambda t, tr: tr(t) if tr is not None else t, time, self.time_rescaler + ) + else: + time_rescaled = time + + if self.input_rescaler is not None: + xt_rescaled = utils.lenient_map( + lambda t, x, ir: ir(t, x) if ir is not None else x, + time, + xt, + self.input_rescaler, + ) + else: + xt_rescaled = xt + + # 2. Encode conditioning + conditioning_embeddings = cast(nn.Module, self.conditioning_encoder).copy( + name='ConditioningEncoder' + )( + time=time_rescaled, + conditioning=conditioning, + is_training=is_training, + ) + + # 3. Create zero logits for each leaf in the data tree + def _create_zero_logits(xt_leaf, process_leaf): + # Assumes process_leaf has `num_categories` + return jnp.zeros( + xt_leaf.shape[:-1] + (process_leaf.num_categories,), + dtype=xt_leaf.dtype, + ) + + zero_logits = jax.tree.map( + _create_zero_logits, xt_rescaled, self.processes.processes + ) + + # 4. First pass: run with zero logits + xt_with_zeros = jax.tree.map( + lambda x, z: jnp.concatenate([x, z], axis=-1), xt_rescaled, zero_logits + ) + + backbone_module = cast(nn.Module, self.backbone_network).copy( + name='Backbone' + ) + + first_output = backbone_module( + x=xt_with_zeros, + conditioning_embeddings=conditioning_embeddings, + is_training=is_training, + ) + + x0_hat_logits = jax.tree.map(jax.lax.stop_gradient, first_output) + + # 5. Apply self-conditioning mask during training + if is_training: + # We assume a global mask for the entire batch across all modalities + # Find a leaf to get the batch size + flat_xt, _ = jax.tree_util.tree_flatten(xt) + batch_size = flat_xt[0].shape[0] + + do_self_cond = ( + jax.random.uniform( + self.make_rng(self.rng_collection), shape=(batch_size,) + ) + < self.self_cond_prob + ) + + def _apply_mask(logits_leaf, zero_leaf): + # Broadcast mask to match leaf dimensions + mask = do_self_cond.reshape( + (batch_size,) + (1,) * (logits_leaf.ndim - 1) + ) + return jnp.where(mask, logits_leaf, zero_leaf) + + x0_hat_logits = jax.tree.map(_apply_mask, x0_hat_logits, zero_logits) + + # 6. Second pass: run with predicted logits concatenated + xt_with_x0_hat_logits = jax.tree.map( + lambda x, l: jnp.concatenate([x, l], axis=-1), + xt_rescaled, + x0_hat_logits, + ) + + backbone_outputs = backbone_module( + x=xt_with_x0_hat_logits, + conditioning_embeddings=conditioning_embeddings, + is_training=is_training, + ) + + # 7. Wrap outputs in prediction type structure + outputs = utils.lenient_map( + lambda out, pred_type: {pred_type: out}, + backbone_outputs, + self.prediction_type, + ) + return outputs diff --git a/hackable_diffusion/lib/multimodal_test.py b/hackable_diffusion/lib/multimodal_test.py index 6d5eefa..83201a2 100644 --- a/hackable_diffusion/lib/multimodal_test.py +++ b/hackable_diffusion/lib/multimodal_test.py @@ -521,5 +521,228 @@ def test_joint_nested_time_sampler(self): chex.assert_trees_all_equal_structs(time, data_spec) +class NestedSelfConditioningDiffusionNetworkTest(parameterized.TestCase): + + def test_nested_self_conditioning_network(self): + num_categories = 10 + batch_size = 2 + + class DummyBackbone(nn.Module, arch_typing.ConditionalBackbone): + + @nn.compact + def __call__(self, x, conditioning_embeddings, is_training): + return jax.tree.map(lambda arr: arr[..., : arr.shape[-1] // 2] + 1.0, x) + + mock_process = mock.MagicMock() + mock_process.num_categories = num_categories + + processes = multimodal.NestedProcess( + processes={'a': mock_process, 'b': {'c': mock_process}} + ) + + xt = { + 'a': jnp.zeros((batch_size, 4, num_categories)), + 'b': {'c': jnp.zeros((batch_size, num_categories))}, + } + time = { + 'a': jnp.zeros((batch_size, 1)), + 'b': {'c': jnp.zeros((batch_size, 1))}, + } + + class DummyCondEncoder(nn.Module): + + @nn.compact + def __call__(self, time, conditioning, is_training): + return {} + + network = multimodal.NestedSelfConditioningDiffusionNetwork( + backbone_network=DummyBackbone(), + conditioning_encoder=DummyCondEncoder(), + prediction_type={'a': 'logits', 'b': {'c': 'logits'}}, + processes=processes, + self_cond_prob=1.0, # Always self-condition + ) + + key = jax.random.PRNGKey(0) + params_key, sc_key = jax.random.split(key) + + variables = network.init( + {'params': params_key, 'self_conditioning': sc_key}, + time=time, + xt=xt, + conditioning=None, + is_training=True, + ) + + output = network.apply( + variables, + time=time, + xt=xt, + conditioning=None, + is_training=True, + rngs={'self_conditioning': sc_key}, + ) + + self.assertIsInstance(output, dict) + self.assertIn('a', output) + self.assertIn('b', output) + + # Expected output: xt + 1.0 = zeros + 1.0 = ones + # Because DummyBackbone returns first half + 1.0, and first half is xt. + expected_output = { + 'a': {'logits': jnp.ones((batch_size, 4, num_categories))}, + 'b': {'c': {'logits': jnp.ones((batch_size, num_categories))}}, + } + + chex.assert_trees_all_close(output, expected_output) + + def test_nested_self_conditioning_network_inference(self): + num_categories = 10 + batch_size = 2 + + class DummyBackbone(nn.Module, arch_typing.ConditionalBackbone): + + @nn.compact + def __call__(self, x, conditioning_embeddings, is_training): + return jax.tree.map(lambda arr: arr[..., : arr.shape[-1] // 2] + 1.0, x) + + mock_process = mock.MagicMock() + mock_process.num_categories = num_categories + + processes = multimodal.NestedProcess( + processes={'a': mock_process, 'b': {'c': mock_process}} + ) + + xt = { + 'a': jnp.zeros((batch_size, 4, num_categories)), + 'b': {'c': jnp.zeros((batch_size, num_categories))}, + } + time = { + 'a': jnp.zeros((batch_size, 1)), + 'b': {'c': jnp.zeros((batch_size, 1))}, + } + + class DummyCondEncoder(nn.Module): + + @nn.compact + def __call__(self, time, conditioning, is_training): + return {} + + network = multimodal.NestedSelfConditioningDiffusionNetwork( + backbone_network=DummyBackbone(), + conditioning_encoder=DummyCondEncoder(), + prediction_type={'a': 'logits', 'b': {'c': 'logits'}}, + processes=processes, + self_cond_prob=0.0, # Should still self-condition in inference mode + ) + + key = jax.random.PRNGKey(0) + params_key, sc_key = jax.random.split(key) + + variables = network.init( + {'params': params_key, 'self_conditioning': sc_key}, + time=time, + xt=xt, + conditioning=None, + is_training=False, + ) + + output = network.apply( + variables, + time=time, + xt=xt, + conditioning=None, + is_training=False, + ) + + self.assertIsInstance(output, dict) + self.assertIn('a', output) + self.assertIn('b', output) + + expected_output = { + 'a': {'logits': jnp.ones((batch_size, 4, num_categories))}, + 'b': {'c': {'logits': jnp.ones((batch_size, num_categories))}}, + } + + chex.assert_trees_all_close(output, expected_output) + + def test_nested_self_conditioning_network_no_self_cond(self): + num_categories = 10 + batch_size = 2 + + class DummyBackbone(nn.Module, arch_typing.ConditionalBackbone): + + @nn.compact + def __call__(self, x, conditioning_embeddings, is_training): + def _check_non_zero(arr): + second_half = arr[..., arr.shape[-1] // 2 :] + return jnp.where( + jnp.any(second_half != 0), + jnp.ones_like(arr[..., : arr.shape[-1] // 2]), + jnp.zeros_like(arr[..., : arr.shape[-1] // 2]), + ) + + return jax.tree.map(_check_non_zero, x) + + mock_process = mock.MagicMock() + mock_process.num_categories = num_categories + + processes = multimodal.NestedProcess( + processes={'a': mock_process, 'b': {'c': mock_process}} + ) + + xt = { + 'a': jnp.zeros((batch_size, 4, num_categories)), + 'b': {'c': jnp.zeros((batch_size, num_categories))}, + } + time = { + 'a': jnp.zeros((batch_size, 1)), + 'b': {'c': jnp.zeros((batch_size, 1))}, + } + + class DummyCondEncoder(nn.Module): + + @nn.compact + def __call__(self, time, conditioning, is_training): + return {} + + network = multimodal.NestedSelfConditioningDiffusionNetwork( + backbone_network=DummyBackbone(), + conditioning_encoder=DummyCondEncoder(), + prediction_type={'a': 'logits', 'b': {'c': 'logits'}}, + processes=processes, + self_cond_prob=0.0, # Never self-condition during training + ) + + key = jax.random.PRNGKey(0) + params_key, sc_key = jax.random.split(key) + + variables = network.init( + {'params': params_key, 'self_conditioning': sc_key}, + time=time, + xt=xt, + conditioning=None, + is_training=True, + ) + + output = network.apply( + variables, + time=time, + xt=xt, + conditioning=None, + is_training=True, + rngs={'self_conditioning': sc_key}, + ) + + self.assertIsInstance(output, dict) + + expected_output = { + 'a': {'logits': jnp.zeros((batch_size, 4, num_categories))}, + 'b': {'c': {'logits': jnp.zeros((batch_size, num_categories))}}, + } + + chex.assert_trees_all_close(output, expected_output) + + if __name__ == '__main__': absltest.main()