1212# See the License for the specific language governing permissions and
1313# limitations under the License.
1414
15- """Actual implementation of the sampling steps.
16-
17- This module proposes various implementations but they all have in common
18- the core logic:
19-
20- * An `initialize` function that takes a starting state and returns the
21- first step of the diffusion process.
22- * An `update` function that takes the current state and returns the next step.
23- * A `finalize` function that takes the last state and returns the final
24- state.
25-
26- At every step, the update function takes the current state and returns the next
27- state. The update is also in charge of computing other auxiliary informations
28- such as volatility, drifts, etc.
29-
30- The `InferenceFn` is also called within the step and converted into the
31- relevant representation, for instance score, velocity, etc.
15+ """Actual implementation of the sampling steps for discrete diffusion.
16+
17+ This module implements various discrete samplers (UnMasking, DDIM, Flow
18+ Matching) using a unified **routing** architecture.
19+
20+ Core concepts:
21+ * **Routing**: Samplers decompose the transition probability into a mixture of
22+ three actions:
23+ - STAY (0): Keep the current token `xt`.
24+ - NOISE (1): Sample from the invariant distribution.
25+ - CLEAN (2): Use the predicted clean token `x0`.
26+ * **RoutingWeightPlanner**: An optional transformation applied to routing
27+ weights before they are sampled. This allows injecting custom selection
28+ strategies (e.g., greedy top-k) without modifying the sampler physics.
29+
30+ Standard interface for all samplers:
31+ * `initialize`: Takes a starting state and returns the first step.
32+ * `update`: Takes current state and returns the next step.
33+ * `finalize`: Takes last state and returns the final state.
34+
35+ The following diagram illustrates the flow of the discrete denoising process:
36+
37+ ┌───────────────────────┐
38+ │ DiffusionStep │
39+ │ Holds current state │
40+ │ xt │
41+ └───────────┬───────────┘
42+ │
43+ ▼
44+ ┌───────────────────────┐
45+ │ InferenceFn │
46+ │ Calls neural network │
47+ │ to get predictions │
48+ └───────────┬───────────┘
49+ │ prediction (logits)
50+ ▼
51+ ┌───────────────────────┐
52+ │ _generate_candidates │
53+ │ Extracts x0, x_noise │
54+ │ │
55+ └───────────┬───────────┘
56+ │
57+ ▼
58+ ┌───────────────────────┐
59+ │ Compute Routing Wgts │
60+ │ Decomposes posterior │
61+ │ to STAY/NOISE/CLEAN │
62+ └───────────┬───────────┘
63+ │
64+ ▼
65+ ┌───────────────────────┐
66+ │ Planner │
67+ │ Modifies weights │
68+ │ (Optional, e.g. topk)│
69+ └───────────┬───────────┘
70+ │
71+ ▼
72+ ┌───────────────────────┐
73+ │ _sample_routing │
74+ │ Categorical sampling │
75+ │ to get next xt │
76+ └───────────┬───────────┘
77+ │
78+ ▼
79+ ┌───────────────────────┐
80+ │ DiffusionStep │
81+ │ Holds next state │
82+ │ xt │
83+ └───────────────────────┘
3284
3385This module also introduces the concepts of Routing and Planning for discrete
3486diffusion:
@@ -386,8 +438,8 @@ class GreedyPlanner(RoutingStrategy):
386438 """Confidence-based top-k greedy planner.
387439
388440 Selects the top-k positions by confidence of the sampled x0 token, forces
389- those to CLEAN, and lets the remaining positions follow their natural
390- stay/noise dynamics (with p_clean zeroed out) .
441+ those to CLEAN, and handles the remaining positions according to the
442+ ``resample`` flag .
391443
392444 The budget k is computed as:
393445 k = num_eligible * p_clean_norm
@@ -398,9 +450,14 @@ class GreedyPlanner(RoutingStrategy):
398450 Attributes:
399451 tie_breaking_noise: Scale of uniform noise added to confidence scores for
400452 tie-breaking. Default 1e-6.
453+ resample: If True, non-selected positions keep their full original routing
454+ weights (stay, noise, and clean), allowing stochastic resampling via the
455+ posterior. If False (default), non-selected positions have their clean
456+ weight zeroed out so they can only stay or re-noise.
401457 """
402458
403459 tie_breaking_noise : float = 1e-6
460+ resample : bool = False
404461
405462 @kt .typechecked
406463 def __call__ (
@@ -413,12 +470,21 @@ def __call__(
413470 next_time : TimeArray ,
414471 key : PRNGKey ,
415472 ) -> RoutingWeights :
473+ # Routing weights have shape (batch, *spatial, 1). We flatten all spatial
474+ # dims into a single "positions" dim so top-k selection works regardless
475+ # of whether the data is 1D (sequences) or 2D (adjacency matrices).
476+ batch_size = routing_weights .stay .shape [0 ]
477+ spatial_shape = routing_weights .stay .shape [1 :- 1 ] # e.g. (seq,) or (N, N)
478+
416479 # Confidence = softmax(logits)[x0] per position
417480 p = jax .nn .softmax (logits , axis = - 1 )
418481 confidence = jnp .take_along_axis (p , x0 , axis = - 1 ).squeeze (- 1 )
482+ # confidence: (batch, *spatial) → flatten to (batch, num_positions)
483+ confidence = confidence .reshape (batch_size , - 1 )
419484
420485 # Only consider positions that could go clean (p_clean > 0)
421486 eligible = routing_weights .clean [..., 0 ] > 0
487+ eligible = eligible .reshape (batch_size , - 1 )
422488 confidence = jnp .where (eligible , confidence , - jnp .inf )
423489
424490 # Add tie-breaking noise
@@ -437,24 +503,27 @@ def __call__(
437503 p_clean_norm = routing_weights .clean / jnp .maximum (total_weight , 1e-12 )
438504 # p_clean_norm is the same for all eligible positions. Take max over spatial
439505 # dimensions to get the value (ineligible positions have 0).
440- p_clean_norm_flat = p_clean_norm .reshape (p_clean_norm . shape [ 0 ] , - 1 )
506+ p_clean_norm_flat = p_clean_norm .reshape (batch_size , - 1 )
441507 frac = jnp .max (p_clean_norm_flat , axis = - 1 , keepdims = True )
442508 frac = jnp .clip (frac , 0.0 , 1.0 )
443509
444510 num_eligible = jnp .sum (eligible .astype (jnp .float32 ), axis = - 1 , keepdims = True )
445511 k = (num_eligible * frac ).astype (jnp .int32 )
446512
447- # Top-k threshold
448- seq_len = confidence .shape [- 1 ]
513+ # Top-k threshold (operates on flattened positions)
514+ num_positions = confidence .shape [- 1 ]
449515 sorted_conf = jnp .sort (confidence , axis = - 1 )[..., ::- 1 ]
450516 threshold = jnp .take_along_axis (
451- sorted_conf , jnp .clip (k - 1 , 0 , seq_len - 1 ), axis = - 1
517+ sorted_conf , jnp .clip (k - 1 , 0 , num_positions - 1 ), axis = - 1
452518 )
453519 # When k=0, no positions should be selected.
454520 to_update = (confidence >= threshold ) & (k > 0 )
455521
456- # Selected positions → force CLEAN (zero out stay/noise).
457- # Non-selected positions → zero out p_clean, keep original stay/noise.
522+ # Unflatten to_update back to original spatial shape
523+ to_update = to_update .reshape (batch_size , * spatial_shape )
524+
525+ # Non-selected positions have clean zeroed out; they can only stay or
526+ # re-noise according to the posterior.
458527 return RoutingWeights (
459528 stay = jnp .where (to_update [..., None ], 0.0 , routing_weights .stay ),
460529 noise = jnp .where (to_update [..., None ], 0.0 , routing_weights .noise ),
@@ -489,6 +558,8 @@ class UnMaskingStep(SamplerStep):
489558
490559 Attributes:
491560 corruption_process: The corruption process to use.
561+ planner: The planner to use for transforming routing weights. This is
562+ optional with the default being ``None`` (pure stochastic sampling).
492563 remasking_fn: The remasking function to use, see
493564 https://arxiv.org/abs/2503.00307v1. This is optional with the default
494565 being no remasking.
@@ -713,6 +784,12 @@ class DiscreteDDIMStep(SamplerStep):
713784 Note: when π = δ_MASK (masking process) and x_t = MASK, this reduces to:
714785 P(unmask) = (α_s - α_t) / (1 - α_t),
715786 which coincides with the UnMaskingStep formula (without remasking).
787+
788+ Attributes:
789+ corruption_process: The corruption process to use.
790+ planner: The planner to use for transforming routing weights. This is
791+ optional with the default being ``None`` (pure stochastic sampling).
792+ temperature: The temperature to use.
716793 """
717794
718795 corruption_process : CategoricalProcess
@@ -863,15 +940,17 @@ class DiscreteFlowMatchingStep(SamplerStep):
863940 This sampler uses the 3-way routing representation. The update rule
864941 decomposes naturally into:
865942
866- p(x_s) = p_stay * δ_{x_t} + p_up * p_x0 + p_down * π
943+ p(x_s) = p_stay * δ_{x_t} + p_clean * p_x0 + p_noise * π
867944
868945 where:
869- - p_stay = 1 - p_up - p_down
870- - p_up = (α_s - α_t) / (1 - α_t) * (1 + stoch_coeff)
871- - p_down = (α_s - α_t) / α_t * stoch_coeff
946+ - p_stay = 1 - p_clean - p_noise
947+ - p_clean = (α_s - α_t) / (1 - α_t) * (1 + stoch_coeff)
948+ - p_noise = (α_s - α_t) / α_t * stoch_coeff
872949
873950 Attributes:
874951 corruption_process: The corruption process to use.
952+ planner: The planner to use for transforming routing weights. This is
953+ optional with the default being ``None`` (pure stochastic sampling).
875954 temperature: The temperature to use.
876955 stoch_coeff: The stochasticity coefficient (default 0.0). Higher values
877956 introduce more noise during the denoising process.
@@ -934,25 +1013,25 @@ def update(
9341013 alpha_s = self .corruption_process .schedule .alpha (next_time_bcast )
9351014 alpha_t = self .corruption_process .schedule .alpha (time_bcast )
9361015
937- prob_up = (
1016+ p_clean_raw = (
9381017 (alpha_s - alpha_t )
9391018 / jnp .maximum (1.0 - alpha_t , 1e-12 )
9401019 * (1.0 + self .stoch_coeff )
9411020 )
942- prob_down = (
1021+ p_noise_raw = (
9431022 (alpha_s - alpha_t ) / jnp .maximum (alpha_t , 1e-12 ) * self .stoch_coeff
9441023 )
9451024
946- # Clip and rescale to ensure valid probabilities
947- raw_p_up = jnp .maximum (prob_up , 0.0 )
948- raw_p_down = jnp .maximum (prob_down , 0.0 )
949- sum_jumps = raw_p_up + raw_p_down
1025+ # Clip and rescale to ensure valid weights
1026+ p_clean_raw = jnp .maximum (p_clean_raw , 0.0 )
1027+ p_noise_raw = jnp .maximum (p_noise_raw , 0.0 )
1028+ sum_jumps = p_clean_raw + p_noise_raw
9501029 scale_factor = jnp .maximum (1.0 , sum_jumps )
9511030
9521031 # Compute the probabilities for the three routing options.
9531032 # This is computed according to https://arxiv.org/abs/2407.15595.
954- p_clean = raw_p_up / scale_factor
955- p_noise = raw_p_down / scale_factor
1033+ p_clean = p_clean_raw / scale_factor
1034+ p_noise = p_noise_raw / scale_factor
9561035 p_stay = 1.0 - p_clean - p_noise
9571036
9581037 routing_weights = RoutingWeights (stay = p_stay , noise = p_noise , clean = p_clean )
0 commit comments