Skip to content

Commit ecd41f8

Browse files
ccrepyHackable Diffusion Authors
authored andcommitted
Update Documentation.
PiperOrigin-RevId: 911379267
1 parent 851faa6 commit ecd41f8

8 files changed

Lines changed: 323 additions & 13 deletions

File tree

docs/architecture.md

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,50 @@ print(f"Output shape: {output.shape}")
134134
# Output shape: (1, 64, 64, 3)
135135
```
136136

137+
### `DiT`
138+
139+
(`lib/architecture/dit.py`)
140+
141+
The `DiT` class implements a **Diffusion Transformer** backbone based on
142+
<https://arxiv.org/abs/2212.09748>. It uses adaptive layer norm zero
143+
(adaLN-Zero) as the conditioning mechanism. The architecture consists of
144+
repeated transformer blocks with optional encoder/decoder and absolute
145+
positional encoding.
146+
147+
Key parameters:
148+
149+
* `num_blocks`: Number of DiT blocks.
150+
* `block`: A DiT block module (e.g., `DiTBlockAdaLNZero`).
151+
* `encoder`: Optional encoder (e.g., `Patchify` for image inputs).
152+
* `decoder`: Optional decoder (e.g., `DePatchify` for image outputs).
153+
* `absolute_posenc`: Optional positional encoding module.
154+
* `use_padding_mask`: Whether to mask out padding tokens (for tokenized
155+
inputs).
156+
157+
The `DiT` expects an `ADAPTIVE_NORM` conditioning embedding. The `mnist_dit`
158+
notebook demonstrates its usage.
159+
160+
## Diffusion Network
161+
162+
(`lib/diffusion_network.py`)
163+
164+
The **`DiffusionNetwork`** class is the primary entry point for constructing a
165+
complete diffusion model. It composes a backbone (e.g., `Unet` or `DiT`) with a
166+
`ConditioningEncoder` into a single Flax module that conforms to the
167+
`BaseDiffusionNetwork` protocol.
168+
169+
* **`DiffusionNetwork`**: Single-modal model. Takes `(time, xt,
170+
conditioning)` and internally runs the conditioning encoder, applies any
171+
input/time rescaling, and calls the backbone.
172+
* **`MultiModalDiffusionNetwork`**: Generalizes `DiffusionNetwork` to
173+
multi-modal PyTree data, allowing different prediction types and data
174+
dtypes per leaf.
175+
* **`SelfConditioningDiffusionNetwork`**: Adds self-conditioning, where the
176+
model receives its own previous prediction as an additional input.
177+
178+
These classes also support `InputRescaler` and `TimeRescaler` for
179+
schedule-dependent input preprocessing (e.g., EDM preconditioning).
180+
137181
### `ConditionalMLP`
138182

139183
(`lib/architecture/mlp.py`)

docs/corruption.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -367,10 +367,14 @@ The torus is a flat space with periodic boundary conditions.
367367
### Example Usage
368368

369369
```python
370+
import jax
371+
import jax.numpy as jnp
370372
from hackable_diffusion.lib import manifolds
371373
from hackable_diffusion.lib.corruption.riemannian import RiemannianProcess
372374
from hackable_diffusion.lib.corruption.schedules import LinearRiemannianSchedule
373375

376+
key = jax.random.PRNGKey(0)
377+
374378
# 1. Define manifold and process
375379
manifold = manifolds.Sphere()
376380
schedule = LinearRiemannianSchedule()
@@ -379,6 +383,7 @@ process = RiemannianProcess(manifold=manifold, schedule=schedule)
379383
# 2. Corrupt data
380384
x0 = jnp.array([[1.0, 0.0, 0.0]]) # Point on S2
381385
time = jnp.array([0.5])
386+
key, subkey = jax.random.split(key)
382387
xt, target_info = process.corrupt(subkey, x0, time)
383388

384389
# target_info['velocity'] is the regression target u_t

docs/index.md

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -80,13 +80,14 @@ encapsulates the call to the model and can be composed with other
8080
functionalities like classifier-free guidance. This allows the main sampling
8181
loop to be agnostic to the details of how a prediction is made.
8282

83-
### [Diffusion Loss Functions](./loss.md)
83+
### [Training](./training.md)
8484

8585
(`lib/training/`)
8686

8787
This module provides flexible loss functions for training diffusion models. It
8888
includes highly configurable weighted MSE losses for Gaussian processes (like
89-
`SiD2Loss`) and cross-entropy losses for discrete data.
89+
`SiD2Loss`) and cross-entropy losses for discrete data. It also provides time
90+
sampling strategies for selecting training timesteps.
9091

9192
### [Sampling](./sampling.md)
9293

@@ -107,14 +108,18 @@ excellent starting points for understanding the library's components in action.
107108
a simple 2D toy dataset.
108109
* **`mnist.ipynb`**: Trains a standard continuous diffusion model (Gaussian
109110
process) on the MNIST dataset, demonstrating image data handling.
111+
* **`mnist_dit.ipynb`**: Trains a Diffusion Transformer (DiT) on MNIST,
112+
showcasing the DiT backbone as an alternative to U-Net.
110113
* **`mnist_discrete.ipynb`**: Trains a discrete diffusion model on MNIST,
111114
treating pixel values as categorical data. This showcases the use of
112115
`CategoricalProcess`.
116+
* **`mnist_simplicial.ipynb`**: Trains a simplicial diffusion model on MNIST
117+
using `SimplicialProcess` with Dirichlet noise on the probability simplex.
113118
* **`mnist_multimodal.ipynb`**: A more advanced example that trains a
114119
multimodal model to jointly generate MNIST images with discrete and
115120
continuous diffusion models, demonstrating the "Nested" design pattern in a
116121
practical setting.
122+
* **`mnist_nn_and_nnx.ipynb`**: Demonstrates both Flax `nn` and `nnx` module
123+
styles for defining diffusion networks.
117124
* **`riemannian_sphere_training.ipynb`**: Demonstrates Riemannian Flow
118125
Matching on the unit sphere S^2.
119-
* **`riemannian_torus_ode_to_sde.ipynb`**: Shows how to use Riemannian Flow
120-
Matching on the torus manifold for both ODE and SDE sampling.

docs/inference.md

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -163,3 +163,54 @@ inference_fn = GuidedDiffusionInferenceFn(
163163
# predicted_x0 = final_prediction['x0']
164164
# ... use predicted_x0 to compute x_t_minus_1
165165
```
166+
167+
## Inference Wrappers
168+
169+
(`lib/inference/wrappers.py`)
170+
171+
In practice, you need a concrete way to convert a trained model into an
172+
`InferenceFn`. The library provides two wrappers:
173+
174+
### `FlaxLinenInferenceFn`
175+
176+
Wraps a Flax `nn.Module` and its parameters into an `InferenceFn`. This is the
177+
most common wrapper for models defined with the Linen API.
178+
179+
```python
180+
from hackable_diffusion.lib.inference.wrappers import FlaxLinenInferenceFn
181+
182+
base_inference_fn = FlaxLinenInferenceFn(
183+
network=my_diffusion_network, # An nn.Module
184+
params=restored_params, # A pytree of model parameters
185+
)
186+
```
187+
188+
### `FlaxNNXInferenceFn`
189+
190+
Wraps an NNX module (converted from a Linen module) into an `InferenceFn`.
191+
192+
```python
193+
from hackable_diffusion.lib.inference.wrappers import FlaxNNXInferenceFn
194+
195+
base_inference_fn = FlaxNNXInferenceFn(
196+
nnx_network=my_nnx_network, # A ConvertedNNXDiffusionNetwork
197+
)
198+
```
199+
200+
### `convert_flax_linen_module_with_params_to_nnx`
201+
202+
A utility function to bridge a Linen module and its pre-trained parameters to
203+
an NNX module:
204+
205+
```python
206+
from hackable_diffusion.lib.inference.wrappers import (
207+
convert_flax_linen_module_with_params_to_nnx
208+
)
209+
210+
nnx_model = convert_flax_linen_module_with_params_to_nnx(
211+
linen_module=my_linen_module,
212+
restored_linen_params=restored_params,
213+
dummy_time, dummy_xt, dummy_conditioning, False, # init args
214+
)
215+
```
216+

docs/multimodal.md

Lines changed: 145 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,145 @@
1+
# Multimodal Diffusion
2+
3+
This document explains how to use Hackable Diffusion's "Nested" wrappers to
4+
build multimodal diffusion models that operate on PyTree-structured data.
5+
6+
The multimodal wrappers are located in `lib/multimodal.py`.
7+
8+
[TOC]
9+
10+
## Overview
11+
12+
Hackable Diffusion's core protocols (`CorruptionProcess`, `SamplerStep`,
13+
`DiffusionLoss`, etc.) are designed around single-modal arrays. To handle
14+
multimodal data — where different parts of the input (e.g., image + labels,
15+
continuous + discrete) require different diffusion treatments — the library
16+
provides **Nested wrappers**.
17+
18+
Each wrapper takes a PyTree of single-modal components that matches the
19+
structure of your data. When called, it dispatches each method to the
20+
corresponding component-data pair.
21+
22+
## Available Wrappers
23+
24+
### Training
25+
26+
* **`NestedProcess`**: Applies different corruption processes per modality.
27+
* **`NestedDiffusionLoss`**: Computes different loss functions per modality.
28+
* **`NestedTimeSampler`**: Samples timesteps independently per modality.
29+
30+
### Sampling
31+
32+
* **`NestedSamplerStep`**: Runs different sampler algorithms per modality.
33+
* **`NestedTimeSchedule`**: Uses different time discretizations per modality.
34+
* **`NestedGuidanceFn`**: Applies different guidance functions per modality.
35+
36+
## Key Concept: Structure Matching
37+
38+
The **structure of your Nested wrapper must match the structure of your data**.
39+
For example, if your data is a dictionary `{"image": ..., "label": ...}`, your
40+
`NestedProcess` must also be keyed by `{"image": ..., "label": ...}`.
41+
42+
```python
43+
data = {
44+
"image": jnp.zeros((batch, 32, 32, 3)),
45+
"label": jnp.zeros((batch, 1), dtype=jnp.int32),
46+
}
47+
48+
process = NestedProcess(
49+
processes={
50+
"image": GaussianProcess(schedule=CosineSchedule()),
51+
"label": CategoricalProcess.masking_process(
52+
schedule=LinearDiscreteSchedule(), num_categories=10,
53+
),
54+
}
55+
)
56+
```
57+
58+
## Example: Multimodal Training Setup
59+
60+
```python
61+
from hackable_diffusion.lib.multimodal import (
62+
NestedProcess,
63+
NestedDiffusionLoss,
64+
NestedTimeSampler,
65+
)
66+
from hackable_diffusion.lib.corruption.gaussian import GaussianProcess
67+
from hackable_diffusion.lib.corruption.discrete import CategoricalProcess
68+
from hackable_diffusion.lib.corruption.schedules import (
69+
CosineSchedule,
70+
LinearDiscreteSchedule,
71+
)
72+
from hackable_diffusion.lib.training.gaussian_loss import SiD2Loss
73+
from hackable_diffusion.lib.training.discrete_loss import MD4Loss
74+
from hackable_diffusion.lib.training.time_sampling import UniformTimeSampler
75+
76+
# 1. Define per-modality corruption processes
77+
process = NestedProcess(
78+
processes={
79+
"image": GaussianProcess(schedule=CosineSchedule()),
80+
"label": CategoricalProcess.masking_process(
81+
schedule=LinearDiscreteSchedule(), num_categories=10,
82+
),
83+
}
84+
)
85+
86+
# 2. Define per-modality losses
87+
loss_fn = NestedDiffusionLoss(
88+
losses={
89+
"image": SiD2Loss(schedule=CosineSchedule()),
90+
"label": MD4Loss(schedule=LinearDiscreteSchedule()),
91+
}
92+
)
93+
94+
# 3. Define per-modality time sampling (optional — can also share time)
95+
time_sampler = NestedTimeSampler(
96+
time_samplers={
97+
"image": UniformTimeSampler(),
98+
"label": UniformTimeSampler(),
99+
}
100+
)
101+
```
102+
103+
## Example: Multimodal Sampling Setup
104+
105+
```python
106+
from hackable_diffusion.lib.multimodal import (
107+
NestedSamplerStep,
108+
NestedTimeSchedule,
109+
)
110+
from hackable_diffusion.lib.sampling.gaussian_step_sampler import DDIMStep
111+
from hackable_diffusion.lib.sampling.discrete_step_sampler import DiscreteDDIMStep
112+
from hackable_diffusion.lib.sampling.time_scheduling import UniformTimeSchedule
113+
114+
sampler_step = NestedSamplerStep(
115+
sampler_steps={
116+
"image": DDIMStep(
117+
corruption_process=gaussian_process,
118+
stoch_coeff=0.0,
119+
),
120+
"label": DiscreteDDIMStep(
121+
corruption_process=categorical_process,
122+
),
123+
}
124+
)
125+
126+
time_schedule = NestedTimeSchedule(
127+
time_schedules={
128+
"image": UniformTimeSchedule(),
129+
"label": UniformTimeSchedule(),
130+
}
131+
)
132+
```
133+
134+
## How It Works
135+
136+
Internally, Nested wrappers use `utils.lenient_map` to traverse the data and
137+
component PyTrees in parallel, calling the corresponding method on each
138+
component with its matching data leaf. This means:
139+
140+
* Any nesting depth works (dictionaries, named tuples, etc.).
141+
* Single-modal and multimodal code share the same protocols.
142+
* You can mix and match any combination of corruption processes, samplers, and
143+
losses.
144+
145+
The `mnist_multimodal` notebook provides a complete end-to-end example.

docs/sampling.md

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,9 @@ Implementations for **Gaussian** processes include:
8484
* **`DDIMStep`**: Implements the popular Denoising Diffusion Implicit Models
8585
sampler. It can be deterministic (`stoch_coeff=0.0`) or stochastic
8686
(`stoch_coeff > 0.0`).
87+
* **`AdjustedDDIMStep`**: An improved DDIM variant from
88+
<https://arxiv.org/abs/2403.06807> that adjusts the update with an
89+
estimated covariance term to reduce sampling error.
8790
* **`SdeStep`**: A stochastic sampler based on discretizing the reverse-time
8891
Stochastic Differential Equation (SDE).
8992
* **`VelocityStep`**: A sampler that operates using the velocity prediction
@@ -332,20 +335,23 @@ from hackable_diffusion.lib import manifolds
332335
from hackable_diffusion.lib.corruption.riemannian import RiemannianProcess
333336
from hackable_diffusion.lib.corruption.schedules import LinearRiemannianSchedule
334337
from hackable_diffusion.lib.sampling.riemannian_sampling import RiemannianFlowSamplerStep
338+
from hackable_diffusion.lib.sampling.time_scheduling import UniformTimeSchedule
335339

336340
# 1. Define manifold and process
337341
manifold = manifolds.Sphere()
338342
process = RiemannianProcess(
343+
339344
manifold=manifold,
340345
schedule=LinearRiemannianSchedule(),
346+
, schedule=LinearRiemannianSchedule()
341347
)
342348

343349
# 2. Configure Sampler Step
344350
stepper = RiemannianFlowSamplerStep(corruption_process=process)
345351

346352
# 3. Create the sampler
347353
sampler = DiffusionSampler(
348-
time_schedule=UniformTimeSchedule(), # or EDM
354+
time_schedule=UniformTimeSchedule(),
349355
stepper=stepper,
350356
num_steps=50,
351357
)

docs/sitemap.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,6 @@
33
* [Architecture](./architecture.md)
44
* [Corruption Processes](./corruption.md)
55
* [Inference Function](./inference.md)
6-
* [Diffusion Loss Functions](./loss.md)
6+
* [Multimodal](./multimodal.md)
77
* [Sampling](./sampling.md)
8-
8+
* [Training](./training.md)

0 commit comments

Comments
 (0)