diff --git a/hackable_diffusion/lib/sampling/discrete_step_sampler.py b/hackable_diffusion/lib/sampling/discrete_step_sampler.py index fa33f89..0c2de15 100644 --- a/hackable_diffusion/lib/sampling/discrete_step_sampler.py +++ b/hackable_diffusion/lib/sampling/discrete_step_sampler.py @@ -32,6 +32,56 @@ * `update`: Takes current state and returns the next step. * `finalize`: Takes last state and returns the final state. + The following diagram illustrates the flow of the discrete denoising process: + + ┌───────────────────────┐ + │ DiffusionStep │ + │ Holds current state │ + │ xt │ + └───────────┬───────────┘ + │ + ▼ + ┌───────────────────────┐ + │ InferenceFn │ + │ Calls neural network │ + │ to get predictions │ + └───────────┬───────────┘ + │ prediction (logits) + ▼ + ┌───────────────────────┐ + │ _generate_candidates │ + │ Extracts x0, x_noise │ + │ │ + └───────────┬───────────┘ + │ + ▼ + ┌───────────────────────┐ + │ Compute Routing Wgts │ + │ Decomposes posterior │ + │ to STAY/NOISE/CLEAN │ + └───────────┬───────────┘ + │ + ▼ + ┌───────────────────────┐ + │ Planner │ + │ Modifies weights │ + │ (Optional, e.g. topk)│ + └───────────┬───────────┘ + │ + ▼ + ┌───────────────────────┐ + │ _sample_routing │ + │ Categorical sampling │ + │ to get next xt │ + └───────────┬───────────┘ + │ + ▼ + ┌───────────────────────┐ + │ DiffusionStep │ + │ Holds next state │ + │ xt │ + └───────────────────────┘ + This module also introduces the concepts of Routing and Planning for discrete diffusion: - Routing: Defines the transition probabilities at each position among the