Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
153 changes: 116 additions & 37 deletions hackable_diffusion/lib/sampling/discrete_step_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,23 +12,75 @@
# See the License for the specific language governing permissions and
# limitations under the License.

"""Actual implementation of the sampling steps.

This module proposes various implementations but they all have in common
the core logic:

* An `initialize` function that takes a starting state and returns the
first step of the diffusion process.
* An `update` function that takes the current state and returns the next step.
* A `finalize` function that takes the last state and returns the final
state.

At every step, the update function takes the current state and returns the next
state. The update is also in charge of computing other auxiliary informations
such as volatility, drifts, etc.

The `InferenceFn` is also called within the step and converted into the
relevant representation, for instance score, velocity, etc.
"""Actual implementation of the sampling steps for discrete diffusion.

This module implements various discrete samplers (UnMasking, DDIM, Flow
Matching) using a unified **routing** architecture.

Core concepts:
* **Routing**: Samplers decompose the transition probability into a mixture of
three actions:
- STAY (0): Keep the current token `xt`.
- NOISE (1): Sample from the invariant distribution.
- CLEAN (2): Use the predicted clean token `x0`.
* **RoutingWeightPlanner**: An optional transformation applied to routing
weights before they are sampled. This allows injecting custom selection
strategies (e.g., greedy top-k) without modifying the sampler physics.

Standard interface for all samplers:
* `initialize`: Takes a starting state and returns the first step.
* `update`: Takes current state and returns the next step.
* `finalize`: Takes last state and returns the final state.

The following diagram illustrates the flow of the discrete denoising process:

┌───────────────────────┐
│ DiffusionStep │
│ Holds current state │
│ xt │
└───────────┬───────────┘
┌───────────────────────┐
│ InferenceFn │
│ Calls neural network │
│ to get predictions │
└───────────┬───────────┘
│ prediction (logits)
┌───────────────────────┐
│ _generate_candidates │
│ Extracts x0, x_noise │
│ │
└───────────┬───────────┘
┌───────────────────────┐
│ Compute Routing Wgts │
│ Decomposes posterior │
│ to STAY/NOISE/CLEAN │
└───────────┬───────────┘
┌───────────────────────┐
│ Planner │
│ Modifies weights │
│ (Optional, e.g. topk)│
└───────────┬───────────┘
┌───────────────────────┐
│ _sample_routing │
│ Categorical sampling │
│ to get next xt │
└───────────┬───────────┘
┌───────────────────────┐
│ DiffusionStep │
│ Holds next state │
│ xt │
└───────────────────────┘

This module also introduces the concepts of Routing and Planning for discrete
diffusion:
Expand Down Expand Up @@ -386,8 +438,8 @@ class GreedyPlanner(RoutingStrategy):
"""Confidence-based top-k greedy planner.

Selects the top-k positions by confidence of the sampled x0 token, forces
those to CLEAN, and lets the remaining positions follow their natural
stay/noise dynamics (with p_clean zeroed out).
those to CLEAN, and handles the remaining positions according to the
``resample`` flag.

The budget k is computed as:
k = num_eligible * p_clean_norm
Expand All @@ -398,9 +450,14 @@ class GreedyPlanner(RoutingStrategy):
Attributes:
tie_breaking_noise: Scale of uniform noise added to confidence scores for
tie-breaking. Default 1e-6.
resample: If True, non-selected positions keep their full original routing
weights (stay, noise, and clean), allowing stochastic resampling via the
posterior. If False (default), non-selected positions have their clean
weight zeroed out so they can only stay or re-noise.
"""

tie_breaking_noise: float = 1e-6
resample: bool = False

@kt.typechecked
def __call__(
Expand All @@ -413,12 +470,21 @@ def __call__(
next_time: TimeArray,
key: PRNGKey,
) -> RoutingWeights:
# Routing weights have shape (batch, *spatial, 1). We flatten all spatial
# dims into a single "positions" dim so top-k selection works regardless
# of whether the data is 1D (sequences) or 2D (adjacency matrices).
batch_size = routing_weights.stay.shape[0]
spatial_shape = routing_weights.stay.shape[1:-1] # e.g. (seq,) or (N, N)

# Confidence = softmax(logits)[x0] per position
p = jax.nn.softmax(logits, axis=-1)
confidence = jnp.take_along_axis(p, x0, axis=-1).squeeze(-1)
# confidence: (batch, *spatial) → flatten to (batch, num_positions)
confidence = confidence.reshape(batch_size, -1)

# Only consider positions that could go clean (p_clean > 0)
eligible = routing_weights.clean[..., 0] > 0
eligible = eligible.reshape(batch_size, -1)
confidence = jnp.where(eligible, confidence, -jnp.inf)

# Add tie-breaking noise
Expand All @@ -437,24 +503,27 @@ def __call__(
p_clean_norm = routing_weights.clean / jnp.maximum(total_weight, 1e-12)
# p_clean_norm is the same for all eligible positions. Take max over spatial
# dimensions to get the value (ineligible positions have 0).
p_clean_norm_flat = p_clean_norm.reshape(p_clean_norm.shape[0], -1)
p_clean_norm_flat = p_clean_norm.reshape(batch_size, -1)
frac = jnp.max(p_clean_norm_flat, axis=-1, keepdims=True)
frac = jnp.clip(frac, 0.0, 1.0)

num_eligible = jnp.sum(eligible.astype(jnp.float32), axis=-1, keepdims=True)
k = (num_eligible * frac).astype(jnp.int32)

# Top-k threshold
seq_len = confidence.shape[-1]
# Top-k threshold (operates on flattened positions)
num_positions = confidence.shape[-1]
sorted_conf = jnp.sort(confidence, axis=-1)[..., ::-1]
threshold = jnp.take_along_axis(
sorted_conf, jnp.clip(k - 1, 0, seq_len - 1), axis=-1
sorted_conf, jnp.clip(k - 1, 0, num_positions - 1), axis=-1
)
# When k=0, no positions should be selected.
to_update = (confidence >= threshold) & (k > 0)

# Selected positions → force CLEAN (zero out stay/noise).
# Non-selected positions → zero out p_clean, keep original stay/noise.
# Unflatten to_update back to original spatial shape
to_update = to_update.reshape(batch_size, *spatial_shape)

# Non-selected positions have clean zeroed out; they can only stay or
# re-noise according to the posterior.
return RoutingWeights(
stay=jnp.where(to_update[..., None], 0.0, routing_weights.stay),
noise=jnp.where(to_update[..., None], 0.0, routing_weights.noise),
Expand Down Expand Up @@ -489,6 +558,8 @@ class UnMaskingStep(SamplerStep):

Attributes:
corruption_process: The corruption process to use.
planner: The planner to use for transforming routing weights. This is
optional with the default being ``None`` (pure stochastic sampling).
remasking_fn: The remasking function to use, see
https://arxiv.org/abs/2503.00307v1. This is optional with the default
being no remasking.
Expand Down Expand Up @@ -713,6 +784,12 @@ class DiscreteDDIMStep(SamplerStep):
Note: when π = δ_MASK (masking process) and x_t = MASK, this reduces to:
P(unmask) = (α_s - α_t) / (1 - α_t),
which coincides with the UnMaskingStep formula (without remasking).

Attributes:
corruption_process: The corruption process to use.
planner: The planner to use for transforming routing weights. This is
optional with the default being ``None`` (pure stochastic sampling).
temperature: The temperature to use.
"""

corruption_process: CategoricalProcess
Expand Down Expand Up @@ -863,15 +940,17 @@ class DiscreteFlowMatchingStep(SamplerStep):
This sampler uses the 3-way routing representation. The update rule
decomposes naturally into:

p(x_s) = p_stay * δ_{x_t} + p_up * p_x0 + p_down * π
p(x_s) = p_stay * δ_{x_t} + p_clean * p_x0 + p_noise * π

where:
- p_stay = 1 - p_up - p_down
- p_up = (α_s - α_t) / (1 - α_t) * (1 + stoch_coeff)
- p_down = (α_s - α_t) / α_t * stoch_coeff
- p_stay = 1 - p_clean - p_noise
- p_clean = (α_s - α_t) / (1 - α_t) * (1 + stoch_coeff)
- p_noise = (α_s - α_t) / α_t * stoch_coeff

Attributes:
corruption_process: The corruption process to use.
planner: The planner to use for transforming routing weights. This is
optional with the default being ``None`` (pure stochastic sampling).
temperature: The temperature to use.
stoch_coeff: The stochasticity coefficient (default 0.0). Higher values
introduce more noise during the denoising process.
Expand Down Expand Up @@ -934,25 +1013,25 @@ def update(
alpha_s = self.corruption_process.schedule.alpha(next_time_bcast)
alpha_t = self.corruption_process.schedule.alpha(time_bcast)

prob_up = (
p_clean_raw = (
(alpha_s - alpha_t)
/ jnp.maximum(1.0 - alpha_t, 1e-12)
* (1.0 + self.stoch_coeff)
)
prob_down = (
p_noise_raw = (
(alpha_s - alpha_t) / jnp.maximum(alpha_t, 1e-12) * self.stoch_coeff
)

# Clip and rescale to ensure valid probabilities
raw_p_up = jnp.maximum(prob_up, 0.0)
raw_p_down = jnp.maximum(prob_down, 0.0)
sum_jumps = raw_p_up + raw_p_down
# Clip and rescale to ensure valid weights
p_clean_raw = jnp.maximum(p_clean_raw, 0.0)
p_noise_raw = jnp.maximum(p_noise_raw, 0.0)
sum_jumps = p_clean_raw + p_noise_raw
scale_factor = jnp.maximum(1.0, sum_jumps)

# Compute the probabilities for the three routing options.
# This is computed according to https://arxiv.org/abs/2407.15595.
p_clean = raw_p_up / scale_factor
p_noise = raw_p_down / scale_factor
p_clean = p_clean_raw / scale_factor
p_noise = p_noise_raw / scale_factor
p_stay = 1.0 - p_clean - p_noise

routing_weights = RoutingWeights(stay=p_stay, noise=p_noise, clean=p_clean)
Expand Down
48 changes: 48 additions & 0 deletions hackable_diffusion/lib/sampling/discrete_step_sampler_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1054,6 +1054,54 @@ def test_greedy_planner_k_zero(self):
chex.assert_trees_all_close(out_probs.noise, expected.noise)
chex.assert_trees_all_close(out_probs.clean, expected.clean)

def test_greedy_planner_2d_spatial(self):
"""GreedyPlanner must work with 2D spatial data (e.g.

adjacency matrices).
"""
planner = discrete_step_sampler.GreedyPlanner()
# Shape (1, 3, 3, 1): batch=1, spatial=(3,3), vocab_trailing=1.
routing_weights = discrete_step_sampler.RoutingWeights(
stay=jnp.full((1, 3, 3, 1), 0.3),
noise=jnp.full((1, 3, 3, 1), 0.2),
clean=jnp.full((1, 3, 3, 1), 0.5),
)
# 9 positions total, 2 vocab classes.
# Logits: position (0,0) has highest confidence, then (0,1), etc.
logits_flat = jnp.array([
[10.0, 0.0],
[9.0, 0.0],
[8.0, 0.0],
[7.0, 0.0],
[6.0, 0.0],
[5.0, 0.0],
[4.0, 0.0],
[3.0, 0.0],
[2.0, 0.0],
])
logits = logits_flat.reshape(1, 3, 3, 2)
x0 = jnp.zeros((1, 3, 3, 1), dtype=jnp.int32)
xt = jnp.ones((1, 3, 3, 1), dtype=jnp.int32)
time = jnp.array([1.0])
next_time = jnp.array([0.5]) # frac = 0.5 -> budget = 9 * 0.5 = 4
key = jax.random.PRNGKey(0)

out = planner(routing_weights, logits, x0, xt, time, next_time, key)

# Output must have the same spatial shape.
self.assertEqual(out.stay.shape, (1, 3, 3, 1))
self.assertEqual(out.noise.shape, (1, 3, 3, 1))
self.assertEqual(out.clean.shape, (1, 3, 3, 1))

# Budget k = 4. Top-4 positions (by confidence) → forced CLEAN.
# Positions (0,0), (0,1), (0,2), (1,0) should be selected.
selected = out.clean[0, :, :, 0] # (3, 3)
num_selected = int(jnp.sum(selected > 0))
self.assertEqual(num_selected, 4)
# Selected positions have stay=0, noise=0.
self.assertEqual(float(jnp.sum(out.stay[out.clean > 0])), 0.0)
self.assertEqual(float(jnp.sum(out.noise[out.clean > 0])), 0.0)


if __name__ == '__main__':
absltest.main()
Loading