Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions hackable_diffusion/lib/sampling/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from hackable_diffusion.lib.sampling.discrete_step_sampler import CorruptedMaskFn
from hackable_diffusion.lib.sampling.discrete_step_sampler import DiscreteDDIMStep
from hackable_diffusion.lib.sampling.discrete_step_sampler import DiscreteFlowMatchingStep
from hackable_diffusion.lib.sampling.discrete_step_sampler import GreedyPlanner
from hackable_diffusion.lib.sampling.discrete_step_sampler import IntegratedDiscreteDDIMStep
from hackable_diffusion.lib.sampling.discrete_step_sampler import MaskValueCorruptedMaskFn
from hackable_diffusion.lib.sampling.discrete_step_sampler import MaxCappedRemaskingFn
Expand Down
81 changes: 81 additions & 0 deletions hackable_diffusion/lib/sampling/discrete_step_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -381,6 +381,87 @@ def __call__(
...


@dataclasses.dataclass(frozen=True, kw_only=True)
class GreedyPlanner(RoutingStrategy):
"""Confidence-based top-k greedy planner.

Selects the top-k positions by confidence of the sampled x0 token, forces
those to CLEAN, and lets the remaining positions follow their natural
stay/noise dynamics (with p_clean zeroed out).

The budget k is computed as:
k = num_eligible * p_clean_norm
where num_eligible is the number of positions with p_clean > 0 and
p_clean_norm is the normalized clean probability. For a linear schedule
α = 1−t this simplifies to k = num_eligible * (1 − next_time/time).

Attributes:
tie_breaking_noise: Scale of uniform noise added to confidence scores for
tie-breaking. Default 1e-6.
"""

tie_breaking_noise: float = 1e-6

@kt.typechecked
def __call__(
self,
routing_weights: RoutingWeights,
logits: Float['... M'],
x0: DataArray,
xt: DataArray,
time: TimeArray,
next_time: TimeArray,
key: PRNGKey,
) -> RoutingWeights:
# Confidence = softmax(logits)[x0] per position
p = jax.nn.softmax(logits, axis=-1)
confidence = jnp.take_along_axis(p, x0, axis=-1).squeeze(-1)

# Only consider positions that could go clean (p_clean > 0)
eligible = routing_weights.clean[..., 0] > 0
confidence = jnp.where(eligible, confidence, -jnp.inf)

# Add tie-breaking noise
_, subkey = jax.random.split(key)
confidence += jax.random.uniform(subkey, confidence.shape) * (
self.tie_breaking_noise
)

# Budget: k = num_eligible * p_clean_norm.
# p_clean_norm = p_clean / (p_stay + p_noise + p_clean) is the same
# for all eligible positions (π(x_t) cancels in normalization).
# For a linear schedule (α = 1-t), this equals (1 - next_time/time).
total_weight = (
routing_weights.stay + routing_weights.noise + routing_weights.clean
)
p_clean_norm = routing_weights.clean / jnp.maximum(total_weight, 1e-12)
# p_clean_norm is the same for all eligible positions. Take max over spatial
# dimensions to get the value (ineligible positions have 0).
p_clean_norm_flat = p_clean_norm.reshape(p_clean_norm.shape[0], -1)
frac = jnp.max(p_clean_norm_flat, axis=-1, keepdims=True)
frac = jnp.clip(frac, 0.0, 1.0)

num_eligible = jnp.sum(eligible.astype(jnp.float32), axis=-1, keepdims=True)
k = (num_eligible * frac).astype(jnp.int32)

# Top-k threshold
seq_len = confidence.shape[-1]
sorted_conf = jnp.sort(confidence, axis=-1)[..., ::-1]
threshold = jnp.take_along_axis(
sorted_conf, jnp.clip(k - 1, 0, seq_len - 1), axis=-1
)
# When k=0, no positions should be selected.
to_update = (confidence >= threshold) & (k > 0)

# Selected positions → force CLEAN (zero out stay/noise).
# Non-selected positions → zero out p_clean, keep original stay/noise.
return RoutingWeights(
stay=jnp.where(to_update[..., None], 0.0, routing_weights.stay),
noise=jnp.where(to_update[..., None], 0.0, routing_weights.noise),
clean=jnp.where(to_update[..., None], 1.0, 0.0),
)


################################################################################
# MARK: UnMasking Step
################################################################################
Expand Down
92 changes: 92 additions & 0 deletions hackable_diffusion/lib/sampling/discrete_step_sampler_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -963,5 +963,97 @@ def __call__(self, routing_weights, logits, x0, xt, time, next_time, key):
chex.assert_trees_all_equal(out, routing_weights)


class GreedyPlannerTest(absltest.TestCase):

def test_greedy_planner_budget(self):
planner = discrete_step_sampler.GreedyPlanner()
# 1 batch, 4 seq len. Realistic routing: all eligible with stay/noise > 0.
routing_weights = discrete_step_sampler.RoutingWeights(
stay=jnp.full((1, 4, 1), 0.3),
noise=jnp.full((1, 4, 1), 0.2),
clean=jnp.full((1, 4, 1), 0.5),
)
logits = jnp.array(
[[[10.0, 0.0], [5.0, 0.0], [2.0, 0.0], [1.0, 0.0]]]
) # high confidence first
x0 = jnp.zeros((1, 4, 1), dtype=jnp.int32)
xt = jnp.ones((1, 4, 1), dtype=jnp.int32)
time = jnp.array([1.0])
next_time = jnp.array([0.5]) # frac = 0.5 -> budget = 4 eligible * 0.5 = 2
key = jax.random.PRNGKey(0)

out_probs = planner(routing_weights, logits, x0, xt, time, next_time, key)

# Top 2 positions → force CLEAN (stay=0, noise=0, clean=1).
# Non-selected → keep original stay/noise, zero out clean.
expected = discrete_step_sampler.RoutingWeights(
stay=jnp.array([[[0.0], [0.0], [0.3], [0.3]]]),
noise=jnp.array([[[0.0], [0.0], [0.2], [0.2]]]),
clean=jnp.array([[[1.0], [1.0], [0.0], [0.0]]]),
)
chex.assert_trees_all_close(out_probs.stay, expected.stay)
chex.assert_trees_all_close(out_probs.noise, expected.noise)
chex.assert_trees_all_close(out_probs.clean, expected.clean)

def test_greedy_planner_eligibility(self):
planner = discrete_step_sampler.GreedyPlanner()
# Position 0 is NOT eligible (p_clean = 0), has original stay=1.0.
routing_weights = discrete_step_sampler.RoutingWeights(
stay=jnp.array([[[1.0], [0.3], [0.3], [0.3]]]),
noise=jnp.array([[[0.0], [0.2], [0.2], [0.2]]]),
clean=jnp.array([[[0.0], [0.5], [0.5], [0.5]]]),
)
logits = jnp.array(
[[[10.0, 0.0], [5.0, 0.0], [2.0, 0.0], [1.0, 0.0]]]
) # Pos 0 has highest logit but ineligible
x0 = jnp.zeros((1, 4, 1), dtype=jnp.int32)
xt = jnp.ones((1, 4, 1), dtype=jnp.int32)
time = jnp.array([1.0])
next_time = jnp.array([0.5]) # frac = 0.5 -> budget = 3 eligible * 0.5 = 1
key = jax.random.PRNGKey(0)

out_probs = planner(routing_weights, logits, x0, xt, time, next_time, key)

# Pos 0 is ineligible (p_clean=0), so num_eligible=3.
# Budget = 3 * 0.5 = 1 (truncated to int).
# Top 1 eligible position by confidence: Pos 1 → force CLEAN.
# Non-selected (Pos 0, 2, 3) → keep original stay/noise, zero clean.
expected = discrete_step_sampler.RoutingWeights(
stay=jnp.array([[[1.0], [0.0], [0.3], [0.3]]]),
noise=jnp.array([[[0.0], [0.0], [0.2], [0.2]]]),
clean=jnp.array([[[0.0], [1.0], [0.0], [0.0]]]),
)
chex.assert_trees_all_close(out_probs.stay, expected.stay)
chex.assert_trees_all_close(out_probs.noise, expected.noise)
chex.assert_trees_all_close(out_probs.clean, expected.clean)

def test_greedy_planner_k_zero(self):
"""When clean weight is small, budget k=0. Keep original stay/noise."""
planner = discrete_step_sampler.GreedyPlanner()
routing_weights = discrete_step_sampler.RoutingWeights(
stay=jnp.full((1, 4, 1), 0.7),
noise=jnp.full((1, 4, 1), 0.2),
clean=jnp.full((1, 4, 1), 0.1),
)
logits = jnp.array([[[10.0, 0.0], [5.0, 0.0], [2.0, 0.0], [1.0, 0.0]]])
x0 = jnp.zeros((1, 4, 1), dtype=jnp.int32)
xt = jnp.ones((1, 4, 1), dtype=jnp.int32)
time = jnp.array([1.0])
next_time = jnp.array([1.0])
key = jax.random.PRNGKey(0)

out_probs = planner(routing_weights, logits, x0, xt, time, next_time, key)

# k=0: no positions selected. All keep original stay/noise, clean zeroed.
expected = discrete_step_sampler.RoutingWeights(
stay=jnp.full((1, 4, 1), 0.7),
noise=jnp.full((1, 4, 1), 0.2),
clean=jnp.zeros((1, 4, 1)),
)
chex.assert_trees_all_close(out_probs.stay, expected.stay)
chex.assert_trees_all_close(out_probs.noise, expected.noise)
chex.assert_trees_all_close(out_probs.clean, expected.clean)


if __name__ == '__main__':
absltest.main()
Loading