Skip to content

Commit acfc276

Browse files
vdebortoHackable Diffusion Authors
authored andcommitted
Update documentation for routing and planner architecture.
PiperOrigin-RevId: 912015141
1 parent 68eed8a commit acfc276

2 files changed

Lines changed: 114 additions & 37 deletions

File tree

hackable_diffusion/lib/sampling/discrete_step_sampler.py

Lines changed: 66 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -12,23 +12,25 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
"""Actual implementation of the sampling steps.
16-
17-
This module proposes various implementations but they all have in common
18-
the core logic:
19-
20-
* An `initialize` function that takes a starting state and returns the
21-
first step of the diffusion process.
22-
* An `update` function that takes the current state and returns the next step.
23-
* A `finalize` function that takes the last state and returns the final
24-
state.
25-
26-
At every step, the update function takes the current state and returns the next
27-
state. The update is also in charge of computing other auxiliary informations
28-
such as volatility, drifts, etc.
29-
30-
The `InferenceFn` is also called within the step and converted into the
31-
relevant representation, for instance score, velocity, etc.
15+
"""Actual implementation of the sampling steps for discrete diffusion.
16+
17+
This module implements various discrete samplers (UnMasking, DDIM, Flow
18+
Matching) using a unified **routing** architecture.
19+
20+
Core concepts:
21+
* **Routing**: Samplers decompose the transition probability into a mixture of
22+
three actions:
23+
- STAY (0): Keep the current token `xt`.
24+
- NOISE (1): Sample from the invariant distribution.
25+
- CLEAN (2): Use the predicted clean token `x0`.
26+
* **RoutingWeightPlanner**: An optional transformation applied to routing
27+
weights before they are sampled. This allows injecting custom selection
28+
strategies (e.g., greedy top-k) without modifying the sampler physics.
29+
30+
Standard interface for all samplers:
31+
* `initialize`: Takes a starting state and returns the first step.
32+
* `update`: Takes current state and returns the next step.
33+
* `finalize`: Takes last state and returns the final state.
3234
3335
This module also introduces the concepts of Routing and Planning for discrete
3436
diffusion:
@@ -386,8 +388,8 @@ class GreedyPlanner(RoutingStrategy):
386388
"""Confidence-based top-k greedy planner.
387389
388390
Selects the top-k positions by confidence of the sampled x0 token, forces
389-
those to CLEAN, and lets the remaining positions follow their natural
390-
stay/noise dynamics (with p_clean zeroed out).
391+
those to CLEAN, and handles the remaining positions according to the
392+
``resample`` flag.
391393
392394
The budget k is computed as:
393395
k = num_eligible * p_clean_norm
@@ -398,9 +400,14 @@ class GreedyPlanner(RoutingStrategy):
398400
Attributes:
399401
tie_breaking_noise: Scale of uniform noise added to confidence scores for
400402
tie-breaking. Default 1e-6.
403+
resample: If True, non-selected positions keep their full original routing
404+
weights (stay, noise, and clean), allowing stochastic resampling via the
405+
posterior. If False (default), non-selected positions have their clean
406+
weight zeroed out so they can only stay or re-noise.
401407
"""
402408

403409
tie_breaking_noise: float = 1e-6
410+
resample: bool = False
404411

405412
@kt.typechecked
406413
def __call__(
@@ -413,12 +420,21 @@ def __call__(
413420
next_time: TimeArray,
414421
key: PRNGKey,
415422
) -> RoutingWeights:
423+
# Routing weights have shape (batch, *spatial, 1). We flatten all spatial
424+
# dims into a single "positions" dim so top-k selection works regardless
425+
# of whether the data is 1D (sequences) or 2D (adjacency matrices).
426+
batch_size = routing_weights.stay.shape[0]
427+
spatial_shape = routing_weights.stay.shape[1:-1] # e.g. (seq,) or (N, N)
428+
416429
# Confidence = softmax(logits)[x0] per position
417430
p = jax.nn.softmax(logits, axis=-1)
418431
confidence = jnp.take_along_axis(p, x0, axis=-1).squeeze(-1)
432+
# confidence: (batch, *spatial) → flatten to (batch, num_positions)
433+
confidence = confidence.reshape(batch_size, -1)
419434

420435
# Only consider positions that could go clean (p_clean > 0)
421436
eligible = routing_weights.clean[..., 0] > 0
437+
eligible = eligible.reshape(batch_size, -1)
422438
confidence = jnp.where(eligible, confidence, -jnp.inf)
423439

424440
# Add tie-breaking noise
@@ -437,24 +453,27 @@ def __call__(
437453
p_clean_norm = routing_weights.clean / jnp.maximum(total_weight, 1e-12)
438454
# p_clean_norm is the same for all eligible positions. Take max over spatial
439455
# dimensions to get the value (ineligible positions have 0).
440-
p_clean_norm_flat = p_clean_norm.reshape(p_clean_norm.shape[0], -1)
456+
p_clean_norm_flat = p_clean_norm.reshape(batch_size, -1)
441457
frac = jnp.max(p_clean_norm_flat, axis=-1, keepdims=True)
442458
frac = jnp.clip(frac, 0.0, 1.0)
443459

444460
num_eligible = jnp.sum(eligible.astype(jnp.float32), axis=-1, keepdims=True)
445461
k = (num_eligible * frac).astype(jnp.int32)
446462

447-
# Top-k threshold
448-
seq_len = confidence.shape[-1]
463+
# Top-k threshold (operates on flattened positions)
464+
num_positions = confidence.shape[-1]
449465
sorted_conf = jnp.sort(confidence, axis=-1)[..., ::-1]
450466
threshold = jnp.take_along_axis(
451-
sorted_conf, jnp.clip(k - 1, 0, seq_len - 1), axis=-1
467+
sorted_conf, jnp.clip(k - 1, 0, num_positions - 1), axis=-1
452468
)
453469
# When k=0, no positions should be selected.
454470
to_update = (confidence >= threshold) & (k > 0)
455471

456-
# Selected positions → force CLEAN (zero out stay/noise).
457-
# Non-selected positions → zero out p_clean, keep original stay/noise.
472+
# Unflatten to_update back to original spatial shape
473+
to_update = to_update.reshape(batch_size, *spatial_shape)
474+
475+
# Non-selected positions have clean zeroed out; they can only stay or
476+
# re-noise according to the posterior.
458477
return RoutingWeights(
459478
stay=jnp.where(to_update[..., None], 0.0, routing_weights.stay),
460479
noise=jnp.where(to_update[..., None], 0.0, routing_weights.noise),
@@ -489,6 +508,8 @@ class UnMaskingStep(SamplerStep):
489508
490509
Attributes:
491510
corruption_process: The corruption process to use.
511+
planner: The planner to use for transforming routing weights. This is
512+
optional with the default being ``None`` (pure stochastic sampling).
492513
remasking_fn: The remasking function to use, see
493514
https://arxiv.org/abs/2503.00307v1. This is optional with the default
494515
being no remasking.
@@ -713,6 +734,12 @@ class DiscreteDDIMStep(SamplerStep):
713734
Note: when π = δ_MASK (masking process) and x_t = MASK, this reduces to:
714735
P(unmask) = (α_s - α_t) / (1 - α_t),
715736
which coincides with the UnMaskingStep formula (without remasking).
737+
738+
Attributes:
739+
corruption_process: The corruption process to use.
740+
planner: The planner to use for transforming routing weights. This is
741+
optional with the default being ``None`` (pure stochastic sampling).
742+
temperature: The temperature to use.
716743
"""
717744

718745
corruption_process: CategoricalProcess
@@ -863,15 +890,17 @@ class DiscreteFlowMatchingStep(SamplerStep):
863890
This sampler uses the 3-way routing representation. The update rule
864891
decomposes naturally into:
865892
866-
p(x_s) = p_stay * δ_{x_t} + p_up * p_x0 + p_down * π
893+
p(x_s) = p_stay * δ_{x_t} + p_clean * p_x0 + p_noise * π
867894
868895
where:
869-
- p_stay = 1 - p_up - p_down
870-
- p_up = (α_s - α_t) / (1 - α_t) * (1 + stoch_coeff)
871-
- p_down = (α_s - α_t) / α_t * stoch_coeff
896+
- p_stay = 1 - p_clean - p_noise
897+
- p_clean = (α_s - α_t) / (1 - α_t) * (1 + stoch_coeff)
898+
- p_noise = (α_s - α_t) / α_t * stoch_coeff
872899
873900
Attributes:
874901
corruption_process: The corruption process to use.
902+
planner: The planner to use for transforming routing weights. This is
903+
optional with the default being ``None`` (pure stochastic sampling).
875904
temperature: The temperature to use.
876905
stoch_coeff: The stochasticity coefficient (default 0.0). Higher values
877906
introduce more noise during the denoising process.
@@ -934,25 +963,25 @@ def update(
934963
alpha_s = self.corruption_process.schedule.alpha(next_time_bcast)
935964
alpha_t = self.corruption_process.schedule.alpha(time_bcast)
936965

937-
prob_up = (
966+
p_clean_raw = (
938967
(alpha_s - alpha_t)
939968
/ jnp.maximum(1.0 - alpha_t, 1e-12)
940969
* (1.0 + self.stoch_coeff)
941970
)
942-
prob_down = (
971+
p_noise_raw = (
943972
(alpha_s - alpha_t) / jnp.maximum(alpha_t, 1e-12) * self.stoch_coeff
944973
)
945974

946-
# Clip and rescale to ensure valid probabilities
947-
raw_p_up = jnp.maximum(prob_up, 0.0)
948-
raw_p_down = jnp.maximum(prob_down, 0.0)
949-
sum_jumps = raw_p_up + raw_p_down
975+
# Clip and rescale to ensure valid weights
976+
p_clean_raw = jnp.maximum(p_clean_raw, 0.0)
977+
p_noise_raw = jnp.maximum(p_noise_raw, 0.0)
978+
sum_jumps = p_clean_raw + p_noise_raw
950979
scale_factor = jnp.maximum(1.0, sum_jumps)
951980

952981
# Compute the probabilities for the three routing options.
953982
# This is computed according to https://arxiv.org/abs/2407.15595.
954-
p_clean = raw_p_up / scale_factor
955-
p_noise = raw_p_down / scale_factor
983+
p_clean = p_clean_raw / scale_factor
984+
p_noise = p_noise_raw / scale_factor
956985
p_stay = 1.0 - p_clean - p_noise
957986

958987
routing_weights = RoutingWeights(stay=p_stay, noise=p_noise, clean=p_clean)

hackable_diffusion/lib/sampling/discrete_step_sampler_test.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1054,6 +1054,54 @@ def test_greedy_planner_k_zero(self):
10541054
chex.assert_trees_all_close(out_probs.noise, expected.noise)
10551055
chex.assert_trees_all_close(out_probs.clean, expected.clean)
10561056

1057+
def test_greedy_planner_2d_spatial(self):
1058+
"""GreedyPlanner must work with 2D spatial data (e.g.
1059+
1060+
adjacency matrices).
1061+
"""
1062+
planner = discrete_step_sampler.GreedyPlanner()
1063+
# Shape (1, 3, 3, 1): batch=1, spatial=(3,3), vocab_trailing=1.
1064+
routing_weights = discrete_step_sampler.RoutingWeights(
1065+
stay=jnp.full((1, 3, 3, 1), 0.3),
1066+
noise=jnp.full((1, 3, 3, 1), 0.2),
1067+
clean=jnp.full((1, 3, 3, 1), 0.5),
1068+
)
1069+
# 9 positions total, 2 vocab classes.
1070+
# Logits: position (0,0) has highest confidence, then (0,1), etc.
1071+
logits_flat = jnp.array([
1072+
[10.0, 0.0],
1073+
[9.0, 0.0],
1074+
[8.0, 0.0],
1075+
[7.0, 0.0],
1076+
[6.0, 0.0],
1077+
[5.0, 0.0],
1078+
[4.0, 0.0],
1079+
[3.0, 0.0],
1080+
[2.0, 0.0],
1081+
])
1082+
logits = logits_flat.reshape(1, 3, 3, 2)
1083+
x0 = jnp.zeros((1, 3, 3, 1), dtype=jnp.int32)
1084+
xt = jnp.ones((1, 3, 3, 1), dtype=jnp.int32)
1085+
time = jnp.array([1.0])
1086+
next_time = jnp.array([0.5]) # frac = 0.5 -> budget = 9 * 0.5 = 4
1087+
key = jax.random.PRNGKey(0)
1088+
1089+
out = planner(routing_weights, logits, x0, xt, time, next_time, key)
1090+
1091+
# Output must have the same spatial shape.
1092+
self.assertEqual(out.stay.shape, (1, 3, 3, 1))
1093+
self.assertEqual(out.noise.shape, (1, 3, 3, 1))
1094+
self.assertEqual(out.clean.shape, (1, 3, 3, 1))
1095+
1096+
# Budget k = 4. Top-4 positions (by confidence) → forced CLEAN.
1097+
# Positions (0,0), (0,1), (0,2), (1,0) should be selected.
1098+
selected = out.clean[0, :, :, 0] # (3, 3)
1099+
num_selected = int(jnp.sum(selected > 0))
1100+
self.assertEqual(num_selected, 4)
1101+
# Selected positions have stay=0, noise=0.
1102+
self.assertEqual(float(jnp.sum(out.stay[out.clean > 0])), 0.0)
1103+
self.assertEqual(float(jnp.sum(out.noise[out.clean > 0])), 0.0)
1104+
10571105

10581106
if __name__ == '__main__':
10591107
absltest.main()

0 commit comments

Comments
 (0)