Skip to content

Commit 2fb7aa9

Browse files
vdebortoHackable Diffusion Authors
authored andcommitted
Add GreedyPlanner matching.
PiperOrigin-RevId: 908000168
1 parent ae44b17 commit 2fb7aa9

3 files changed

Lines changed: 174 additions & 0 deletions

File tree

hackable_diffusion/lib/sampling/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from hackable_diffusion.lib.sampling.discrete_step_sampler import CorruptedMaskFn
2626
from hackable_diffusion.lib.sampling.discrete_step_sampler import DiscreteDDIMStep
2727
from hackable_diffusion.lib.sampling.discrete_step_sampler import DiscreteFlowMatchingStep
28+
from hackable_diffusion.lib.sampling.discrete_step_sampler import GreedyPlanner
2829
from hackable_diffusion.lib.sampling.discrete_step_sampler import IntegratedDiscreteDDIMStep
2930
from hackable_diffusion.lib.sampling.discrete_step_sampler import MaskValueCorruptedMaskFn
3031
from hackable_diffusion.lib.sampling.discrete_step_sampler import MaxCappedRemaskingFn

hackable_diffusion/lib/sampling/discrete_step_sampler.py

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -381,6 +381,87 @@ def __call__(
381381
...
382382

383383

384+
@dataclasses.dataclass(frozen=True, kw_only=True)
385+
class GreedyPlanner(RoutingStrategy):
386+
"""Confidence-based top-k greedy planner.
387+
388+
Selects the top-k positions by confidence of the sampled x0 token, forces
389+
those to CLEAN, and lets the remaining positions follow their natural
390+
stay/noise dynamics (with p_clean zeroed out).
391+
392+
The budget k is computed as:
393+
k = num_eligible * p_clean_norm
394+
where num_eligible is the number of positions with p_clean > 0 and
395+
p_clean_norm is the normalized clean probability. For a linear schedule
396+
α = 1−t this simplifies to k = num_eligible * (1 − next_time/time).
397+
398+
Attributes:
399+
tie_breaking_noise: Scale of uniform noise added to confidence scores for
400+
tie-breaking. Default 1e-6.
401+
"""
402+
403+
tie_breaking_noise: float = 1e-6
404+
405+
@kt.typechecked
406+
def __call__(
407+
self,
408+
routing_weights: RoutingWeights,
409+
logits: Float['... M'],
410+
x0: DataArray,
411+
xt: DataArray,
412+
time: TimeArray,
413+
next_time: TimeArray,
414+
key: PRNGKey,
415+
) -> RoutingWeights:
416+
# Confidence = softmax(logits)[x0] per position
417+
p = jax.nn.softmax(logits, axis=-1)
418+
confidence = jnp.take_along_axis(p, x0, axis=-1).squeeze(-1)
419+
420+
# Only consider positions that could go clean (p_clean > 0)
421+
eligible = routing_weights.clean[..., 0] > 0
422+
confidence = jnp.where(eligible, confidence, -jnp.inf)
423+
424+
# Add tie-breaking noise
425+
_, subkey = jax.random.split(key)
426+
confidence += jax.random.uniform(subkey, confidence.shape) * (
427+
self.tie_breaking_noise
428+
)
429+
430+
# Budget: k = num_eligible * p_clean_norm.
431+
# p_clean_norm = p_clean / (p_stay + p_noise + p_clean) is the same
432+
# for all eligible positions (π(x_t) cancels in normalization).
433+
# For a linear schedule (α = 1-t), this equals (1 - next_time/time).
434+
total_weight = (
435+
routing_weights.stay + routing_weights.noise + routing_weights.clean
436+
)
437+
p_clean_norm = routing_weights.clean / jnp.maximum(total_weight, 1e-12)
438+
# p_clean_norm is the same for all eligible positions. Take max over spatial
439+
# dimensions to get the value (ineligible positions have 0).
440+
p_clean_norm_flat = p_clean_norm.reshape(p_clean_norm.shape[0], -1)
441+
frac = jnp.max(p_clean_norm_flat, axis=-1, keepdims=True)
442+
frac = jnp.clip(frac, 0.0, 1.0)
443+
444+
num_eligible = jnp.sum(eligible.astype(jnp.float32), axis=-1, keepdims=True)
445+
k = (num_eligible * frac).astype(jnp.int32)
446+
447+
# Top-k threshold
448+
seq_len = confidence.shape[-1]
449+
sorted_conf = jnp.sort(confidence, axis=-1)[..., ::-1]
450+
threshold = jnp.take_along_axis(
451+
sorted_conf, jnp.clip(k - 1, 0, seq_len - 1), axis=-1
452+
)
453+
# When k=0, no positions should be selected.
454+
to_update = (confidence >= threshold) & (k > 0)
455+
456+
# Selected positions → force CLEAN (zero out stay/noise).
457+
# Non-selected positions → zero out p_clean, keep original stay/noise.
458+
return RoutingWeights(
459+
stay=jnp.where(to_update[..., None], 0.0, routing_weights.stay),
460+
noise=jnp.where(to_update[..., None], 0.0, routing_weights.noise),
461+
clean=jnp.where(to_update[..., None], 1.0, 0.0),
462+
)
463+
464+
384465
################################################################################
385466
# MARK: UnMasking Step
386467
################################################################################

hackable_diffusion/lib/sampling/discrete_step_sampler_test.py

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -963,5 +963,97 @@ def __call__(self, routing_weights, logits, x0, xt, time, next_time, key):
963963
chex.assert_trees_all_equal(out, routing_weights)
964964

965965

966+
class GreedyPlannerTest(absltest.TestCase):
967+
968+
def test_greedy_planner_budget(self):
969+
planner = discrete_step_sampler.GreedyPlanner()
970+
# 1 batch, 4 seq len. Realistic routing: all eligible with stay/noise > 0.
971+
routing_weights = discrete_step_sampler.RoutingWeights(
972+
stay=jnp.full((1, 4, 1), 0.3),
973+
noise=jnp.full((1, 4, 1), 0.2),
974+
clean=jnp.full((1, 4, 1), 0.5),
975+
)
976+
logits = jnp.array(
977+
[[[10.0, 0.0], [5.0, 0.0], [2.0, 0.0], [1.0, 0.0]]]
978+
) # high confidence first
979+
x0 = jnp.zeros((1, 4, 1), dtype=jnp.int32)
980+
xt = jnp.ones((1, 4, 1), dtype=jnp.int32)
981+
time = jnp.array([1.0])
982+
next_time = jnp.array([0.5]) # frac = 0.5 -> budget = 4 eligible * 0.5 = 2
983+
key = jax.random.PRNGKey(0)
984+
985+
out_probs = planner(routing_weights, logits, x0, xt, time, next_time, key)
986+
987+
# Top 2 positions → force CLEAN (stay=0, noise=0, clean=1).
988+
# Non-selected → keep original stay/noise, zero out clean.
989+
expected = discrete_step_sampler.RoutingWeights(
990+
stay=jnp.array([[[0.0], [0.0], [0.3], [0.3]]]),
991+
noise=jnp.array([[[0.0], [0.0], [0.2], [0.2]]]),
992+
clean=jnp.array([[[1.0], [1.0], [0.0], [0.0]]]),
993+
)
994+
chex.assert_trees_all_close(out_probs.stay, expected.stay)
995+
chex.assert_trees_all_close(out_probs.noise, expected.noise)
996+
chex.assert_trees_all_close(out_probs.clean, expected.clean)
997+
998+
def test_greedy_planner_eligibility(self):
999+
planner = discrete_step_sampler.GreedyPlanner()
1000+
# Position 0 is NOT eligible (p_clean = 0), has original stay=1.0.
1001+
routing_weights = discrete_step_sampler.RoutingWeights(
1002+
stay=jnp.array([[[1.0], [0.3], [0.3], [0.3]]]),
1003+
noise=jnp.array([[[0.0], [0.2], [0.2], [0.2]]]),
1004+
clean=jnp.array([[[0.0], [0.5], [0.5], [0.5]]]),
1005+
)
1006+
logits = jnp.array(
1007+
[[[10.0, 0.0], [5.0, 0.0], [2.0, 0.0], [1.0, 0.0]]]
1008+
) # Pos 0 has highest logit but ineligible
1009+
x0 = jnp.zeros((1, 4, 1), dtype=jnp.int32)
1010+
xt = jnp.ones((1, 4, 1), dtype=jnp.int32)
1011+
time = jnp.array([1.0])
1012+
next_time = jnp.array([0.5]) # frac = 0.5 -> budget = 3 eligible * 0.5 = 1
1013+
key = jax.random.PRNGKey(0)
1014+
1015+
out_probs = planner(routing_weights, logits, x0, xt, time, next_time, key)
1016+
1017+
# Pos 0 is ineligible (p_clean=0), so num_eligible=3.
1018+
# Budget = 3 * 0.5 = 1 (truncated to int).
1019+
# Top 1 eligible position by confidence: Pos 1 → force CLEAN.
1020+
# Non-selected (Pos 0, 2, 3) → keep original stay/noise, zero clean.
1021+
expected = discrete_step_sampler.RoutingWeights(
1022+
stay=jnp.array([[[1.0], [0.0], [0.3], [0.3]]]),
1023+
noise=jnp.array([[[0.0], [0.0], [0.2], [0.2]]]),
1024+
clean=jnp.array([[[0.0], [1.0], [0.0], [0.0]]]),
1025+
)
1026+
chex.assert_trees_all_close(out_probs.stay, expected.stay)
1027+
chex.assert_trees_all_close(out_probs.noise, expected.noise)
1028+
chex.assert_trees_all_close(out_probs.clean, expected.clean)
1029+
1030+
def test_greedy_planner_k_zero(self):
1031+
"""When clean weight is small, budget k=0. Keep original stay/noise."""
1032+
planner = discrete_step_sampler.GreedyPlanner()
1033+
routing_weights = discrete_step_sampler.RoutingWeights(
1034+
stay=jnp.full((1, 4, 1), 0.7),
1035+
noise=jnp.full((1, 4, 1), 0.2),
1036+
clean=jnp.full((1, 4, 1), 0.1),
1037+
)
1038+
logits = jnp.array([[[10.0, 0.0], [5.0, 0.0], [2.0, 0.0], [1.0, 0.0]]])
1039+
x0 = jnp.zeros((1, 4, 1), dtype=jnp.int32)
1040+
xt = jnp.ones((1, 4, 1), dtype=jnp.int32)
1041+
time = jnp.array([1.0])
1042+
next_time = jnp.array([1.0])
1043+
key = jax.random.PRNGKey(0)
1044+
1045+
out_probs = planner(routing_weights, logits, x0, xt, time, next_time, key)
1046+
1047+
# k=0: no positions selected. All keep original stay/noise, clean zeroed.
1048+
expected = discrete_step_sampler.RoutingWeights(
1049+
stay=jnp.full((1, 4, 1), 0.7),
1050+
noise=jnp.full((1, 4, 1), 0.2),
1051+
clean=jnp.zeros((1, 4, 1)),
1052+
)
1053+
chex.assert_trees_all_close(out_probs.stay, expected.stay)
1054+
chex.assert_trees_all_close(out_probs.noise, expected.noise)
1055+
chex.assert_trees_all_close(out_probs.clean, expected.clean)
1056+
1057+
9661058
if __name__ == '__main__':
9671059
absltest.main()

0 commit comments

Comments
 (0)