diff --git a/hackable_diffusion/lib/sampling/__init__.py b/hackable_diffusion/lib/sampling/__init__.py index 1b38266..8c84bd9 100644 --- a/hackable_diffusion/lib/sampling/__init__.py +++ b/hackable_diffusion/lib/sampling/__init__.py @@ -30,6 +30,7 @@ from hackable_diffusion.lib.sampling.discrete_step_sampler import NoRemaskingFn from hackable_diffusion.lib.sampling.discrete_step_sampler import RemaskingFn from hackable_diffusion.lib.sampling.discrete_step_sampler import RescaledRemaskingFn +from hackable_diffusion.lib.sampling.discrete_step_sampler import truncate_at_stop_tokens from hackable_diffusion.lib.sampling.discrete_step_sampler import UnMaskingStep from hackable_diffusion.lib.sampling.gaussian_step_sampler import AdjustedDDIMStep from hackable_diffusion.lib.sampling.gaussian_step_sampler import DDIMStep diff --git a/hackable_diffusion/lib/sampling/discrete_step_sampler.py b/hackable_diffusion/lib/sampling/discrete_step_sampler.py index 4e1f3c6..8b13184 100644 --- a/hackable_diffusion/lib/sampling/discrete_step_sampler.py +++ b/hackable_diffusion/lib/sampling/discrete_step_sampler.py @@ -60,6 +60,87 @@ CategoricalProcess = discrete.CategoricalProcess DiscreteSchedule = schedules.DiscreteSchedule + +################################################################################ +# MARK: Stop token truncation +################################################################################ + + +def truncate_at_stop_tokens( + tokens: DataArray, + stop_tokens: tuple[int, ...], + pad_token: int = 0, +) -> DataArray: + """Replace tokens after the first stop token with pad_token. + + This is a utility for discrete diffusion samplers that support stop-token + based truncation (e.g. for text generation). The function identifies the + first occurrence of any token in `stop_tokens` along the sequence dimension + and pads everything after it. + + The function supports arbitrary leading batch dimensions. The sequence + dimension is assumed to be the second-to-last axis (``tokens.shape[-2]``), + and the last axis is the feature dimension (typically size 1). + + Args: + tokens: Token array of shape ``(*batch, seq_len, features)``. + stop_tokens: Tuple of token IDs that signal end of generation. + pad_token: Token ID to use for padding after the first stop token. + + Returns: + Token array with the same shape, padded after the first stop token. + """ + stop_arr = jnp.array(stop_tokens, dtype=jnp.int32) + + # Identify stop token positions. + is_stop = jnp.isin(tokens[..., 0], stop_arr) + + # keepdims=True makes broadcasting work for any number of batch dims. + has_stop = jnp.any(is_stop, axis=-1, keepdims=True) + first_stop_idx = jnp.argmax(is_stop.astype(jnp.int32), axis=-1, keepdims=True) + + # Create sequence indices: shape (seq_len,) + seq_len = tokens.shape[-2] + seq_idx = jnp.arange(seq_len) + + # Broadcast: (..., 1) vs (seq_len,) handles batch dims automatically. + stop_limit = jnp.where(has_stop, first_stop_idx, seq_len) + keep_mask = seq_idx <= stop_limit + + # Apply mask to the last dimension (features) using [..., None]. + return jnp.where(keep_mask[..., None], tokens, pad_token) + + +def _finalize_with_stop_tokens( + sampler: SamplerStep, + prediction: TargetInfo, + current_step: DiffusionStep, + last_step_info: StepInfo, + stop_tokens: tuple[int, ...] | None, + pad_token: int, +) -> DiffusionStep: + """Common finalize logic: run update then optionally truncate at stop tokens. + + Args: + sampler: The sampler step whose ``update`` method is called. + prediction: The prediction from the inference function. + current_step: The current diffusion step. + last_step_info: The step info for the final step. + stop_tokens: If not None, truncate at the first occurrence of any of these + token IDs. + pad_token: Token ID used for padding after the first stop token. + + Returns: + The finalized diffusion step, with tokens truncated at the first stop + token if ``stop_tokens`` is not None. + """ + result = sampler.update(prediction, current_step, last_step_info) + if stop_tokens is not None: + new_xt = truncate_at_stop_tokens(result.xt, stop_tokens, pad_token) + result = result.replace(xt=new_xt) + return result + + ################################################################################ # MARK: Remasking strategy ################################################################################ @@ -250,6 +331,8 @@ class UnMaskingStep(SamplerStep): corruption_mask_fn: CorruptedMaskFn = AllCorruptedMaskFn() temperature: float = 1.0 logits_dtype: jnp.dtype = jnp.float32 + stop_tokens: tuple[int, ...] | None = None + pad_token: int = 0 def __post_init__(self): """UnMaskingStep only supports masking processes. @@ -369,10 +452,13 @@ def finalize( current_step: DiffusionStep, last_step_info: StepInfo, ) -> DiffusionStep: - return self.update( + return _finalize_with_stop_tokens( + self, prediction, current_step, last_step_info, + self.stop_tokens, + self.pad_token, ) @@ -406,6 +492,8 @@ class DiscreteDDIMStep(SamplerStep): corruption_process: CategoricalProcess temperature: float = 1.0 logits_dtype: jnp.dtype = jnp.float32 + stop_tokens: tuple[int, ...] | None = None + pad_token: int = 0 def __post_init__(self): """DiscreteDDIMStep does not support masking processes. @@ -533,10 +621,13 @@ def finalize( current_step: DiffusionStep, last_step_info: StepInfo, ) -> DiffusionStep: - return self.update( + return _finalize_with_stop_tokens( + self, prediction, current_step, last_step_info, + self.stop_tokens, + self.pad_token, ) @@ -584,6 +675,8 @@ class IntegratedDiscreteDDIMStep(SamplerStep): corruption_process: CategoricalProcess temperature: float = 1.0 logits_dtype: jnp.dtype = jnp.float32 + stop_tokens: tuple[int, ...] | None = None + pad_token: int = 0 def __post_init__(self): """IntegratedDiscreteDDIMStep does not support masking processes. @@ -711,10 +804,13 @@ def finalize( current_step: DiffusionStep, last_step_info: StepInfo, ) -> DiffusionStep: - return self.update( + return _finalize_with_stop_tokens( + self, prediction, current_step, last_step_info, + self.stop_tokens, + self.pad_token, ) @@ -750,6 +846,8 @@ class DiscreteFlowMatchingStep(SamplerStep): temperature: float = 1.0 gamma: float = 0.0 logits_dtype: jnp.dtype = jnp.float32 + stop_tokens: tuple[int, ...] | None = None + pad_token: int = 0 @kt.typechecked def initialize( @@ -869,8 +967,11 @@ def finalize( current_step: DiffusionStep, last_step_info: StepInfo, ) -> DiffusionStep: - return self.update( + return _finalize_with_stop_tokens( + self, prediction, current_step, last_step_info, + self.stop_tokens, + self.pad_token, )