Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions hackable_diffusion/lib/sampling/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from hackable_diffusion.lib.sampling.discrete_step_sampler import NoRemaskingFn
from hackable_diffusion.lib.sampling.discrete_step_sampler import RemaskingFn
from hackable_diffusion.lib.sampling.discrete_step_sampler import RescaledRemaskingFn
from hackable_diffusion.lib.sampling.discrete_step_sampler import truncate_at_stop_tokens
from hackable_diffusion.lib.sampling.discrete_step_sampler import UnMaskingStep
from hackable_diffusion.lib.sampling.gaussian_step_sampler import AdjustedDDIMStep
from hackable_diffusion.lib.sampling.gaussian_step_sampler import DDIMStep
Expand Down
109 changes: 105 additions & 4 deletions hackable_diffusion/lib/sampling/discrete_step_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,87 @@
CategoricalProcess = discrete.CategoricalProcess
DiscreteSchedule = schedules.DiscreteSchedule


################################################################################
# MARK: Stop token truncation
################################################################################


def truncate_at_stop_tokens(
tokens: DataArray,
stop_tokens: tuple[int, ...],
pad_token: int = 0,
) -> DataArray:
"""Replace tokens after the first stop token with pad_token.

This is a utility for discrete diffusion samplers that support stop-token
based truncation (e.g. for text generation). The function identifies the
first occurrence of any token in `stop_tokens` along the sequence dimension
and pads everything after it.

The function supports arbitrary leading batch dimensions. The sequence
dimension is assumed to be the second-to-last axis (``tokens.shape[-2]``),
and the last axis is the feature dimension (typically size 1).

Args:
tokens: Token array of shape ``(*batch, seq_len, features)``.
stop_tokens: Tuple of token IDs that signal end of generation.
pad_token: Token ID to use for padding after the first stop token.

Returns:
Token array with the same shape, padded after the first stop token.
"""
stop_arr = jnp.array(stop_tokens, dtype=jnp.int32)

# Identify stop token positions.
is_stop = jnp.isin(tokens[..., 0], stop_arr)

# keepdims=True makes broadcasting work for any number of batch dims.
has_stop = jnp.any(is_stop, axis=-1, keepdims=True)
first_stop_idx = jnp.argmax(is_stop.astype(jnp.int32), axis=-1, keepdims=True)

# Create sequence indices: shape (seq_len,)
seq_len = tokens.shape[-2]
seq_idx = jnp.arange(seq_len)

# Broadcast: (..., 1) vs (seq_len,) handles batch dims automatically.
stop_limit = jnp.where(has_stop, first_stop_idx, seq_len)
keep_mask = seq_idx <= stop_limit

# Apply mask to the last dimension (features) using [..., None].
return jnp.where(keep_mask[..., None], tokens, pad_token)


def _finalize_with_stop_tokens(
sampler: SamplerStep,
prediction: TargetInfo,
current_step: DiffusionStep,
last_step_info: StepInfo,
stop_tokens: tuple[int, ...] | None,
pad_token: int,
) -> DiffusionStep:
"""Common finalize logic: run update then optionally truncate at stop tokens.

Args:
sampler: The sampler step whose ``update`` method is called.
prediction: The prediction from the inference function.
current_step: The current diffusion step.
last_step_info: The step info for the final step.
stop_tokens: If not None, truncate at the first occurrence of any of these
token IDs.
pad_token: Token ID used for padding after the first stop token.

Returns:
The finalized diffusion step, with tokens truncated at the first stop
token if ``stop_tokens`` is not None.
"""
result = sampler.update(prediction, current_step, last_step_info)
if stop_tokens is not None:
new_xt = truncate_at_stop_tokens(result.xt, stop_tokens, pad_token)
result = result.replace(xt=new_xt)
return result


################################################################################
# MARK: Remasking strategy
################################################################################
Expand Down Expand Up @@ -250,6 +331,8 @@ class UnMaskingStep(SamplerStep):
corruption_mask_fn: CorruptedMaskFn = AllCorruptedMaskFn()
temperature: float = 1.0
logits_dtype: jnp.dtype = jnp.float32
stop_tokens: tuple[int, ...] | None = None
pad_token: int = 0

def __post_init__(self):
"""UnMaskingStep only supports masking processes.
Expand Down Expand Up @@ -369,10 +452,13 @@ def finalize(
current_step: DiffusionStep,
last_step_info: StepInfo,
) -> DiffusionStep:
return self.update(
return _finalize_with_stop_tokens(
self,
prediction,
current_step,
last_step_info,
self.stop_tokens,
self.pad_token,
)


Expand Down Expand Up @@ -406,6 +492,8 @@ class DiscreteDDIMStep(SamplerStep):
corruption_process: CategoricalProcess
temperature: float = 1.0
logits_dtype: jnp.dtype = jnp.float32
stop_tokens: tuple[int, ...] | None = None
pad_token: int = 0

def __post_init__(self):
"""DiscreteDDIMStep does not support masking processes.
Expand Down Expand Up @@ -533,10 +621,13 @@ def finalize(
current_step: DiffusionStep,
last_step_info: StepInfo,
) -> DiffusionStep:
return self.update(
return _finalize_with_stop_tokens(
self,
prediction,
current_step,
last_step_info,
self.stop_tokens,
self.pad_token,
)


Expand Down Expand Up @@ -584,6 +675,8 @@ class IntegratedDiscreteDDIMStep(SamplerStep):
corruption_process: CategoricalProcess
temperature: float = 1.0
logits_dtype: jnp.dtype = jnp.float32
stop_tokens: tuple[int, ...] | None = None
pad_token: int = 0

def __post_init__(self):
"""IntegratedDiscreteDDIMStep does not support masking processes.
Expand Down Expand Up @@ -711,10 +804,13 @@ def finalize(
current_step: DiffusionStep,
last_step_info: StepInfo,
) -> DiffusionStep:
return self.update(
return _finalize_with_stop_tokens(
self,
prediction,
current_step,
last_step_info,
self.stop_tokens,
self.pad_token,
)


Expand Down Expand Up @@ -750,6 +846,8 @@ class DiscreteFlowMatchingStep(SamplerStep):
temperature: float = 1.0
gamma: float = 0.0
logits_dtype: jnp.dtype = jnp.float32
stop_tokens: tuple[int, ...] | None = None
pad_token: int = 0

@kt.typechecked
def initialize(
Expand Down Expand Up @@ -869,8 +967,11 @@ def finalize(
current_step: DiffusionStep,
last_step_info: StepInfo,
) -> DiffusionStep:
return self.update(
return _finalize_with_stop_tokens(
self,
prediction,
current_step,
last_step_info,
self.stop_tokens,
self.pad_token,
)
Loading