diff --git a/hackable_diffusion/lib/sampling/__init__.py b/hackable_diffusion/lib/sampling/__init__.py index 1b38266..eeca37a 100644 --- a/hackable_diffusion/lib/sampling/__init__.py +++ b/hackable_diffusion/lib/sampling/__init__.py @@ -20,6 +20,7 @@ from hackable_diffusion.lib.sampling.base import SamplerStep from hackable_diffusion.lib.sampling.base import StepInfo from hackable_diffusion.lib.sampling.base import StepInfoTree +from hackable_diffusion.lib.sampling.base import UpdateConditioningFn from hackable_diffusion.lib.sampling.discrete_step_sampler import AllCorruptedMaskFn from hackable_diffusion.lib.sampling.discrete_step_sampler import CorruptedMaskFn from hackable_diffusion.lib.sampling.discrete_step_sampler import DiscreteDDIMStep diff --git a/hackable_diffusion/lib/sampling/base.py b/hackable_diffusion/lib/sampling/base.py index 415bed4..4f928a0 100644 --- a/hackable_diffusion/lib/sampling/base.py +++ b/hackable_diffusion/lib/sampling/base.py @@ -79,6 +79,7 @@ DataArray = hd_typing.DataArray DataTree = hd_typing.DataTree +Conditioning = hd_typing.Conditioning TargetInfoTree = hd_typing.TargetInfoTree TimeArray = hd_typing.TimeArray @@ -128,6 +129,7 @@ class DiffusionStep: DiffusionStepTree = PyTree[DiffusionStep] + ################################################################################ # MARK: Protocols ################################################################################ @@ -161,3 +163,28 @@ def finalize( ) -> DiffusionStepTree: """Performs the final step to produce the clean output sample.""" ... + + +class UpdateConditioningFn(Protocol): + """Protocol for updating conditioning during the sampling loop. + + This allows injecting step-dependent information back into the conditioning + dict between sampling steps (e.g. self-conditioning logits from the + previous prediction). + """ + + def __call__( + self, + conditioning: Conditioning, + step_carry: DiffusionStepTree, + ) -> Conditioning: + """Update conditioning based on the current diffusion step. + + Args: + conditioning: The current conditioning dict. + step_carry: The current diffusion step state. + + Returns: + The updated conditioning dict. + """ + ... diff --git a/hackable_diffusion/lib/sampling/discrete_step_sampler.py b/hackable_diffusion/lib/sampling/discrete_step_sampler.py index 4e1f3c6..8b13184 100644 --- a/hackable_diffusion/lib/sampling/discrete_step_sampler.py +++ b/hackable_diffusion/lib/sampling/discrete_step_sampler.py @@ -60,6 +60,87 @@ CategoricalProcess = discrete.CategoricalProcess DiscreteSchedule = schedules.DiscreteSchedule + +################################################################################ +# MARK: Stop token truncation +################################################################################ + + +def truncate_at_stop_tokens( + tokens: DataArray, + stop_tokens: tuple[int, ...], + pad_token: int = 0, +) -> DataArray: + """Replace tokens after the first stop token with pad_token. + + This is a utility for discrete diffusion samplers that support stop-token + based truncation (e.g. for text generation). The function identifies the + first occurrence of any token in `stop_tokens` along the sequence dimension + and pads everything after it. + + The function supports arbitrary leading batch dimensions. The sequence + dimension is assumed to be the second-to-last axis (``tokens.shape[-2]``), + and the last axis is the feature dimension (typically size 1). + + Args: + tokens: Token array of shape ``(*batch, seq_len, features)``. + stop_tokens: Tuple of token IDs that signal end of generation. + pad_token: Token ID to use for padding after the first stop token. + + Returns: + Token array with the same shape, padded after the first stop token. + """ + stop_arr = jnp.array(stop_tokens, dtype=jnp.int32) + + # Identify stop token positions. + is_stop = jnp.isin(tokens[..., 0], stop_arr) + + # keepdims=True makes broadcasting work for any number of batch dims. + has_stop = jnp.any(is_stop, axis=-1, keepdims=True) + first_stop_idx = jnp.argmax(is_stop.astype(jnp.int32), axis=-1, keepdims=True) + + # Create sequence indices: shape (seq_len,) + seq_len = tokens.shape[-2] + seq_idx = jnp.arange(seq_len) + + # Broadcast: (..., 1) vs (seq_len,) handles batch dims automatically. + stop_limit = jnp.where(has_stop, first_stop_idx, seq_len) + keep_mask = seq_idx <= stop_limit + + # Apply mask to the last dimension (features) using [..., None]. + return jnp.where(keep_mask[..., None], tokens, pad_token) + + +def _finalize_with_stop_tokens( + sampler: SamplerStep, + prediction: TargetInfo, + current_step: DiffusionStep, + last_step_info: StepInfo, + stop_tokens: tuple[int, ...] | None, + pad_token: int, +) -> DiffusionStep: + """Common finalize logic: run update then optionally truncate at stop tokens. + + Args: + sampler: The sampler step whose ``update`` method is called. + prediction: The prediction from the inference function. + current_step: The current diffusion step. + last_step_info: The step info for the final step. + stop_tokens: If not None, truncate at the first occurrence of any of these + token IDs. + pad_token: Token ID used for padding after the first stop token. + + Returns: + The finalized diffusion step, with tokens truncated at the first stop + token if ``stop_tokens`` is not None. + """ + result = sampler.update(prediction, current_step, last_step_info) + if stop_tokens is not None: + new_xt = truncate_at_stop_tokens(result.xt, stop_tokens, pad_token) + result = result.replace(xt=new_xt) + return result + + ################################################################################ # MARK: Remasking strategy ################################################################################ @@ -250,6 +331,8 @@ class UnMaskingStep(SamplerStep): corruption_mask_fn: CorruptedMaskFn = AllCorruptedMaskFn() temperature: float = 1.0 logits_dtype: jnp.dtype = jnp.float32 + stop_tokens: tuple[int, ...] | None = None + pad_token: int = 0 def __post_init__(self): """UnMaskingStep only supports masking processes. @@ -369,10 +452,13 @@ def finalize( current_step: DiffusionStep, last_step_info: StepInfo, ) -> DiffusionStep: - return self.update( + return _finalize_with_stop_tokens( + self, prediction, current_step, last_step_info, + self.stop_tokens, + self.pad_token, ) @@ -406,6 +492,8 @@ class DiscreteDDIMStep(SamplerStep): corruption_process: CategoricalProcess temperature: float = 1.0 logits_dtype: jnp.dtype = jnp.float32 + stop_tokens: tuple[int, ...] | None = None + pad_token: int = 0 def __post_init__(self): """DiscreteDDIMStep does not support masking processes. @@ -533,10 +621,13 @@ def finalize( current_step: DiffusionStep, last_step_info: StepInfo, ) -> DiffusionStep: - return self.update( + return _finalize_with_stop_tokens( + self, prediction, current_step, last_step_info, + self.stop_tokens, + self.pad_token, ) @@ -584,6 +675,8 @@ class IntegratedDiscreteDDIMStep(SamplerStep): corruption_process: CategoricalProcess temperature: float = 1.0 logits_dtype: jnp.dtype = jnp.float32 + stop_tokens: tuple[int, ...] | None = None + pad_token: int = 0 def __post_init__(self): """IntegratedDiscreteDDIMStep does not support masking processes. @@ -711,10 +804,13 @@ def finalize( current_step: DiffusionStep, last_step_info: StepInfo, ) -> DiffusionStep: - return self.update( + return _finalize_with_stop_tokens( + self, prediction, current_step, last_step_info, + self.stop_tokens, + self.pad_token, ) @@ -750,6 +846,8 @@ class DiscreteFlowMatchingStep(SamplerStep): temperature: float = 1.0 gamma: float = 0.0 logits_dtype: jnp.dtype = jnp.float32 + stop_tokens: tuple[int, ...] | None = None + pad_token: int = 0 @kt.typechecked def initialize( @@ -869,8 +967,11 @@ def finalize( current_step: DiffusionStep, last_step_info: StepInfo, ) -> DiffusionStep: - return self.update( + return _finalize_with_stop_tokens( + self, prediction, current_step, last_step_info, + self.stop_tokens, + self.pad_token, ) diff --git a/hackable_diffusion/lib/sampling/sampling.py b/hackable_diffusion/lib/sampling/sampling.py index de46bc5..0c6a08b 100644 --- a/hackable_diffusion/lib/sampling/sampling.py +++ b/hackable_diffusion/lib/sampling/sampling.py @@ -47,6 +47,7 @@ InferenceFn = inference_base.InferenceFn TimeSchedule = time_scheduling.TimeSchedule +UpdateConditioningFn = base.UpdateConditioningFn ################################################################################ # MARK: Protocols @@ -130,11 +131,14 @@ class DiffusionSampler(SampleFn): time_schedule: Defines the sequence of time steps for the process. stepper: The sampling algorithm (e.g., DDIM) that updates the state. num_steps: The total number of denoising steps. + update_conditioning_fn: An optional function to update the conditioning at + each step. """ time_schedule: TimeSchedule stepper: SamplerStep num_steps: int + update_conditioning_fn: UpdateConditioningFn | None = None @kt.typechecked def __call__( @@ -181,9 +185,14 @@ def __call__( def scan_body(step_carry: DiffusionStepTree, next_step_info: StepInfoTree): xt, time = _get_input_inference_fn(step_carry) + updated_conditioning = conditioning + if self.update_conditioning_fn is not None: + updated_conditioning = self.update_conditioning_fn( + conditioning, step_carry + ) prediction = inference_fn( xt=xt, - conditioning=conditioning, + conditioning=updated_conditioning, time=time, ) next_step = self.stepper.update( @@ -198,9 +207,14 @@ def scan_body(step_carry: DiffusionStepTree, next_step_info: StepInfoTree): ) xt, time = _get_input_inference_fn(before_last_step) + last_conditioning = conditioning + if self.update_conditioning_fn is not None: + last_conditioning = self.update_conditioning_fn( + conditioning, before_last_step + ) last_prediction = inference_fn( xt=xt, - conditioning=conditioning, + conditioning=last_conditioning, time=time, )