Skip to content

Commit bedf2d4

Browse files
vdebortoHackable Diffusion Authors
authored andcommitted
Add diagram to discrete_step_sampler.py docstring illustrating routing and planning flow.
PiperOrigin-RevId: 910755948
1 parent ae44b17 commit bedf2d4

3 files changed

Lines changed: 330 additions & 29 deletions

File tree

hackable_diffusion/lib/sampling/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from hackable_diffusion.lib.sampling.discrete_step_sampler import CorruptedMaskFn
2626
from hackable_diffusion.lib.sampling.discrete_step_sampler import DiscreteDDIMStep
2727
from hackable_diffusion.lib.sampling.discrete_step_sampler import DiscreteFlowMatchingStep
28+
from hackable_diffusion.lib.sampling.discrete_step_sampler import GreedyPlanner
2829
from hackable_diffusion.lib.sampling.discrete_step_sampler import IntegratedDiscreteDDIMStep
2930
from hackable_diffusion.lib.sampling.discrete_step_sampler import MaskValueCorruptedMaskFn
3031
from hackable_diffusion.lib.sampling.discrete_step_sampler import MaxCappedRemaskingFn

hackable_diffusion/lib/sampling/discrete_step_sampler.py

Lines changed: 189 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -12,23 +12,75 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
"""Actual implementation of the sampling steps.
16-
17-
This module proposes various implementations but they all have in common
18-
the core logic:
19-
20-
* An `initialize` function that takes a starting state and returns the
21-
first step of the diffusion process.
22-
* An `update` function that takes the current state and returns the next step.
23-
* A `finalize` function that takes the last state and returns the final
24-
state.
25-
26-
At every step, the update function takes the current state and returns the next
27-
state. The update is also in charge of computing other auxiliary informations
28-
such as volatility, drifts, etc.
29-
30-
The `InferenceFn` is also called within the step and converted into the
31-
relevant representation, for instance score, velocity, etc.
15+
"""Actual implementation of the sampling steps for discrete diffusion.
16+
17+
This module implements various discrete samplers (UnMasking, DDIM, Flow
18+
Matching) using a unified **routing** architecture.
19+
20+
Core concepts:
21+
* **Routing**: Samplers decompose the transition probability into a mixture of
22+
three actions:
23+
- STAY (0): Keep the current token `xt`.
24+
- NOISE (1): Sample from the invariant distribution.
25+
- CLEAN (2): Use the predicted clean token `x0`.
26+
* **RoutingWeightPlanner**: An optional transformation applied to routing
27+
weights before they are sampled. This allows injecting custom selection
28+
strategies (e.g., greedy top-k) without modifying the sampler physics.
29+
30+
Standard interface for all samplers:
31+
* `initialize`: Takes a starting state and returns the first step.
32+
* `update`: Takes current state and returns the next step.
33+
* `finalize`: Takes last state and returns the final state.
34+
35+
The following diagram illustrates the flow of the discrete denoising process:
36+
37+
┌───────────────────────┐
38+
│ DiffusionStep │
39+
│ Holds current state │
40+
│ xt │
41+
└───────────┬───────────┘
42+
43+
44+
┌───────────────────────┐
45+
│ InferenceFn │
46+
│ Calls neural network │
47+
│ to get predictions │
48+
└───────────┬───────────┘
49+
│ prediction (logits)
50+
51+
┌───────────────────────┐
52+
│ _generate_candidates │
53+
│ Extracts x0, x_noise │
54+
│ │
55+
└───────────┬───────────┘
56+
57+
58+
┌───────────────────────┐
59+
│ Compute Routing Wgts │
60+
│ Decomposes posterior │
61+
│ to STAY/NOISE/CLEAN │
62+
└───────────┬───────────┘
63+
64+
65+
┌───────────────────────┐
66+
│ Planner │
67+
│ Modifies weights │
68+
│ (Optional, e.g. topk)│
69+
└───────────┬───────────┘
70+
71+
72+
┌───────────────────────┐
73+
│ _sample_routing │
74+
│ Categorical sampling │
75+
│ to get next xt │
76+
└───────────┬───────────┘
77+
78+
79+
┌───────────────────────┐
80+
│ DiffusionStep │
81+
│ Holds next state │
82+
│ xt │
83+
└───────────────────────┘
3284
3385
This module also introduces the concepts of Routing and Planning for discrete
3486
diffusion:
@@ -381,6 +433,104 @@ def __call__(
381433
...
382434

383435

436+
@dataclasses.dataclass(frozen=True, kw_only=True)
437+
class GreedyPlanner(RoutingStrategy):
438+
"""Confidence-based top-k greedy planner.
439+
440+
Selects the top-k positions by confidence of the sampled x0 token, forces
441+
those to CLEAN, and handles the remaining positions according to the
442+
``resample`` flag.
443+
444+
The budget k is computed as:
445+
k = num_eligible * p_clean_norm
446+
where num_eligible is the number of positions with p_clean > 0 and
447+
p_clean_norm is the normalized clean probability. For a linear schedule
448+
α = 1−t this simplifies to k = num_eligible * (1 − next_time/time).
449+
450+
Attributes:
451+
tie_breaking_noise: Scale of uniform noise added to confidence scores for
452+
tie-breaking. Default 1e-6.
453+
resample: If True, non-selected positions keep their full original routing
454+
weights (stay, noise, and clean), allowing stochastic resampling via the
455+
posterior. If False (default), non-selected positions have their clean
456+
weight zeroed out so they can only stay or re-noise.
457+
"""
458+
459+
tie_breaking_noise: float = 1e-6
460+
resample: bool = False
461+
462+
@kt.typechecked
463+
def __call__(
464+
self,
465+
routing_weights: RoutingWeights,
466+
logits: Float['... M'],
467+
x0: DataArray,
468+
xt: DataArray,
469+
time: TimeArray,
470+
next_time: TimeArray,
471+
key: PRNGKey,
472+
) -> RoutingWeights:
473+
# Routing weights have shape (batch, *spatial, 1). We flatten all spatial
474+
# dims into a single "positions" dim so top-k selection works regardless
475+
# of whether the data is 1D (sequences) or 2D (adjacency matrices).
476+
batch_size = routing_weights.stay.shape[0]
477+
spatial_shape = routing_weights.stay.shape[1:-1] # e.g. (seq,) or (N, N)
478+
479+
# Confidence = softmax(logits)[x0] per position
480+
p = jax.nn.softmax(logits, axis=-1)
481+
confidence = jnp.take_along_axis(p, x0, axis=-1).squeeze(-1)
482+
# confidence: (batch, *spatial) → flatten to (batch, num_positions)
483+
confidence = confidence.reshape(batch_size, -1)
484+
485+
# Only consider positions that could go clean (p_clean > 0)
486+
eligible = routing_weights.clean[..., 0] > 0
487+
eligible = eligible.reshape(batch_size, -1)
488+
confidence = jnp.where(eligible, confidence, -jnp.inf)
489+
490+
# Add tie-breaking noise
491+
_, subkey = jax.random.split(key)
492+
confidence += jax.random.uniform(subkey, confidence.shape) * (
493+
self.tie_breaking_noise
494+
)
495+
496+
# Budget: k = num_eligible * p_clean_norm.
497+
# p_clean_norm = p_clean / (p_stay + p_noise + p_clean) is the same
498+
# for all eligible positions (π(x_t) cancels in normalization).
499+
# For a linear schedule (α = 1-t), this equals (1 - next_time/time).
500+
total_weight = (
501+
routing_weights.stay + routing_weights.noise + routing_weights.clean
502+
)
503+
p_clean_norm = routing_weights.clean / jnp.maximum(total_weight, 1e-12)
504+
# p_clean_norm is the same for all eligible positions. Take max over spatial
505+
# dimensions to get the value (ineligible positions have 0).
506+
p_clean_norm_flat = p_clean_norm.reshape(batch_size, -1)
507+
frac = jnp.max(p_clean_norm_flat, axis=-1, keepdims=True)
508+
frac = jnp.clip(frac, 0.0, 1.0)
509+
510+
num_eligible = jnp.sum(eligible.astype(jnp.float32), axis=-1, keepdims=True)
511+
k = (num_eligible * frac).astype(jnp.int32)
512+
513+
# Top-k threshold (operates on flattened positions)
514+
num_positions = confidence.shape[-1]
515+
sorted_conf = jnp.sort(confidence, axis=-1)[..., ::-1]
516+
threshold = jnp.take_along_axis(
517+
sorted_conf, jnp.clip(k - 1, 0, num_positions - 1), axis=-1
518+
)
519+
# When k=0, no positions should be selected.
520+
to_update = (confidence >= threshold) & (k > 0)
521+
522+
# Unflatten to_update back to original spatial shape
523+
to_update = to_update.reshape(batch_size, *spatial_shape)
524+
525+
# Non-selected positions have clean zeroed out; they can only stay or
526+
# re-noise according to the posterior.
527+
return RoutingWeights(
528+
stay=jnp.where(to_update[..., None], 0.0, routing_weights.stay),
529+
noise=jnp.where(to_update[..., None], 0.0, routing_weights.noise),
530+
clean=jnp.where(to_update[..., None], 1.0, 0.0),
531+
)
532+
533+
384534
################################################################################
385535
# MARK: UnMasking Step
386536
################################################################################
@@ -408,6 +558,8 @@ class UnMaskingStep(SamplerStep):
408558
409559
Attributes:
410560
corruption_process: The corruption process to use.
561+
planner: The planner to use for transforming routing weights. This is
562+
optional with the default being ``None`` (pure stochastic sampling).
411563
remasking_fn: The remasking function to use, see
412564
https://arxiv.org/abs/2503.00307v1. This is optional with the default
413565
being no remasking.
@@ -632,6 +784,12 @@ class DiscreteDDIMStep(SamplerStep):
632784
Note: when π = δ_MASK (masking process) and x_t = MASK, this reduces to:
633785
P(unmask) = (α_s - α_t) / (1 - α_t),
634786
which coincides with the UnMaskingStep formula (without remasking).
787+
788+
Attributes:
789+
corruption_process: The corruption process to use.
790+
planner: The planner to use for transforming routing weights. This is
791+
optional with the default being ``None`` (pure stochastic sampling).
792+
temperature: The temperature to use.
635793
"""
636794

637795
corruption_process: CategoricalProcess
@@ -782,15 +940,17 @@ class DiscreteFlowMatchingStep(SamplerStep):
782940
This sampler uses the 3-way routing representation. The update rule
783941
decomposes naturally into:
784942
785-
p(x_s) = p_stay * δ_{x_t} + p_up * p_x0 + p_down * π
943+
p(x_s) = p_stay * δ_{x_t} + p_clean * p_x0 + p_noise * π
786944
787945
where:
788-
- p_stay = 1 - p_up - p_down
789-
- p_up = (α_s - α_t) / (1 - α_t) * (1 + stoch_coeff)
790-
- p_down = (α_s - α_t) / α_t * stoch_coeff
946+
- p_stay = 1 - p_clean - p_noise
947+
- p_clean = (α_s - α_t) / (1 - α_t) * (1 + stoch_coeff)
948+
- p_noise = (α_s - α_t) / α_t * stoch_coeff
791949
792950
Attributes:
793951
corruption_process: The corruption process to use.
952+
planner: The planner to use for transforming routing weights. This is
953+
optional with the default being ``None`` (pure stochastic sampling).
794954
temperature: The temperature to use.
795955
stoch_coeff: The stochasticity coefficient (default 0.0). Higher values
796956
introduce more noise during the denoising process.
@@ -853,25 +1013,25 @@ def update(
8531013
alpha_s = self.corruption_process.schedule.alpha(next_time_bcast)
8541014
alpha_t = self.corruption_process.schedule.alpha(time_bcast)
8551015

856-
prob_up = (
1016+
p_clean_raw = (
8571017
(alpha_s - alpha_t)
8581018
/ jnp.maximum(1.0 - alpha_t, 1e-12)
8591019
* (1.0 + self.stoch_coeff)
8601020
)
861-
prob_down = (
1021+
p_noise_raw = (
8621022
(alpha_s - alpha_t) / jnp.maximum(alpha_t, 1e-12) * self.stoch_coeff
8631023
)
8641024

865-
# Clip and rescale to ensure valid probabilities
866-
raw_p_up = jnp.maximum(prob_up, 0.0)
867-
raw_p_down = jnp.maximum(prob_down, 0.0)
868-
sum_jumps = raw_p_up + raw_p_down
1025+
# Clip and rescale to ensure valid weights
1026+
p_clean_raw = jnp.maximum(p_clean_raw, 0.0)
1027+
p_noise_raw = jnp.maximum(p_noise_raw, 0.0)
1028+
sum_jumps = p_clean_raw + p_noise_raw
8691029
scale_factor = jnp.maximum(1.0, sum_jumps)
8701030

8711031
# Compute the probabilities for the three routing options.
8721032
# This is computed according to https://arxiv.org/abs/2407.15595.
873-
p_clean = raw_p_up / scale_factor
874-
p_noise = raw_p_down / scale_factor
1033+
p_clean = p_clean_raw / scale_factor
1034+
p_noise = p_noise_raw / scale_factor
8751035
p_stay = 1.0 - p_clean - p_noise
8761036

8771037
routing_weights = RoutingWeights(stay=p_stay, noise=p_noise, clean=p_clean)

0 commit comments

Comments
 (0)