Skip to content

Commit a3fa40c

Browse files
vdebortoHackable Diffusion Authors
authored andcommitted
Update documentation for routing and planner architecture.
PiperOrigin-RevId: 908000167
1 parent ae44b17 commit a3fa40c

3 files changed

Lines changed: 280 additions & 29 deletions

File tree

hackable_diffusion/lib/sampling/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from hackable_diffusion.lib.sampling.discrete_step_sampler import CorruptedMaskFn
2626
from hackable_diffusion.lib.sampling.discrete_step_sampler import DiscreteDDIMStep
2727
from hackable_diffusion.lib.sampling.discrete_step_sampler import DiscreteFlowMatchingStep
28+
from hackable_diffusion.lib.sampling.discrete_step_sampler import GreedyPlanner
2829
from hackable_diffusion.lib.sampling.discrete_step_sampler import IntegratedDiscreteDDIMStep
2930
from hackable_diffusion.lib.sampling.discrete_step_sampler import MaskValueCorruptedMaskFn
3031
from hackable_diffusion.lib.sampling.discrete_step_sampler import MaxCappedRemaskingFn

hackable_diffusion/lib/sampling/discrete_step_sampler.py

Lines changed: 139 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -12,23 +12,25 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
"""Actual implementation of the sampling steps.
16-
17-
This module proposes various implementations but they all have in common
18-
the core logic:
19-
20-
* An `initialize` function that takes a starting state and returns the
21-
first step of the diffusion process.
22-
* An `update` function that takes the current state and returns the next step.
23-
* A `finalize` function that takes the last state and returns the final
24-
state.
25-
26-
At every step, the update function takes the current state and returns the next
27-
state. The update is also in charge of computing other auxiliary informations
28-
such as volatility, drifts, etc.
29-
30-
The `InferenceFn` is also called within the step and converted into the
31-
relevant representation, for instance score, velocity, etc.
15+
"""Actual implementation of the sampling steps for discrete diffusion.
16+
17+
This module implements various discrete samplers (UnMasking, DDIM, Flow
18+
Matching) using a unified **routing** architecture.
19+
20+
Core concepts:
21+
* **Routing**: Samplers decompose the transition probability into a mixture of
22+
three actions:
23+
- STAY (0): Keep the current token `xt`.
24+
- NOISE (1): Sample from the invariant distribution.
25+
- CLEAN (2): Use the predicted clean token `x0`.
26+
* **RoutingWeightPlanner**: An optional transformation applied to routing
27+
weights before they are sampled. This allows injecting custom selection
28+
strategies (e.g., greedy top-k) without modifying the sampler physics.
29+
30+
Standard interface for all samplers:
31+
* `initialize`: Takes a starting state and returns the first step.
32+
* `update`: Takes current state and returns the next step.
33+
* `finalize`: Takes last state and returns the final state.
3234
3335
This module also introduces the concepts of Routing and Planning for discrete
3436
diffusion:
@@ -381,6 +383,104 @@ def __call__(
381383
...
382384

383385

386+
@dataclasses.dataclass(frozen=True, kw_only=True)
387+
class GreedyPlanner(RoutingStrategy):
388+
"""Confidence-based top-k greedy planner.
389+
390+
Selects the top-k positions by confidence of the sampled x0 token, forces
391+
those to CLEAN, and handles the remaining positions according to the
392+
``resample`` flag.
393+
394+
The budget k is computed as:
395+
k = num_eligible * p_clean_norm
396+
where num_eligible is the number of positions with p_clean > 0 and
397+
p_clean_norm is the normalized clean probability. For a linear schedule
398+
α = 1−t this simplifies to k = num_eligible * (1 − next_time/time).
399+
400+
Attributes:
401+
tie_breaking_noise: Scale of uniform noise added to confidence scores for
402+
tie-breaking. Default 1e-6.
403+
resample: If True, non-selected positions keep their full original routing
404+
weights (stay, noise, and clean), allowing stochastic resampling via the
405+
posterior. If False (default), non-selected positions have their clean
406+
weight zeroed out so they can only stay or re-noise.
407+
"""
408+
409+
tie_breaking_noise: float = 1e-6
410+
resample: bool = False
411+
412+
@kt.typechecked
413+
def __call__(
414+
self,
415+
routing_weights: RoutingWeights,
416+
logits: Float['... M'],
417+
x0: DataArray,
418+
xt: DataArray,
419+
time: TimeArray,
420+
next_time: TimeArray,
421+
key: PRNGKey,
422+
) -> RoutingWeights:
423+
# Routing weights have shape (batch, *spatial, 1). We flatten all spatial
424+
# dims into a single "positions" dim so top-k selection works regardless
425+
# of whether the data is 1D (sequences) or 2D (adjacency matrices).
426+
batch_size = routing_weights.stay.shape[0]
427+
spatial_shape = routing_weights.stay.shape[1:-1] # e.g. (seq,) or (N, N)
428+
429+
# Confidence = softmax(logits)[x0] per position
430+
p = jax.nn.softmax(logits, axis=-1)
431+
confidence = jnp.take_along_axis(p, x0, axis=-1).squeeze(-1)
432+
# confidence: (batch, *spatial) → flatten to (batch, num_positions)
433+
confidence = confidence.reshape(batch_size, -1)
434+
435+
# Only consider positions that could go clean (p_clean > 0)
436+
eligible = routing_weights.clean[..., 0] > 0
437+
eligible = eligible.reshape(batch_size, -1)
438+
confidence = jnp.where(eligible, confidence, -jnp.inf)
439+
440+
# Add tie-breaking noise
441+
_, subkey = jax.random.split(key)
442+
confidence += jax.random.uniform(subkey, confidence.shape) * (
443+
self.tie_breaking_noise
444+
)
445+
446+
# Budget: k = num_eligible * p_clean_norm.
447+
# p_clean_norm = p_clean / (p_stay + p_noise + p_clean) is the same
448+
# for all eligible positions (π(x_t) cancels in normalization).
449+
# For a linear schedule (α = 1-t), this equals (1 - next_time/time).
450+
total_weight = (
451+
routing_weights.stay + routing_weights.noise + routing_weights.clean
452+
)
453+
p_clean_norm = routing_weights.clean / jnp.maximum(total_weight, 1e-12)
454+
# p_clean_norm is the same for all eligible positions. Take max over spatial
455+
# dimensions to get the value (ineligible positions have 0).
456+
p_clean_norm_flat = p_clean_norm.reshape(batch_size, -1)
457+
frac = jnp.max(p_clean_norm_flat, axis=-1, keepdims=True)
458+
frac = jnp.clip(frac, 0.0, 1.0)
459+
460+
num_eligible = jnp.sum(eligible.astype(jnp.float32), axis=-1, keepdims=True)
461+
k = (num_eligible * frac).astype(jnp.int32)
462+
463+
# Top-k threshold (operates on flattened positions)
464+
num_positions = confidence.shape[-1]
465+
sorted_conf = jnp.sort(confidence, axis=-1)[..., ::-1]
466+
threshold = jnp.take_along_axis(
467+
sorted_conf, jnp.clip(k - 1, 0, num_positions - 1), axis=-1
468+
)
469+
# When k=0, no positions should be selected.
470+
to_update = (confidence >= threshold) & (k > 0)
471+
472+
# Unflatten to_update back to original spatial shape
473+
to_update = to_update.reshape(batch_size, *spatial_shape)
474+
475+
# Non-selected positions have clean zeroed out; they can only stay or
476+
# re-noise according to the posterior.
477+
return RoutingWeights(
478+
stay=jnp.where(to_update[..., None], 0.0, routing_weights.stay),
479+
noise=jnp.where(to_update[..., None], 0.0, routing_weights.noise),
480+
clean=jnp.where(to_update[..., None], 1.0, 0.0),
481+
)
482+
483+
384484
################################################################################
385485
# MARK: UnMasking Step
386486
################################################################################
@@ -408,6 +508,8 @@ class UnMaskingStep(SamplerStep):
408508
409509
Attributes:
410510
corruption_process: The corruption process to use.
511+
planner: The planner to use for transforming routing weights. This is
512+
optional with the default being ``None`` (pure stochastic sampling).
411513
remasking_fn: The remasking function to use, see
412514
https://arxiv.org/abs/2503.00307v1. This is optional with the default
413515
being no remasking.
@@ -632,6 +734,12 @@ class DiscreteDDIMStep(SamplerStep):
632734
Note: when π = δ_MASK (masking process) and x_t = MASK, this reduces to:
633735
P(unmask) = (α_s - α_t) / (1 - α_t),
634736
which coincides with the UnMaskingStep formula (without remasking).
737+
738+
Attributes:
739+
corruption_process: The corruption process to use.
740+
planner: The planner to use for transforming routing weights. This is
741+
optional with the default being ``None`` (pure stochastic sampling).
742+
temperature: The temperature to use.
635743
"""
636744

637745
corruption_process: CategoricalProcess
@@ -782,15 +890,17 @@ class DiscreteFlowMatchingStep(SamplerStep):
782890
This sampler uses the 3-way routing representation. The update rule
783891
decomposes naturally into:
784892
785-
p(x_s) = p_stay * δ_{x_t} + p_up * p_x0 + p_down * π
893+
p(x_s) = p_stay * δ_{x_t} + p_clean * p_x0 + p_noise * π
786894
787895
where:
788-
- p_stay = 1 - p_up - p_down
789-
- p_up = (α_s - α_t) / (1 - α_t) * (1 + stoch_coeff)
790-
- p_down = (α_s - α_t) / α_t * stoch_coeff
896+
- p_stay = 1 - p_clean - p_noise
897+
- p_clean = (α_s - α_t) / (1 - α_t) * (1 + stoch_coeff)
898+
- p_noise = (α_s - α_t) / α_t * stoch_coeff
791899
792900
Attributes:
793901
corruption_process: The corruption process to use.
902+
planner: The planner to use for transforming routing weights. This is
903+
optional with the default being ``None`` (pure stochastic sampling).
794904
temperature: The temperature to use.
795905
stoch_coeff: The stochasticity coefficient (default 0.0). Higher values
796906
introduce more noise during the denoising process.
@@ -853,25 +963,25 @@ def update(
853963
alpha_s = self.corruption_process.schedule.alpha(next_time_bcast)
854964
alpha_t = self.corruption_process.schedule.alpha(time_bcast)
855965

856-
prob_up = (
966+
p_clean_raw = (
857967
(alpha_s - alpha_t)
858968
/ jnp.maximum(1.0 - alpha_t, 1e-12)
859969
* (1.0 + self.stoch_coeff)
860970
)
861-
prob_down = (
971+
p_noise_raw = (
862972
(alpha_s - alpha_t) / jnp.maximum(alpha_t, 1e-12) * self.stoch_coeff
863973
)
864974

865-
# Clip and rescale to ensure valid probabilities
866-
raw_p_up = jnp.maximum(prob_up, 0.0)
867-
raw_p_down = jnp.maximum(prob_down, 0.0)
868-
sum_jumps = raw_p_up + raw_p_down
975+
# Clip and rescale to ensure valid weights
976+
p_clean_raw = jnp.maximum(p_clean_raw, 0.0)
977+
p_noise_raw = jnp.maximum(p_noise_raw, 0.0)
978+
sum_jumps = p_clean_raw + p_noise_raw
869979
scale_factor = jnp.maximum(1.0, sum_jumps)
870980

871981
# Compute the probabilities for the three routing options.
872982
# This is computed according to https://arxiv.org/abs/2407.15595.
873-
p_clean = raw_p_up / scale_factor
874-
p_noise = raw_p_down / scale_factor
983+
p_clean = p_clean_raw / scale_factor
984+
p_noise = p_noise_raw / scale_factor
875985
p_stay = 1.0 - p_clean - p_noise
876986

877987
routing_weights = RoutingWeights(stay=p_stay, noise=p_noise, clean=p_clean)

hackable_diffusion/lib/sampling/discrete_step_sampler_test.py

Lines changed: 140 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -963,5 +963,145 @@ def __call__(self, routing_weights, logits, x0, xt, time, next_time, key):
963963
chex.assert_trees_all_equal(out, routing_weights)
964964

965965

966+
class GreedyPlannerTest(absltest.TestCase):
967+
968+
def test_greedy_planner_budget(self):
969+
planner = discrete_step_sampler.GreedyPlanner()
970+
# 1 batch, 4 seq len. Realistic routing: all eligible with stay/noise > 0.
971+
routing_weights = discrete_step_sampler.RoutingWeights(
972+
stay=jnp.full((1, 4, 1), 0.3),
973+
noise=jnp.full((1, 4, 1), 0.2),
974+
clean=jnp.full((1, 4, 1), 0.5),
975+
)
976+
logits = jnp.array(
977+
[[[10.0, 0.0], [5.0, 0.0], [2.0, 0.0], [1.0, 0.0]]]
978+
) # high confidence first
979+
x0 = jnp.zeros((1, 4, 1), dtype=jnp.int32)
980+
xt = jnp.ones((1, 4, 1), dtype=jnp.int32)
981+
time = jnp.array([1.0])
982+
next_time = jnp.array([0.5]) # frac = 0.5 -> budget = 4 eligible * 0.5 = 2
983+
key = jax.random.PRNGKey(0)
984+
985+
out_probs = planner(routing_weights, logits, x0, xt, time, next_time, key)
986+
987+
# Top 2 positions → force CLEAN (stay=0, noise=0, clean=1).
988+
# Non-selected → keep original stay/noise, zero out clean.
989+
expected = discrete_step_sampler.RoutingWeights(
990+
stay=jnp.array([[[0.0], [0.0], [0.3], [0.3]]]),
991+
noise=jnp.array([[[0.0], [0.0], [0.2], [0.2]]]),
992+
clean=jnp.array([[[1.0], [1.0], [0.0], [0.0]]]),
993+
)
994+
chex.assert_trees_all_close(out_probs.stay, expected.stay)
995+
chex.assert_trees_all_close(out_probs.noise, expected.noise)
996+
chex.assert_trees_all_close(out_probs.clean, expected.clean)
997+
998+
def test_greedy_planner_eligibility(self):
999+
planner = discrete_step_sampler.GreedyPlanner()
1000+
# Position 0 is NOT eligible (p_clean = 0), has original stay=1.0.
1001+
routing_weights = discrete_step_sampler.RoutingWeights(
1002+
stay=jnp.array([[[1.0], [0.3], [0.3], [0.3]]]),
1003+
noise=jnp.array([[[0.0], [0.2], [0.2], [0.2]]]),
1004+
clean=jnp.array([[[0.0], [0.5], [0.5], [0.5]]]),
1005+
)
1006+
logits = jnp.array(
1007+
[[[10.0, 0.0], [5.0, 0.0], [2.0, 0.0], [1.0, 0.0]]]
1008+
) # Pos 0 has highest logit but ineligible
1009+
x0 = jnp.zeros((1, 4, 1), dtype=jnp.int32)
1010+
xt = jnp.ones((1, 4, 1), dtype=jnp.int32)
1011+
time = jnp.array([1.0])
1012+
next_time = jnp.array([0.5]) # frac = 0.5 -> budget = 3 eligible * 0.5 = 1
1013+
key = jax.random.PRNGKey(0)
1014+
1015+
out_probs = planner(routing_weights, logits, x0, xt, time, next_time, key)
1016+
1017+
# Pos 0 is ineligible (p_clean=0), so num_eligible=3.
1018+
# Budget = 3 * 0.5 = 1 (truncated to int).
1019+
# Top 1 eligible position by confidence: Pos 1 → force CLEAN.
1020+
# Non-selected (Pos 0, 2, 3) → keep original stay/noise, zero clean.
1021+
expected = discrete_step_sampler.RoutingWeights(
1022+
stay=jnp.array([[[1.0], [0.0], [0.3], [0.3]]]),
1023+
noise=jnp.array([[[0.0], [0.0], [0.2], [0.2]]]),
1024+
clean=jnp.array([[[0.0], [1.0], [0.0], [0.0]]]),
1025+
)
1026+
chex.assert_trees_all_close(out_probs.stay, expected.stay)
1027+
chex.assert_trees_all_close(out_probs.noise, expected.noise)
1028+
chex.assert_trees_all_close(out_probs.clean, expected.clean)
1029+
1030+
def test_greedy_planner_k_zero(self):
1031+
"""When clean weight is small, budget k=0. Keep original stay/noise."""
1032+
planner = discrete_step_sampler.GreedyPlanner()
1033+
routing_weights = discrete_step_sampler.RoutingWeights(
1034+
stay=jnp.full((1, 4, 1), 0.7),
1035+
noise=jnp.full((1, 4, 1), 0.2),
1036+
clean=jnp.full((1, 4, 1), 0.1),
1037+
)
1038+
logits = jnp.array([[[10.0, 0.0], [5.0, 0.0], [2.0, 0.0], [1.0, 0.0]]])
1039+
x0 = jnp.zeros((1, 4, 1), dtype=jnp.int32)
1040+
xt = jnp.ones((1, 4, 1), dtype=jnp.int32)
1041+
time = jnp.array([1.0])
1042+
next_time = jnp.array([1.0])
1043+
key = jax.random.PRNGKey(0)
1044+
1045+
out_probs = planner(routing_weights, logits, x0, xt, time, next_time, key)
1046+
1047+
# k=0: no positions selected. All keep original stay/noise, clean zeroed.
1048+
expected = discrete_step_sampler.RoutingWeights(
1049+
stay=jnp.full((1, 4, 1), 0.7),
1050+
noise=jnp.full((1, 4, 1), 0.2),
1051+
clean=jnp.zeros((1, 4, 1)),
1052+
)
1053+
chex.assert_trees_all_close(out_probs.stay, expected.stay)
1054+
chex.assert_trees_all_close(out_probs.noise, expected.noise)
1055+
chex.assert_trees_all_close(out_probs.clean, expected.clean)
1056+
1057+
def test_greedy_planner_2d_spatial(self):
1058+
"""GreedyPlanner must work with 2D spatial data (e.g.
1059+
1060+
adjacency matrices).
1061+
"""
1062+
planner = discrete_step_sampler.GreedyPlanner()
1063+
# Shape (1, 3, 3, 1): batch=1, spatial=(3,3), vocab_trailing=1.
1064+
routing_weights = discrete_step_sampler.RoutingWeights(
1065+
stay=jnp.full((1, 3, 3, 1), 0.3),
1066+
noise=jnp.full((1, 3, 3, 1), 0.2),
1067+
clean=jnp.full((1, 3, 3, 1), 0.5),
1068+
)
1069+
# 9 positions total, 2 vocab classes.
1070+
# Logits: position (0,0) has highest confidence, then (0,1), etc.
1071+
logits_flat = jnp.array([
1072+
[10.0, 0.0],
1073+
[9.0, 0.0],
1074+
[8.0, 0.0],
1075+
[7.0, 0.0],
1076+
[6.0, 0.0],
1077+
[5.0, 0.0],
1078+
[4.0, 0.0],
1079+
[3.0, 0.0],
1080+
[2.0, 0.0],
1081+
])
1082+
logits = logits_flat.reshape(1, 3, 3, 2)
1083+
x0 = jnp.zeros((1, 3, 3, 1), dtype=jnp.int32)
1084+
xt = jnp.ones((1, 3, 3, 1), dtype=jnp.int32)
1085+
time = jnp.array([1.0])
1086+
next_time = jnp.array([0.5]) # frac = 0.5 -> budget = 9 * 0.5 = 4
1087+
key = jax.random.PRNGKey(0)
1088+
1089+
out = planner(routing_weights, logits, x0, xt, time, next_time, key)
1090+
1091+
# Output must have the same spatial shape.
1092+
self.assertEqual(out.stay.shape, (1, 3, 3, 1))
1093+
self.assertEqual(out.noise.shape, (1, 3, 3, 1))
1094+
self.assertEqual(out.clean.shape, (1, 3, 3, 1))
1095+
1096+
# Budget k = 4. Top-4 positions (by confidence) → forced CLEAN.
1097+
# Positions (0,0), (0,1), (0,2), (1,0) should be selected.
1098+
selected = out.clean[0, :, :, 0] # (3, 3)
1099+
num_selected = int(jnp.sum(selected > 0))
1100+
self.assertEqual(num_selected, 4)
1101+
# Selected positions have stay=0, noise=0.
1102+
self.assertEqual(float(jnp.sum(out.stay[out.clean > 0])), 0.0)
1103+
self.assertEqual(float(jnp.sum(out.noise[out.clean > 0])), 0.0)
1104+
1105+
9661106
if __name__ == '__main__':
9671107
absltest.main()

0 commit comments

Comments
 (0)