diff --git a/hackable_diffusion/lib/sampling/discrete_step_sampler.py b/hackable_diffusion/lib/sampling/discrete_step_sampler.py index 746489d..74e675c 100644 --- a/hackable_diffusion/lib/sampling/discrete_step_sampler.py +++ b/hackable_diffusion/lib/sampling/discrete_step_sampler.py @@ -321,6 +321,27 @@ def _sample_routing( return new_xt +def _generate_candidates( + corruption_process: CategoricalProcess, + prediction: TargetInfo, + xt: DataArray, + time_bcast: TimeArray, + key: PRNGKey, + temperature: float, +) -> tuple[DataArray, DataArray, Float['... M']]: + """Generate candidate x0, x_noise samples and logits.""" + logits = corruption_process.convert_predictions(prediction, xt, time_bcast)[ + 'logits' + ] + logits = logits / temperature + + x0_key, noise_key = jax.random.split(key) + x0 = jax.random.categorical(key=x0_key, logits=logits)[..., None] + x_noise = corruption_process.sample_from_invariant(noise_key, data_spec=xt) + + return x0, x_noise, logits + + class RoutingStrategy(Protocol): """Protocol for transforming routing weights. @@ -453,20 +474,15 @@ def update( next_time_bcast = utils.bcast_right(next_time, xt.ndim) key = next_step_info.rng - # Get model predictions - logits = self.corruption_process.convert_predictions( + # Get model predictions and candidates + _, candidate_key, plan_key, route_key = jax.random.split(key, 4) + x0, x_noise, logits = _generate_candidates( + self.corruption_process, prediction, xt, time_bcast, - )['logits'] - logits = logits / self.temperature - - _, x0_key, noise_key, plan_key, route_key = jax.random.split(key, 5) - - # Sample candidates - x0 = jax.random.categorical(key=x0_key, logits=logits)[..., None] - x_noise = self.corruption_process.sample_from_invariant( - noise_key, data_spec=xt + candidate_key, + self.temperature, ) currently_masked = self.corruption_mask_fn(xt) # (bsz, seq_len, 1) @@ -675,20 +691,15 @@ def update( next_time_bcast = utils.bcast_right(next_time, xt.ndim) key = next_step_info.rng - # Get model predictions - logits = self.corruption_process.convert_predictions( + # Get model predictions and candidates + _, candidate_key, plan_key, route_key = jax.random.split(key, 4) + x0, x_noise, logits = _generate_candidates( + self.corruption_process, prediction, xt, time_bcast, - )['logits'] - logits = logits / self.temperature - - _, x0_key, noise_key, plan_key, route_key = jax.random.split(key, 5) - - # Sample candidates - x0 = jax.random.categorical(key=x0_key, logits=logits)[..., None] - x_noise = self.corruption_process.sample_from_invariant( - noise_key, data_spec=xt + candidate_key, + self.temperature, ) # Schedule @@ -744,8 +755,6 @@ def update( step_info=next_step_info, aux={'logits': logits}, ) - # `logits` need to be passed in `aux` dictionary to a performance - # bug when using TPU. Needs to be investigated. @kt.typechecked def finalize( @@ -829,20 +838,15 @@ def update( next_time_bcast = utils.bcast_right(next_time, xt.ndim) key = next_step_info.rng - # Get model predictions - logits = self.corruption_process.convert_predictions( + # Get model predictions and candidates + _, candidate_key, plan_key, route_key = jax.random.split(key, 4) + x0, x_noise, logits = _generate_candidates( + self.corruption_process, prediction, xt, time_bcast, - )['logits'] - logits = logits / self.temperature - - _, x0_key, noise_key, plan_key, route_key = jax.random.split(key, 5) - - # Sample candidates - x0 = jax.random.categorical(key=x0_key, logits=logits)[..., None] - x_noise = self.corruption_process.sample_from_invariant( - noise_key, data_spec=xt + candidate_key, + self.temperature, ) # Denoising rates