|
12 | 12 | # See the License for the specific language governing permissions and |
13 | 13 | # limitations under the License. |
14 | 14 |
|
15 | | -"""Actual implementation of the sampling steps. |
16 | | -
|
17 | | -This module proposes various implementations but they all have in common |
18 | | -the core logic: |
19 | | -
|
20 | | -* An `initialize` function that takes a starting state and returns the |
21 | | - first step of the diffusion process. |
22 | | -* An `update` function that takes the current state and returns the next step. |
23 | | -* A `finalize` function that takes the last state and returns the final |
24 | | - state. |
25 | | -
|
26 | | -At every step, the update function takes the current state and returns the next |
27 | | -state. The update is also in charge of computing other auxiliary informations |
28 | | -such as volatility, drifts, etc. |
29 | | -
|
30 | | -The `InferenceFn` is also called within the step and converted into the |
31 | | -relevant representation, for instance score, velocity, etc. |
| 15 | +"""Actual implementation of the sampling steps for discrete diffusion. |
| 16 | +
|
| 17 | +This module implements various discrete samplers (UnMasking, DDIM, Flow |
| 18 | +Matching) using a unified **routing** architecture. |
| 19 | +
|
| 20 | +Core concepts: |
| 21 | +* **Routing**: Samplers decompose the transition probability into a mixture of |
| 22 | + three actions: |
| 23 | + - STAY (0): Keep the current token `xt`. |
| 24 | + - NOISE (1): Sample from the invariant distribution. |
| 25 | + - CLEAN (2): Use the predicted clean token `x0`. |
| 26 | +* **RoutingWeightPlanner**: An optional transformation applied to routing |
| 27 | + weights before they are sampled. This allows injecting custom selection |
| 28 | + strategies (e.g., greedy top-k) without modifying the sampler physics. |
| 29 | +
|
| 30 | +Standard interface for all samplers: |
| 31 | +* `initialize`: Takes a starting state and returns the first step. |
| 32 | +* `update`: Takes current state and returns the next step. |
| 33 | +* `finalize`: Takes last state and returns the final state. |
| 34 | +
|
| 35 | + The following diagram illustrates the flow of the discrete denoising process: |
| 36 | +
|
| 37 | + ┌───────────────────────┐ |
| 38 | + │ DiffusionStep │ |
| 39 | + │ Holds current state │ |
| 40 | + │ xt │ |
| 41 | + └───────────┬───────────┘ |
| 42 | + │ |
| 43 | + ▼ |
| 44 | + ┌───────────────────────┐ |
| 45 | + │ InferenceFn │ |
| 46 | + │ Calls neural network │ |
| 47 | + │ to get predictions │ |
| 48 | + └───────────┬───────────┘ |
| 49 | + │ prediction (logits) |
| 50 | + ▼ |
| 51 | + ┌───────────────────────┐ |
| 52 | + │ _generate_candidates │ |
| 53 | + │ Extracts x0, x_noise │ |
| 54 | + │ │ |
| 55 | + └───────────┬───────────┘ |
| 56 | + │ |
| 57 | + ▼ |
| 58 | + ┌───────────────────────┐ |
| 59 | + │ Compute Routing Wgts │ |
| 60 | + │ Decomposes posterior │ |
| 61 | + │ to STAY/NOISE/CLEAN │ |
| 62 | + └───────────┬───────────┘ |
| 63 | + │ |
| 64 | + ▼ |
| 65 | + ┌───────────────────────┐ |
| 66 | + │ Planner │ |
| 67 | + │ Modifies weights │ |
| 68 | + │ (Optional, e.g. topk)│ |
| 69 | + └───────────┬───────────┘ |
| 70 | + │ |
| 71 | + ▼ |
| 72 | + ┌───────────────────────┐ |
| 73 | + │ _sample_routing │ |
| 74 | + │ Categorical sampling │ |
| 75 | + │ to get next xt │ |
| 76 | + └───────────┬───────────┘ |
| 77 | + │ |
| 78 | + ▼ |
| 79 | + ┌───────────────────────┐ |
| 80 | + │ DiffusionStep │ |
| 81 | + │ Holds next state │ |
| 82 | + │ xt │ |
| 83 | + └───────────────────────┘ |
32 | 84 |
|
33 | 85 | This module also introduces the concepts of Routing and Planning for discrete |
34 | 86 | diffusion: |
@@ -381,6 +433,104 @@ def __call__( |
381 | 433 | ... |
382 | 434 |
|
383 | 435 |
|
| 436 | +@dataclasses.dataclass(frozen=True, kw_only=True) |
| 437 | +class GreedyPlanner(RoutingStrategy): |
| 438 | + """Confidence-based top-k greedy planner. |
| 439 | +
|
| 440 | + Selects the top-k positions by confidence of the sampled x0 token, forces |
| 441 | + those to CLEAN, and handles the remaining positions according to the |
| 442 | + ``resample`` flag. |
| 443 | +
|
| 444 | + The budget k is computed as: |
| 445 | + k = num_eligible * p_clean_norm |
| 446 | + where num_eligible is the number of positions with p_clean > 0 and |
| 447 | + p_clean_norm is the normalized clean probability. For a linear schedule |
| 448 | + α = 1−t this simplifies to k = num_eligible * (1 − next_time/time). |
| 449 | +
|
| 450 | + Attributes: |
| 451 | + tie_breaking_noise: Scale of uniform noise added to confidence scores for |
| 452 | + tie-breaking. Default 1e-6. |
| 453 | + resample: If True, non-selected positions keep their full original routing |
| 454 | + weights (stay, noise, and clean), allowing stochastic resampling via the |
| 455 | + posterior. If False (default), non-selected positions have their clean |
| 456 | + weight zeroed out so they can only stay or re-noise. |
| 457 | + """ |
| 458 | + |
| 459 | + tie_breaking_noise: float = 1e-6 |
| 460 | + resample: bool = False |
| 461 | + |
| 462 | + @kt.typechecked |
| 463 | + def __call__( |
| 464 | + self, |
| 465 | + routing_weights: RoutingWeights, |
| 466 | + logits: Float['... M'], |
| 467 | + x0: DataArray, |
| 468 | + xt: DataArray, |
| 469 | + time: TimeArray, |
| 470 | + next_time: TimeArray, |
| 471 | + key: PRNGKey, |
| 472 | + ) -> RoutingWeights: |
| 473 | + # Routing weights have shape (batch, *spatial, 1). We flatten all spatial |
| 474 | + # dims into a single "positions" dim so top-k selection works regardless |
| 475 | + # of whether the data is 1D (sequences) or 2D (adjacency matrices). |
| 476 | + batch_size = routing_weights.stay.shape[0] |
| 477 | + spatial_shape = routing_weights.stay.shape[1:-1] # e.g. (seq,) or (N, N) |
| 478 | + |
| 479 | + # Confidence = softmax(logits)[x0] per position |
| 480 | + p = jax.nn.softmax(logits, axis=-1) |
| 481 | + confidence = jnp.take_along_axis(p, x0, axis=-1).squeeze(-1) |
| 482 | + # confidence: (batch, *spatial) → flatten to (batch, num_positions) |
| 483 | + confidence = confidence.reshape(batch_size, -1) |
| 484 | + |
| 485 | + # Only consider positions that could go clean (p_clean > 0) |
| 486 | + eligible = routing_weights.clean[..., 0] > 0 |
| 487 | + eligible = eligible.reshape(batch_size, -1) |
| 488 | + confidence = jnp.where(eligible, confidence, -jnp.inf) |
| 489 | + |
| 490 | + # Add tie-breaking noise |
| 491 | + _, subkey = jax.random.split(key) |
| 492 | + confidence += jax.random.uniform(subkey, confidence.shape) * ( |
| 493 | + self.tie_breaking_noise |
| 494 | + ) |
| 495 | + |
| 496 | + # Budget: k = num_eligible * p_clean_norm. |
| 497 | + # p_clean_norm = p_clean / (p_stay + p_noise + p_clean) is the same |
| 498 | + # for all eligible positions (π(x_t) cancels in normalization). |
| 499 | + # For a linear schedule (α = 1-t), this equals (1 - next_time/time). |
| 500 | + total_weight = ( |
| 501 | + routing_weights.stay + routing_weights.noise + routing_weights.clean |
| 502 | + ) |
| 503 | + p_clean_norm = routing_weights.clean / jnp.maximum(total_weight, 1e-12) |
| 504 | + # p_clean_norm is the same for all eligible positions. Take max over spatial |
| 505 | + # dimensions to get the value (ineligible positions have 0). |
| 506 | + p_clean_norm_flat = p_clean_norm.reshape(batch_size, -1) |
| 507 | + frac = jnp.max(p_clean_norm_flat, axis=-1, keepdims=True) |
| 508 | + frac = jnp.clip(frac, 0.0, 1.0) |
| 509 | + |
| 510 | + num_eligible = jnp.sum(eligible.astype(jnp.float32), axis=-1, keepdims=True) |
| 511 | + k = (num_eligible * frac).astype(jnp.int32) |
| 512 | + |
| 513 | + # Top-k threshold (operates on flattened positions) |
| 514 | + num_positions = confidence.shape[-1] |
| 515 | + sorted_conf = jnp.sort(confidence, axis=-1)[..., ::-1] |
| 516 | + threshold = jnp.take_along_axis( |
| 517 | + sorted_conf, jnp.clip(k - 1, 0, num_positions - 1), axis=-1 |
| 518 | + ) |
| 519 | + # When k=0, no positions should be selected. |
| 520 | + to_update = (confidence >= threshold) & (k > 0) |
| 521 | + |
| 522 | + # Unflatten to_update back to original spatial shape |
| 523 | + to_update = to_update.reshape(batch_size, *spatial_shape) |
| 524 | + |
| 525 | + # Non-selected positions have clean zeroed out; they can only stay or |
| 526 | + # re-noise according to the posterior. |
| 527 | + return RoutingWeights( |
| 528 | + stay=jnp.where(to_update[..., None], 0.0, routing_weights.stay), |
| 529 | + noise=jnp.where(to_update[..., None], 0.0, routing_weights.noise), |
| 530 | + clean=jnp.where(to_update[..., None], 1.0, 0.0), |
| 531 | + ) |
| 532 | + |
| 533 | + |
384 | 534 | ################################################################################ |
385 | 535 | # MARK: UnMasking Step |
386 | 536 | ################################################################################ |
@@ -408,6 +558,8 @@ class UnMaskingStep(SamplerStep): |
408 | 558 |
|
409 | 559 | Attributes: |
410 | 560 | corruption_process: The corruption process to use. |
| 561 | + planner: The planner to use for transforming routing weights. This is |
| 562 | + optional with the default being ``None`` (pure stochastic sampling). |
411 | 563 | remasking_fn: The remasking function to use, see |
412 | 564 | https://arxiv.org/abs/2503.00307v1. This is optional with the default |
413 | 565 | being no remasking. |
@@ -632,6 +784,12 @@ class DiscreteDDIMStep(SamplerStep): |
632 | 784 | Note: when π = δ_MASK (masking process) and x_t = MASK, this reduces to: |
633 | 785 | P(unmask) = (α_s - α_t) / (1 - α_t), |
634 | 786 | which coincides with the UnMaskingStep formula (without remasking). |
| 787 | +
|
| 788 | + Attributes: |
| 789 | + corruption_process: The corruption process to use. |
| 790 | + planner: The planner to use for transforming routing weights. This is |
| 791 | + optional with the default being ``None`` (pure stochastic sampling). |
| 792 | + temperature: The temperature to use. |
635 | 793 | """ |
636 | 794 |
|
637 | 795 | corruption_process: CategoricalProcess |
@@ -782,15 +940,17 @@ class DiscreteFlowMatchingStep(SamplerStep): |
782 | 940 | This sampler uses the 3-way routing representation. The update rule |
783 | 941 | decomposes naturally into: |
784 | 942 |
|
785 | | - p(x_s) = p_stay * δ_{x_t} + p_up * p_x0 + p_down * π |
| 943 | + p(x_s) = p_stay * δ_{x_t} + p_clean * p_x0 + p_noise * π |
786 | 944 |
|
787 | 945 | where: |
788 | | - - p_stay = 1 - p_up - p_down |
789 | | - - p_up = (α_s - α_t) / (1 - α_t) * (1 + stoch_coeff) |
790 | | - - p_down = (α_s - α_t) / α_t * stoch_coeff |
| 946 | + - p_stay = 1 - p_clean - p_noise |
| 947 | + - p_clean = (α_s - α_t) / (1 - α_t) * (1 + stoch_coeff) |
| 948 | + - p_noise = (α_s - α_t) / α_t * stoch_coeff |
791 | 949 |
|
792 | 950 | Attributes: |
793 | 951 | corruption_process: The corruption process to use. |
| 952 | + planner: The planner to use for transforming routing weights. This is |
| 953 | + optional with the default being ``None`` (pure stochastic sampling). |
794 | 954 | temperature: The temperature to use. |
795 | 955 | stoch_coeff: The stochasticity coefficient (default 0.0). Higher values |
796 | 956 | introduce more noise during the denoising process. |
@@ -853,25 +1013,25 @@ def update( |
853 | 1013 | alpha_s = self.corruption_process.schedule.alpha(next_time_bcast) |
854 | 1014 | alpha_t = self.corruption_process.schedule.alpha(time_bcast) |
855 | 1015 |
|
856 | | - prob_up = ( |
| 1016 | + p_clean_raw = ( |
857 | 1017 | (alpha_s - alpha_t) |
858 | 1018 | / jnp.maximum(1.0 - alpha_t, 1e-12) |
859 | 1019 | * (1.0 + self.stoch_coeff) |
860 | 1020 | ) |
861 | | - prob_down = ( |
| 1021 | + p_noise_raw = ( |
862 | 1022 | (alpha_s - alpha_t) / jnp.maximum(alpha_t, 1e-12) * self.stoch_coeff |
863 | 1023 | ) |
864 | 1024 |
|
865 | | - # Clip and rescale to ensure valid probabilities |
866 | | - raw_p_up = jnp.maximum(prob_up, 0.0) |
867 | | - raw_p_down = jnp.maximum(prob_down, 0.0) |
868 | | - sum_jumps = raw_p_up + raw_p_down |
| 1025 | + # Clip and rescale to ensure valid weights |
| 1026 | + p_clean_raw = jnp.maximum(p_clean_raw, 0.0) |
| 1027 | + p_noise_raw = jnp.maximum(p_noise_raw, 0.0) |
| 1028 | + sum_jumps = p_clean_raw + p_noise_raw |
869 | 1029 | scale_factor = jnp.maximum(1.0, sum_jumps) |
870 | 1030 |
|
871 | 1031 | # Compute the probabilities for the three routing options. |
872 | 1032 | # This is computed according to https://arxiv.org/abs/2407.15595. |
873 | | - p_clean = raw_p_up / scale_factor |
874 | | - p_noise = raw_p_down / scale_factor |
| 1033 | + p_clean = p_clean_raw / scale_factor |
| 1034 | + p_noise = p_noise_raw / scale_factor |
875 | 1035 | p_stay = 1.0 - p_clean - p_noise |
876 | 1036 |
|
877 | 1037 | routing_weights = RoutingWeights(stay=p_stay, noise=p_noise, clean=p_clean) |
|
0 commit comments