Skip to content

Commit 6bb82b8

Browse files
vdebortoHackable Diffusion Authors
authored andcommitted
Add diagram to discrete_step_sampler.py docstring illustrating routing and planning flow.
PiperOrigin-RevId: 910755948
1 parent 68eed8a commit 6bb82b8

2 files changed

Lines changed: 164 additions & 37 deletions

File tree

hackable_diffusion/lib/sampling/discrete_step_sampler.py

Lines changed: 116 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -12,23 +12,75 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
"""Actual implementation of the sampling steps.
16-
17-
This module proposes various implementations but they all have in common
18-
the core logic:
19-
20-
* An `initialize` function that takes a starting state and returns the
21-
first step of the diffusion process.
22-
* An `update` function that takes the current state and returns the next step.
23-
* A `finalize` function that takes the last state and returns the final
24-
state.
25-
26-
At every step, the update function takes the current state and returns the next
27-
state. The update is also in charge of computing other auxiliary informations
28-
such as volatility, drifts, etc.
29-
30-
The `InferenceFn` is also called within the step and converted into the
31-
relevant representation, for instance score, velocity, etc.
15+
"""Actual implementation of the sampling steps for discrete diffusion.
16+
17+
This module implements various discrete samplers (UnMasking, DDIM, Flow
18+
Matching) using a unified **routing** architecture.
19+
20+
Core concepts:
21+
* **Routing**: Samplers decompose the transition probability into a mixture of
22+
three actions:
23+
- STAY (0): Keep the current token `xt`.
24+
- NOISE (1): Sample from the invariant distribution.
25+
- CLEAN (2): Use the predicted clean token `x0`.
26+
* **RoutingWeightPlanner**: An optional transformation applied to routing
27+
weights before they are sampled. This allows injecting custom selection
28+
strategies (e.g., greedy top-k) without modifying the sampler physics.
29+
30+
Standard interface for all samplers:
31+
* `initialize`: Takes a starting state and returns the first step.
32+
* `update`: Takes current state and returns the next step.
33+
* `finalize`: Takes last state and returns the final state.
34+
35+
The following diagram illustrates the flow of the discrete denoising process:
36+
37+
┌───────────────────────┐
38+
│ DiffusionStep │
39+
│ Holds current state │
40+
│ xt │
41+
└───────────┬───────────┘
42+
43+
44+
┌───────────────────────┐
45+
│ InferenceFn │
46+
│ Calls neural network │
47+
│ to get predictions │
48+
└───────────┬───────────┘
49+
│ prediction (logits)
50+
51+
┌───────────────────────┐
52+
│ _generate_candidates │
53+
│ Extracts x0, x_noise │
54+
│ │
55+
└───────────┬───────────┘
56+
57+
58+
┌───────────────────────┐
59+
│ Compute Routing Wgts │
60+
│ Decomposes posterior │
61+
│ to STAY/NOISE/CLEAN │
62+
└───────────┬───────────┘
63+
64+
65+
┌───────────────────────┐
66+
│ Planner │
67+
│ Modifies weights │
68+
│ (Optional, e.g. topk)│
69+
└───────────┬───────────┘
70+
71+
72+
┌───────────────────────┐
73+
│ _sample_routing │
74+
│ Categorical sampling │
75+
│ to get next xt │
76+
└───────────┬───────────┘
77+
78+
79+
┌───────────────────────┐
80+
│ DiffusionStep │
81+
│ Holds next state │
82+
│ xt │
83+
└───────────────────────┘
3284
3385
This module also introduces the concepts of Routing and Planning for discrete
3486
diffusion:
@@ -386,8 +438,8 @@ class GreedyPlanner(RoutingStrategy):
386438
"""Confidence-based top-k greedy planner.
387439
388440
Selects the top-k positions by confidence of the sampled x0 token, forces
389-
those to CLEAN, and lets the remaining positions follow their natural
390-
stay/noise dynamics (with p_clean zeroed out).
441+
those to CLEAN, and handles the remaining positions according to the
442+
``resample`` flag.
391443
392444
The budget k is computed as:
393445
k = num_eligible * p_clean_norm
@@ -398,9 +450,14 @@ class GreedyPlanner(RoutingStrategy):
398450
Attributes:
399451
tie_breaking_noise: Scale of uniform noise added to confidence scores for
400452
tie-breaking. Default 1e-6.
453+
resample: If True, non-selected positions keep their full original routing
454+
weights (stay, noise, and clean), allowing stochastic resampling via the
455+
posterior. If False (default), non-selected positions have their clean
456+
weight zeroed out so they can only stay or re-noise.
401457
"""
402458

403459
tie_breaking_noise: float = 1e-6
460+
resample: bool = False
404461

405462
@kt.typechecked
406463
def __call__(
@@ -413,12 +470,21 @@ def __call__(
413470
next_time: TimeArray,
414471
key: PRNGKey,
415472
) -> RoutingWeights:
473+
# Routing weights have shape (batch, *spatial, 1). We flatten all spatial
474+
# dims into a single "positions" dim so top-k selection works regardless
475+
# of whether the data is 1D (sequences) or 2D (adjacency matrices).
476+
batch_size = routing_weights.stay.shape[0]
477+
spatial_shape = routing_weights.stay.shape[1:-1] # e.g. (seq,) or (N, N)
478+
416479
# Confidence = softmax(logits)[x0] per position
417480
p = jax.nn.softmax(logits, axis=-1)
418481
confidence = jnp.take_along_axis(p, x0, axis=-1).squeeze(-1)
482+
# confidence: (batch, *spatial) → flatten to (batch, num_positions)
483+
confidence = confidence.reshape(batch_size, -1)
419484

420485
# Only consider positions that could go clean (p_clean > 0)
421486
eligible = routing_weights.clean[..., 0] > 0
487+
eligible = eligible.reshape(batch_size, -1)
422488
confidence = jnp.where(eligible, confidence, -jnp.inf)
423489

424490
# Add tie-breaking noise
@@ -437,24 +503,27 @@ def __call__(
437503
p_clean_norm = routing_weights.clean / jnp.maximum(total_weight, 1e-12)
438504
# p_clean_norm is the same for all eligible positions. Take max over spatial
439505
# dimensions to get the value (ineligible positions have 0).
440-
p_clean_norm_flat = p_clean_norm.reshape(p_clean_norm.shape[0], -1)
506+
p_clean_norm_flat = p_clean_norm.reshape(batch_size, -1)
441507
frac = jnp.max(p_clean_norm_flat, axis=-1, keepdims=True)
442508
frac = jnp.clip(frac, 0.0, 1.0)
443509

444510
num_eligible = jnp.sum(eligible.astype(jnp.float32), axis=-1, keepdims=True)
445511
k = (num_eligible * frac).astype(jnp.int32)
446512

447-
# Top-k threshold
448-
seq_len = confidence.shape[-1]
513+
# Top-k threshold (operates on flattened positions)
514+
num_positions = confidence.shape[-1]
449515
sorted_conf = jnp.sort(confidence, axis=-1)[..., ::-1]
450516
threshold = jnp.take_along_axis(
451-
sorted_conf, jnp.clip(k - 1, 0, seq_len - 1), axis=-1
517+
sorted_conf, jnp.clip(k - 1, 0, num_positions - 1), axis=-1
452518
)
453519
# When k=0, no positions should be selected.
454520
to_update = (confidence >= threshold) & (k > 0)
455521

456-
# Selected positions → force CLEAN (zero out stay/noise).
457-
# Non-selected positions → zero out p_clean, keep original stay/noise.
522+
# Unflatten to_update back to original spatial shape
523+
to_update = to_update.reshape(batch_size, *spatial_shape)
524+
525+
# Non-selected positions have clean zeroed out; they can only stay or
526+
# re-noise according to the posterior.
458527
return RoutingWeights(
459528
stay=jnp.where(to_update[..., None], 0.0, routing_weights.stay),
460529
noise=jnp.where(to_update[..., None], 0.0, routing_weights.noise),
@@ -489,6 +558,8 @@ class UnMaskingStep(SamplerStep):
489558
490559
Attributes:
491560
corruption_process: The corruption process to use.
561+
planner: The planner to use for transforming routing weights. This is
562+
optional with the default being ``None`` (pure stochastic sampling).
492563
remasking_fn: The remasking function to use, see
493564
https://arxiv.org/abs/2503.00307v1. This is optional with the default
494565
being no remasking.
@@ -713,6 +784,12 @@ class DiscreteDDIMStep(SamplerStep):
713784
Note: when π = δ_MASK (masking process) and x_t = MASK, this reduces to:
714785
P(unmask) = (α_s - α_t) / (1 - α_t),
715786
which coincides with the UnMaskingStep formula (without remasking).
787+
788+
Attributes:
789+
corruption_process: The corruption process to use.
790+
planner: The planner to use for transforming routing weights. This is
791+
optional with the default being ``None`` (pure stochastic sampling).
792+
temperature: The temperature to use.
716793
"""
717794

718795
corruption_process: CategoricalProcess
@@ -863,15 +940,17 @@ class DiscreteFlowMatchingStep(SamplerStep):
863940
This sampler uses the 3-way routing representation. The update rule
864941
decomposes naturally into:
865942
866-
p(x_s) = p_stay * δ_{x_t} + p_up * p_x0 + p_down * π
943+
p(x_s) = p_stay * δ_{x_t} + p_clean * p_x0 + p_noise * π
867944
868945
where:
869-
- p_stay = 1 - p_up - p_down
870-
- p_up = (α_s - α_t) / (1 - α_t) * (1 + stoch_coeff)
871-
- p_down = (α_s - α_t) / α_t * stoch_coeff
946+
- p_stay = 1 - p_clean - p_noise
947+
- p_clean = (α_s - α_t) / (1 - α_t) * (1 + stoch_coeff)
948+
- p_noise = (α_s - α_t) / α_t * stoch_coeff
872949
873950
Attributes:
874951
corruption_process: The corruption process to use.
952+
planner: The planner to use for transforming routing weights. This is
953+
optional with the default being ``None`` (pure stochastic sampling).
875954
temperature: The temperature to use.
876955
stoch_coeff: The stochasticity coefficient (default 0.0). Higher values
877956
introduce more noise during the denoising process.
@@ -934,25 +1013,25 @@ def update(
9341013
alpha_s = self.corruption_process.schedule.alpha(next_time_bcast)
9351014
alpha_t = self.corruption_process.schedule.alpha(time_bcast)
9361015

937-
prob_up = (
1016+
p_clean_raw = (
9381017
(alpha_s - alpha_t)
9391018
/ jnp.maximum(1.0 - alpha_t, 1e-12)
9401019
* (1.0 + self.stoch_coeff)
9411020
)
942-
prob_down = (
1021+
p_noise_raw = (
9431022
(alpha_s - alpha_t) / jnp.maximum(alpha_t, 1e-12) * self.stoch_coeff
9441023
)
9451024

946-
# Clip and rescale to ensure valid probabilities
947-
raw_p_up = jnp.maximum(prob_up, 0.0)
948-
raw_p_down = jnp.maximum(prob_down, 0.0)
949-
sum_jumps = raw_p_up + raw_p_down
1025+
# Clip and rescale to ensure valid weights
1026+
p_clean_raw = jnp.maximum(p_clean_raw, 0.0)
1027+
p_noise_raw = jnp.maximum(p_noise_raw, 0.0)
1028+
sum_jumps = p_clean_raw + p_noise_raw
9501029
scale_factor = jnp.maximum(1.0, sum_jumps)
9511030

9521031
# Compute the probabilities for the three routing options.
9531032
# This is computed according to https://arxiv.org/abs/2407.15595.
954-
p_clean = raw_p_up / scale_factor
955-
p_noise = raw_p_down / scale_factor
1033+
p_clean = p_clean_raw / scale_factor
1034+
p_noise = p_noise_raw / scale_factor
9561035
p_stay = 1.0 - p_clean - p_noise
9571036

9581037
routing_weights = RoutingWeights(stay=p_stay, noise=p_noise, clean=p_clean)

hackable_diffusion/lib/sampling/discrete_step_sampler_test.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1054,6 +1054,54 @@ def test_greedy_planner_k_zero(self):
10541054
chex.assert_trees_all_close(out_probs.noise, expected.noise)
10551055
chex.assert_trees_all_close(out_probs.clean, expected.clean)
10561056

1057+
def test_greedy_planner_2d_spatial(self):
1058+
"""GreedyPlanner must work with 2D spatial data (e.g.
1059+
1060+
adjacency matrices).
1061+
"""
1062+
planner = discrete_step_sampler.GreedyPlanner()
1063+
# Shape (1, 3, 3, 1): batch=1, spatial=(3,3), vocab_trailing=1.
1064+
routing_weights = discrete_step_sampler.RoutingWeights(
1065+
stay=jnp.full((1, 3, 3, 1), 0.3),
1066+
noise=jnp.full((1, 3, 3, 1), 0.2),
1067+
clean=jnp.full((1, 3, 3, 1), 0.5),
1068+
)
1069+
# 9 positions total, 2 vocab classes.
1070+
# Logits: position (0,0) has highest confidence, then (0,1), etc.
1071+
logits_flat = jnp.array([
1072+
[10.0, 0.0],
1073+
[9.0, 0.0],
1074+
[8.0, 0.0],
1075+
[7.0, 0.0],
1076+
[6.0, 0.0],
1077+
[5.0, 0.0],
1078+
[4.0, 0.0],
1079+
[3.0, 0.0],
1080+
[2.0, 0.0],
1081+
])
1082+
logits = logits_flat.reshape(1, 3, 3, 2)
1083+
x0 = jnp.zeros((1, 3, 3, 1), dtype=jnp.int32)
1084+
xt = jnp.ones((1, 3, 3, 1), dtype=jnp.int32)
1085+
time = jnp.array([1.0])
1086+
next_time = jnp.array([0.5]) # frac = 0.5 -> budget = 9 * 0.5 = 4
1087+
key = jax.random.PRNGKey(0)
1088+
1089+
out = planner(routing_weights, logits, x0, xt, time, next_time, key)
1090+
1091+
# Output must have the same spatial shape.
1092+
self.assertEqual(out.stay.shape, (1, 3, 3, 1))
1093+
self.assertEqual(out.noise.shape, (1, 3, 3, 1))
1094+
self.assertEqual(out.clean.shape, (1, 3, 3, 1))
1095+
1096+
# Budget k = 4. Top-4 positions (by confidence) → forced CLEAN.
1097+
# Positions (0,0), (0,1), (0,2), (1,0) should be selected.
1098+
selected = out.clean[0, :, :, 0] # (3, 3)
1099+
num_selected = int(jnp.sum(selected > 0))
1100+
self.assertEqual(num_selected, 4)
1101+
# Selected positions have stay=0, noise=0.
1102+
self.assertEqual(float(jnp.sum(out.stay[out.clean > 0])), 0.0)
1103+
self.assertEqual(float(jnp.sum(out.noise[out.clean > 0])), 0.0)
1104+
10571105

10581106
if __name__ == '__main__':
10591107
absltest.main()

0 commit comments

Comments
 (0)