|
| 1 | +# Multimodal Diffusion |
| 2 | + |
| 3 | +This document explains how to use Hackable Diffusion's "Nested" wrappers to |
| 4 | +build multimodal diffusion models that operate on PyTree-structured data. |
| 5 | + |
| 6 | +The multimodal wrappers are located in `lib/multimodal.py`. |
| 7 | + |
| 8 | +[TOC] |
| 9 | + |
| 10 | +## Overview |
| 11 | + |
| 12 | +Hackable Diffusion's core protocols (`CorruptionProcess`, `SamplerStep`, |
| 13 | +`DiffusionLoss`, etc.) are designed around single-modal arrays. To handle |
| 14 | +multimodal data — where different parts of the input (e.g., image + labels, |
| 15 | +continuous + discrete) require different diffusion treatments — the library |
| 16 | +provides **Nested wrappers**. |
| 17 | + |
| 18 | +Each wrapper takes a PyTree of single-modal components that matches the |
| 19 | +structure of your data. When called, it dispatches each method to the |
| 20 | +corresponding component-data pair. |
| 21 | + |
| 22 | +## Available Wrappers |
| 23 | + |
| 24 | +### Training |
| 25 | + |
| 26 | +* **`NestedProcess`**: Applies different corruption processes per modality. |
| 27 | +* **`NestedDiffusionLoss`**: Computes different loss functions per modality. |
| 28 | +* **`NestedTimeSampler`**: Samples timesteps independently per modality. |
| 29 | + |
| 30 | +### Sampling |
| 31 | + |
| 32 | +* **`NestedSamplerStep`**: Runs different sampler algorithms per modality. |
| 33 | +* **`NestedTimeSchedule`**: Uses different time discretizations per modality. |
| 34 | +* **`NestedGuidanceFn`**: Applies different guidance functions per modality. |
| 35 | + |
| 36 | +## Key Concept: Structure Matching |
| 37 | + |
| 38 | +The **structure of your Nested wrapper must match the structure of your data**. |
| 39 | +For example, if your data is a dictionary `{"image": ..., "label": ...}`, your |
| 40 | +`NestedProcess` must also be keyed by `{"image": ..., "label": ...}`. |
| 41 | + |
| 42 | +```python |
| 43 | +data = { |
| 44 | + "image": jnp.zeros((batch, 32, 32, 3)), |
| 45 | + "label": jnp.zeros((batch, 1), dtype=jnp.int32), |
| 46 | +} |
| 47 | + |
| 48 | +process = NestedProcess( |
| 49 | + processes={ |
| 50 | + "image": GaussianProcess(schedule=CosineSchedule()), |
| 51 | + "label": CategoricalProcess.masking_process( |
| 52 | + schedule=LinearDiscreteSchedule(), num_categories=10, |
| 53 | + ), |
| 54 | + } |
| 55 | +) |
| 56 | +``` |
| 57 | + |
| 58 | +## Example: Multimodal Training Setup |
| 59 | + |
| 60 | +```python |
| 61 | +from hackable_diffusion.lib.multimodal import ( |
| 62 | + NestedProcess, |
| 63 | + NestedDiffusionLoss, |
| 64 | + NestedTimeSampler, |
| 65 | +) |
| 66 | +from hackable_diffusion.lib.corruption.gaussian import GaussianProcess |
| 67 | +from hackable_diffusion.lib.corruption.discrete import CategoricalProcess |
| 68 | +from hackable_diffusion.lib.corruption.schedules import ( |
| 69 | + CosineSchedule, |
| 70 | + LinearDiscreteSchedule, |
| 71 | +) |
| 72 | +from hackable_diffusion.lib.training.gaussian_loss import SiD2Loss |
| 73 | +from hackable_diffusion.lib.training.discrete_loss import MD4Loss |
| 74 | +from hackable_diffusion.lib.training.time_sampling import UniformTimeSampler |
| 75 | + |
| 76 | +# 1. Define per-modality corruption processes |
| 77 | +process = NestedProcess( |
| 78 | + processes={ |
| 79 | + "image": GaussianProcess(schedule=CosineSchedule()), |
| 80 | + "label": CategoricalProcess.masking_process( |
| 81 | + schedule=LinearDiscreteSchedule(), num_categories=10, |
| 82 | + ), |
| 83 | + } |
| 84 | +) |
| 85 | + |
| 86 | +# 2. Define per-modality losses |
| 87 | +loss_fn = NestedDiffusionLoss( |
| 88 | + losses={ |
| 89 | + "image": SiD2Loss(schedule=CosineSchedule()), |
| 90 | + "label": MD4Loss(schedule=LinearDiscreteSchedule()), |
| 91 | + } |
| 92 | +) |
| 93 | + |
| 94 | +# 3. Define per-modality time sampling (optional — can also share time) |
| 95 | +time_sampler = NestedTimeSampler( |
| 96 | + time_samplers={ |
| 97 | + "image": UniformTimeSampler(), |
| 98 | + "label": UniformTimeSampler(), |
| 99 | + } |
| 100 | +) |
| 101 | +``` |
| 102 | + |
| 103 | +## Example: Multimodal Sampling Setup |
| 104 | + |
| 105 | +```python |
| 106 | +from hackable_diffusion.lib.multimodal import ( |
| 107 | + NestedSamplerStep, |
| 108 | + NestedTimeSchedule, |
| 109 | +) |
| 110 | +from hackable_diffusion.lib.sampling.gaussian_step_sampler import DDIMStep |
| 111 | +from hackable_diffusion.lib.sampling.discrete_step_sampler import DiscreteDDIMStep |
| 112 | +from hackable_diffusion.lib.sampling.time_scheduling import UniformTimeSchedule |
| 113 | + |
| 114 | +sampler_step = NestedSamplerStep( |
| 115 | + sampler_steps={ |
| 116 | + "image": DDIMStep( |
| 117 | + corruption_process=gaussian_process, |
| 118 | + stoch_coeff=0.0, |
| 119 | + ), |
| 120 | + "label": DiscreteDDIMStep( |
| 121 | + corruption_process=categorical_process, |
| 122 | + ), |
| 123 | + } |
| 124 | +) |
| 125 | + |
| 126 | +time_schedule = NestedTimeSchedule( |
| 127 | + time_schedules={ |
| 128 | + "image": UniformTimeSchedule(), |
| 129 | + "label": UniformTimeSchedule(), |
| 130 | + } |
| 131 | +) |
| 132 | +``` |
| 133 | + |
| 134 | +## How It Works |
| 135 | + |
| 136 | +Internally, Nested wrappers use `utils.lenient_map` to traverse the data and |
| 137 | +component PyTrees in parallel, calling the corresponding method on each |
| 138 | +component with its matching data leaf. This means: |
| 139 | + |
| 140 | +* Any nesting depth works (dictionaries, named tuples, etc.). |
| 141 | +* Single-modal and multimodal code share the same protocols. |
| 142 | +* You can mix and match any combination of corruption processes, samplers, and |
| 143 | + losses. |
| 144 | + |
| 145 | +The `mnist_multimodal` notebook provides a complete end-to-end example. |
0 commit comments