From 68eed8afd5ca977cc48cde94d6bf8200540038a8 Mon Sep 17 00:00:00 2001 From: Valentin De Bortoli Date: Thu, 7 May 2026 09:32:48 -0700 Subject: [PATCH] Add GreedyPlanner matching. PiperOrigin-RevId: 912005477 --- hackable_diffusion/lib/sampling/__init__.py | 1 + .../lib/sampling/discrete_step_sampler.py | 81 ++++++++++++++++ .../sampling/discrete_step_sampler_test.py | 92 +++++++++++++++++++ 3 files changed, 174 insertions(+) diff --git a/hackable_diffusion/lib/sampling/__init__.py b/hackable_diffusion/lib/sampling/__init__.py index 29bcfcc..dc6fb90 100644 --- a/hackable_diffusion/lib/sampling/__init__.py +++ b/hackable_diffusion/lib/sampling/__init__.py @@ -25,6 +25,7 @@ from hackable_diffusion.lib.sampling.discrete_step_sampler import CorruptedMaskFn from hackable_diffusion.lib.sampling.discrete_step_sampler import DiscreteDDIMStep from hackable_diffusion.lib.sampling.discrete_step_sampler import DiscreteFlowMatchingStep +from hackable_diffusion.lib.sampling.discrete_step_sampler import GreedyPlanner from hackable_diffusion.lib.sampling.discrete_step_sampler import IntegratedDiscreteDDIMStep from hackable_diffusion.lib.sampling.discrete_step_sampler import MaskValueCorruptedMaskFn from hackable_diffusion.lib.sampling.discrete_step_sampler import MaxCappedRemaskingFn diff --git a/hackable_diffusion/lib/sampling/discrete_step_sampler.py b/hackable_diffusion/lib/sampling/discrete_step_sampler.py index 74e675c..32550c7 100644 --- a/hackable_diffusion/lib/sampling/discrete_step_sampler.py +++ b/hackable_diffusion/lib/sampling/discrete_step_sampler.py @@ -381,6 +381,87 @@ def __call__( ... +@dataclasses.dataclass(frozen=True, kw_only=True) +class GreedyPlanner(RoutingStrategy): + """Confidence-based top-k greedy planner. + + Selects the top-k positions by confidence of the sampled x0 token, forces + those to CLEAN, and lets the remaining positions follow their natural + stay/noise dynamics (with p_clean zeroed out). + + The budget k is computed as: + k = num_eligible * p_clean_norm + where num_eligible is the number of positions with p_clean > 0 and + p_clean_norm is the normalized clean probability. For a linear schedule + α = 1−t this simplifies to k = num_eligible * (1 − next_time/time). + + Attributes: + tie_breaking_noise: Scale of uniform noise added to confidence scores for + tie-breaking. Default 1e-6. + """ + + tie_breaking_noise: float = 1e-6 + + @kt.typechecked + def __call__( + self, + routing_weights: RoutingWeights, + logits: Float['... M'], + x0: DataArray, + xt: DataArray, + time: TimeArray, + next_time: TimeArray, + key: PRNGKey, + ) -> RoutingWeights: + # Confidence = softmax(logits)[x0] per position + p = jax.nn.softmax(logits, axis=-1) + confidence = jnp.take_along_axis(p, x0, axis=-1).squeeze(-1) + + # Only consider positions that could go clean (p_clean > 0) + eligible = routing_weights.clean[..., 0] > 0 + confidence = jnp.where(eligible, confidence, -jnp.inf) + + # Add tie-breaking noise + _, subkey = jax.random.split(key) + confidence += jax.random.uniform(subkey, confidence.shape) * ( + self.tie_breaking_noise + ) + + # Budget: k = num_eligible * p_clean_norm. + # p_clean_norm = p_clean / (p_stay + p_noise + p_clean) is the same + # for all eligible positions (π(x_t) cancels in normalization). + # For a linear schedule (α = 1-t), this equals (1 - next_time/time). + total_weight = ( + routing_weights.stay + routing_weights.noise + routing_weights.clean + ) + p_clean_norm = routing_weights.clean / jnp.maximum(total_weight, 1e-12) + # p_clean_norm is the same for all eligible positions. Take max over spatial + # dimensions to get the value (ineligible positions have 0). + p_clean_norm_flat = p_clean_norm.reshape(p_clean_norm.shape[0], -1) + frac = jnp.max(p_clean_norm_flat, axis=-1, keepdims=True) + frac = jnp.clip(frac, 0.0, 1.0) + + num_eligible = jnp.sum(eligible.astype(jnp.float32), axis=-1, keepdims=True) + k = (num_eligible * frac).astype(jnp.int32) + + # Top-k threshold + seq_len = confidence.shape[-1] + sorted_conf = jnp.sort(confidence, axis=-1)[..., ::-1] + threshold = jnp.take_along_axis( + sorted_conf, jnp.clip(k - 1, 0, seq_len - 1), axis=-1 + ) + # When k=0, no positions should be selected. + to_update = (confidence >= threshold) & (k > 0) + + # Selected positions → force CLEAN (zero out stay/noise). + # Non-selected positions → zero out p_clean, keep original stay/noise. + return RoutingWeights( + stay=jnp.where(to_update[..., None], 0.0, routing_weights.stay), + noise=jnp.where(to_update[..., None], 0.0, routing_weights.noise), + clean=jnp.where(to_update[..., None], 1.0, 0.0), + ) + + ################################################################################ # MARK: UnMasking Step ################################################################################ diff --git a/hackable_diffusion/lib/sampling/discrete_step_sampler_test.py b/hackable_diffusion/lib/sampling/discrete_step_sampler_test.py index 9301226..d8990e1 100644 --- a/hackable_diffusion/lib/sampling/discrete_step_sampler_test.py +++ b/hackable_diffusion/lib/sampling/discrete_step_sampler_test.py @@ -963,5 +963,97 @@ def __call__(self, routing_weights, logits, x0, xt, time, next_time, key): chex.assert_trees_all_equal(out, routing_weights) +class GreedyPlannerTest(absltest.TestCase): + + def test_greedy_planner_budget(self): + planner = discrete_step_sampler.GreedyPlanner() + # 1 batch, 4 seq len. Realistic routing: all eligible with stay/noise > 0. + routing_weights = discrete_step_sampler.RoutingWeights( + stay=jnp.full((1, 4, 1), 0.3), + noise=jnp.full((1, 4, 1), 0.2), + clean=jnp.full((1, 4, 1), 0.5), + ) + logits = jnp.array( + [[[10.0, 0.0], [5.0, 0.0], [2.0, 0.0], [1.0, 0.0]]] + ) # high confidence first + x0 = jnp.zeros((1, 4, 1), dtype=jnp.int32) + xt = jnp.ones((1, 4, 1), dtype=jnp.int32) + time = jnp.array([1.0]) + next_time = jnp.array([0.5]) # frac = 0.5 -> budget = 4 eligible * 0.5 = 2 + key = jax.random.PRNGKey(0) + + out_probs = planner(routing_weights, logits, x0, xt, time, next_time, key) + + # Top 2 positions → force CLEAN (stay=0, noise=0, clean=1). + # Non-selected → keep original stay/noise, zero out clean. + expected = discrete_step_sampler.RoutingWeights( + stay=jnp.array([[[0.0], [0.0], [0.3], [0.3]]]), + noise=jnp.array([[[0.0], [0.0], [0.2], [0.2]]]), + clean=jnp.array([[[1.0], [1.0], [0.0], [0.0]]]), + ) + chex.assert_trees_all_close(out_probs.stay, expected.stay) + chex.assert_trees_all_close(out_probs.noise, expected.noise) + chex.assert_trees_all_close(out_probs.clean, expected.clean) + + def test_greedy_planner_eligibility(self): + planner = discrete_step_sampler.GreedyPlanner() + # Position 0 is NOT eligible (p_clean = 0), has original stay=1.0. + routing_weights = discrete_step_sampler.RoutingWeights( + stay=jnp.array([[[1.0], [0.3], [0.3], [0.3]]]), + noise=jnp.array([[[0.0], [0.2], [0.2], [0.2]]]), + clean=jnp.array([[[0.0], [0.5], [0.5], [0.5]]]), + ) + logits = jnp.array( + [[[10.0, 0.0], [5.0, 0.0], [2.0, 0.0], [1.0, 0.0]]] + ) # Pos 0 has highest logit but ineligible + x0 = jnp.zeros((1, 4, 1), dtype=jnp.int32) + xt = jnp.ones((1, 4, 1), dtype=jnp.int32) + time = jnp.array([1.0]) + next_time = jnp.array([0.5]) # frac = 0.5 -> budget = 3 eligible * 0.5 = 1 + key = jax.random.PRNGKey(0) + + out_probs = planner(routing_weights, logits, x0, xt, time, next_time, key) + + # Pos 0 is ineligible (p_clean=0), so num_eligible=3. + # Budget = 3 * 0.5 = 1 (truncated to int). + # Top 1 eligible position by confidence: Pos 1 → force CLEAN. + # Non-selected (Pos 0, 2, 3) → keep original stay/noise, zero clean. + expected = discrete_step_sampler.RoutingWeights( + stay=jnp.array([[[1.0], [0.0], [0.3], [0.3]]]), + noise=jnp.array([[[0.0], [0.0], [0.2], [0.2]]]), + clean=jnp.array([[[0.0], [1.0], [0.0], [0.0]]]), + ) + chex.assert_trees_all_close(out_probs.stay, expected.stay) + chex.assert_trees_all_close(out_probs.noise, expected.noise) + chex.assert_trees_all_close(out_probs.clean, expected.clean) + + def test_greedy_planner_k_zero(self): + """When clean weight is small, budget k=0. Keep original stay/noise.""" + planner = discrete_step_sampler.GreedyPlanner() + routing_weights = discrete_step_sampler.RoutingWeights( + stay=jnp.full((1, 4, 1), 0.7), + noise=jnp.full((1, 4, 1), 0.2), + clean=jnp.full((1, 4, 1), 0.1), + ) + logits = jnp.array([[[10.0, 0.0], [5.0, 0.0], [2.0, 0.0], [1.0, 0.0]]]) + x0 = jnp.zeros((1, 4, 1), dtype=jnp.int32) + xt = jnp.ones((1, 4, 1), dtype=jnp.int32) + time = jnp.array([1.0]) + next_time = jnp.array([1.0]) + key = jax.random.PRNGKey(0) + + out_probs = planner(routing_weights, logits, x0, xt, time, next_time, key) + + # k=0: no positions selected. All keep original stay/noise, clean zeroed. + expected = discrete_step_sampler.RoutingWeights( + stay=jnp.full((1, 4, 1), 0.7), + noise=jnp.full((1, 4, 1), 0.2), + clean=jnp.zeros((1, 4, 1)), + ) + chex.assert_trees_all_close(out_probs.stay, expected.stay) + chex.assert_trees_all_close(out_probs.noise, expected.noise) + chex.assert_trees_all_close(out_probs.clean, expected.clean) + + if __name__ == '__main__': absltest.main()