1212# See the License for the specific language governing permissions and
1313# limitations under the License.
1414
15- """Actual implementation of the sampling steps.
16-
17- This module proposes various implementations but they all have in common
18- the core logic:
19-
20- * An `initialize` function that takes a starting state and returns the
21- first step of the diffusion process.
22- * An `update` function that takes the current state and returns the next step.
23- * A `finalize` function that takes the last state and returns the final
24- state.
25-
26- At every step, the update function takes the current state and returns the next
27- state. The update is also in charge of computing other auxiliary informations
28- such as volatility, drifts, etc.
29-
30- The `InferenceFn` is also called within the step and converted into the
31- relevant representation, for instance score, velocity, etc.
15+ """Actual implementation of the sampling steps for discrete diffusion.
16+
17+ This module implements various discrete samplers (UnMasking, DDIM, Flow
18+ Matching) using a unified **routing** architecture.
19+
20+ Core concepts:
21+ * **Routing**: Samplers decompose the transition probability into a mixture of
22+ three actions:
23+ - STAY (0): Keep the current token `xt`.
24+ - NOISE (1): Sample from the invariant distribution.
25+ - CLEAN (2): Use the predicted clean token `x0`.
26+ * **RoutingWeightPlanner**: An optional transformation applied to routing
27+ weights before they are sampled. This allows injecting custom selection
28+ strategies (e.g., greedy top-k) without modifying the sampler physics.
29+
30+ Standard interface for all samplers:
31+ * `initialize`: Takes a starting state and returns the first step.
32+ * `update`: Takes current state and returns the next step.
33+ * `finalize`: Takes last state and returns the final state.
3234
3335This module also introduces the concepts of Routing and Planning for discrete
3436diffusion:
@@ -386,8 +388,8 @@ class GreedyPlanner(RoutingStrategy):
386388 """Confidence-based top-k greedy planner.
387389
388390 Selects the top-k positions by confidence of the sampled x0 token, forces
389- those to CLEAN, and lets the remaining positions follow their natural
390- stay/noise dynamics (with p_clean zeroed out) .
391+ those to CLEAN, and handles the remaining positions according to the
392+ ``resample`` flag .
391393
392394 The budget k is computed as:
393395 k = num_eligible * p_clean_norm
@@ -398,9 +400,14 @@ class GreedyPlanner(RoutingStrategy):
398400 Attributes:
399401 tie_breaking_noise: Scale of uniform noise added to confidence scores for
400402 tie-breaking. Default 1e-6.
403+ resample: If True, non-selected positions keep their full original routing
404+ weights (stay, noise, and clean), allowing stochastic resampling via the
405+ posterior. If False (default), non-selected positions have their clean
406+ weight zeroed out so they can only stay or re-noise.
401407 """
402408
403409 tie_breaking_noise : float = 1e-6
410+ resample : bool = False
404411
405412 @kt .typechecked
406413 def __call__ (
@@ -413,12 +420,21 @@ def __call__(
413420 next_time : TimeArray ,
414421 key : PRNGKey ,
415422 ) -> RoutingWeights :
423+ # Routing weights have shape (batch, *spatial, 1). We flatten all spatial
424+ # dims into a single "positions" dim so top-k selection works regardless
425+ # of whether the data is 1D (sequences) or 2D (adjacency matrices).
426+ batch_size = routing_weights .stay .shape [0 ]
427+ spatial_shape = routing_weights .stay .shape [1 :- 1 ] # e.g. (seq,) or (N, N)
428+
416429 # Confidence = softmax(logits)[x0] per position
417430 p = jax .nn .softmax (logits , axis = - 1 )
418431 confidence = jnp .take_along_axis (p , x0 , axis = - 1 ).squeeze (- 1 )
432+ # confidence: (batch, *spatial) → flatten to (batch, num_positions)
433+ confidence = confidence .reshape (batch_size , - 1 )
419434
420435 # Only consider positions that could go clean (p_clean > 0)
421436 eligible = routing_weights .clean [..., 0 ] > 0
437+ eligible = eligible .reshape (batch_size , - 1 )
422438 confidence = jnp .where (eligible , confidence , - jnp .inf )
423439
424440 # Add tie-breaking noise
@@ -437,24 +453,27 @@ def __call__(
437453 p_clean_norm = routing_weights .clean / jnp .maximum (total_weight , 1e-12 )
438454 # p_clean_norm is the same for all eligible positions. Take max over spatial
439455 # dimensions to get the value (ineligible positions have 0).
440- p_clean_norm_flat = p_clean_norm .reshape (p_clean_norm . shape [ 0 ] , - 1 )
456+ p_clean_norm_flat = p_clean_norm .reshape (batch_size , - 1 )
441457 frac = jnp .max (p_clean_norm_flat , axis = - 1 , keepdims = True )
442458 frac = jnp .clip (frac , 0.0 , 1.0 )
443459
444460 num_eligible = jnp .sum (eligible .astype (jnp .float32 ), axis = - 1 , keepdims = True )
445461 k = (num_eligible * frac ).astype (jnp .int32 )
446462
447- # Top-k threshold
448- seq_len = confidence .shape [- 1 ]
463+ # Top-k threshold (operates on flattened positions)
464+ num_positions = confidence .shape [- 1 ]
449465 sorted_conf = jnp .sort (confidence , axis = - 1 )[..., ::- 1 ]
450466 threshold = jnp .take_along_axis (
451- sorted_conf , jnp .clip (k - 1 , 0 , seq_len - 1 ), axis = - 1
467+ sorted_conf , jnp .clip (k - 1 , 0 , num_positions - 1 ), axis = - 1
452468 )
453469 # When k=0, no positions should be selected.
454470 to_update = (confidence >= threshold ) & (k > 0 )
455471
456- # Selected positions → force CLEAN (zero out stay/noise).
457- # Non-selected positions → zero out p_clean, keep original stay/noise.
472+ # Unflatten to_update back to original spatial shape
473+ to_update = to_update .reshape (batch_size , * spatial_shape )
474+
475+ # Non-selected positions have clean zeroed out; they can only stay or
476+ # re-noise according to the posterior.
458477 return RoutingWeights (
459478 stay = jnp .where (to_update [..., None ], 0.0 , routing_weights .stay ),
460479 noise = jnp .where (to_update [..., None ], 0.0 , routing_weights .noise ),
@@ -489,6 +508,8 @@ class UnMaskingStep(SamplerStep):
489508
490509 Attributes:
491510 corruption_process: The corruption process to use.
511+ planner: The planner to use for transforming routing weights. This is
512+ optional with the default being ``None`` (pure stochastic sampling).
492513 remasking_fn: The remasking function to use, see
493514 https://arxiv.org/abs/2503.00307v1. This is optional with the default
494515 being no remasking.
@@ -713,6 +734,12 @@ class DiscreteDDIMStep(SamplerStep):
713734 Note: when π = δ_MASK (masking process) and x_t = MASK, this reduces to:
714735 P(unmask) = (α_s - α_t) / (1 - α_t),
715736 which coincides with the UnMaskingStep formula (without remasking).
737+
738+ Attributes:
739+ corruption_process: The corruption process to use.
740+ planner: The planner to use for transforming routing weights. This is
741+ optional with the default being ``None`` (pure stochastic sampling).
742+ temperature: The temperature to use.
716743 """
717744
718745 corruption_process : CategoricalProcess
@@ -863,15 +890,17 @@ class DiscreteFlowMatchingStep(SamplerStep):
863890 This sampler uses the 3-way routing representation. The update rule
864891 decomposes naturally into:
865892
866- p(x_s) = p_stay * δ_{x_t} + p_up * p_x0 + p_down * π
893+ p(x_s) = p_stay * δ_{x_t} + p_clean * p_x0 + p_noise * π
867894
868895 where:
869- - p_stay = 1 - p_up - p_down
870- - p_up = (α_s - α_t) / (1 - α_t) * (1 + stoch_coeff)
871- - p_down = (α_s - α_t) / α_t * stoch_coeff
896+ - p_stay = 1 - p_clean - p_noise
897+ - p_clean = (α_s - α_t) / (1 - α_t) * (1 + stoch_coeff)
898+ - p_noise = (α_s - α_t) / α_t * stoch_coeff
872899
873900 Attributes:
874901 corruption_process: The corruption process to use.
902+ planner: The planner to use for transforming routing weights. This is
903+ optional with the default being ``None`` (pure stochastic sampling).
875904 temperature: The temperature to use.
876905 stoch_coeff: The stochasticity coefficient (default 0.0). Higher values
877906 introduce more noise during the denoising process.
@@ -934,25 +963,25 @@ def update(
934963 alpha_s = self .corruption_process .schedule .alpha (next_time_bcast )
935964 alpha_t = self .corruption_process .schedule .alpha (time_bcast )
936965
937- prob_up = (
966+ p_clean_raw = (
938967 (alpha_s - alpha_t )
939968 / jnp .maximum (1.0 - alpha_t , 1e-12 )
940969 * (1.0 + self .stoch_coeff )
941970 )
942- prob_down = (
971+ p_noise_raw = (
943972 (alpha_s - alpha_t ) / jnp .maximum (alpha_t , 1e-12 ) * self .stoch_coeff
944973 )
945974
946- # Clip and rescale to ensure valid probabilities
947- raw_p_up = jnp .maximum (prob_up , 0.0 )
948- raw_p_down = jnp .maximum (prob_down , 0.0 )
949- sum_jumps = raw_p_up + raw_p_down
975+ # Clip and rescale to ensure valid weights
976+ p_clean_raw = jnp .maximum (p_clean_raw , 0.0 )
977+ p_noise_raw = jnp .maximum (p_noise_raw , 0.0 )
978+ sum_jumps = p_clean_raw + p_noise_raw
950979 scale_factor = jnp .maximum (1.0 , sum_jumps )
951980
952981 # Compute the probabilities for the three routing options.
953982 # This is computed according to https://arxiv.org/abs/2407.15595.
954- p_clean = raw_p_up / scale_factor
955- p_noise = raw_p_down / scale_factor
983+ p_clean = p_clean_raw / scale_factor
984+ p_noise = p_noise_raw / scale_factor
956985 p_stay = 1.0 - p_clean - p_noise
957986
958987 routing_weights = RoutingWeights (stay = p_stay , noise = p_noise , clean = p_clean )
0 commit comments