Skip to content

Commit ae44b17

Browse files
vdebortoHackable Diffusion Authors
authored andcommitted
Factor out candidate generation and post-processing in discrete samplers to reduce boilerplate. This makes the update methods focus purely on computing the routing probabilities.
PiperOrigin-RevId: 911403532
1 parent bcf84ea commit ae44b17

1 file changed

Lines changed: 39 additions & 35 deletions

File tree

hackable_diffusion/lib/sampling/discrete_step_sampler.py

Lines changed: 39 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -321,6 +321,27 @@ def _sample_routing(
321321
return new_xt
322322

323323

324+
def _generate_candidates(
325+
corruption_process: CategoricalProcess,
326+
prediction: TargetInfo,
327+
xt: DataArray,
328+
time_bcast: TimeArray,
329+
key: PRNGKey,
330+
temperature: float,
331+
) -> tuple[DataArray, DataArray, Float['... M']]:
332+
"""Generate candidate x0, x_noise samples and logits."""
333+
logits = corruption_process.convert_predictions(prediction, xt, time_bcast)[
334+
'logits'
335+
]
336+
logits = logits / temperature
337+
338+
x0_key, noise_key = jax.random.split(key)
339+
x0 = jax.random.categorical(key=x0_key, logits=logits)[..., None]
340+
x_noise = corruption_process.sample_from_invariant(noise_key, data_spec=xt)
341+
342+
return x0, x_noise, logits
343+
344+
324345
class RoutingStrategy(Protocol):
325346
"""Protocol for transforming routing weights.
326347
@@ -453,20 +474,15 @@ def update(
453474
next_time_bcast = utils.bcast_right(next_time, xt.ndim)
454475
key = next_step_info.rng
455476

456-
# Get model predictions
457-
logits = self.corruption_process.convert_predictions(
477+
# Get model predictions and candidates
478+
_, candidate_key, plan_key, route_key = jax.random.split(key, 4)
479+
x0, x_noise, logits = _generate_candidates(
480+
self.corruption_process,
458481
prediction,
459482
xt,
460483
time_bcast,
461-
)['logits']
462-
logits = logits / self.temperature
463-
464-
_, x0_key, noise_key, plan_key, route_key = jax.random.split(key, 5)
465-
466-
# Sample candidates
467-
x0 = jax.random.categorical(key=x0_key, logits=logits)[..., None]
468-
x_noise = self.corruption_process.sample_from_invariant(
469-
noise_key, data_spec=xt
484+
candidate_key,
485+
self.temperature,
470486
)
471487

472488
currently_masked = self.corruption_mask_fn(xt) # (bsz, seq_len, 1)
@@ -675,20 +691,15 @@ def update(
675691
next_time_bcast = utils.bcast_right(next_time, xt.ndim)
676692
key = next_step_info.rng
677693

678-
# Get model predictions
679-
logits = self.corruption_process.convert_predictions(
694+
# Get model predictions and candidates
695+
_, candidate_key, plan_key, route_key = jax.random.split(key, 4)
696+
x0, x_noise, logits = _generate_candidates(
697+
self.corruption_process,
680698
prediction,
681699
xt,
682700
time_bcast,
683-
)['logits']
684-
logits = logits / self.temperature
685-
686-
_, x0_key, noise_key, plan_key, route_key = jax.random.split(key, 5)
687-
688-
# Sample candidates
689-
x0 = jax.random.categorical(key=x0_key, logits=logits)[..., None]
690-
x_noise = self.corruption_process.sample_from_invariant(
691-
noise_key, data_spec=xt
701+
candidate_key,
702+
self.temperature,
692703
)
693704

694705
# Schedule
@@ -744,8 +755,6 @@ def update(
744755
step_info=next_step_info,
745756
aux={'logits': logits},
746757
)
747-
# `logits` need to be passed in `aux` dictionary to a performance
748-
# bug when using TPU. Needs to be investigated.
749758

750759
@kt.typechecked
751760
def finalize(
@@ -829,20 +838,15 @@ def update(
829838
next_time_bcast = utils.bcast_right(next_time, xt.ndim)
830839
key = next_step_info.rng
831840

832-
# Get model predictions
833-
logits = self.corruption_process.convert_predictions(
841+
# Get model predictions and candidates
842+
_, candidate_key, plan_key, route_key = jax.random.split(key, 4)
843+
x0, x_noise, logits = _generate_candidates(
844+
self.corruption_process,
834845
prediction,
835846
xt,
836847
time_bcast,
837-
)['logits']
838-
logits = logits / self.temperature
839-
840-
_, x0_key, noise_key, plan_key, route_key = jax.random.split(key, 5)
841-
842-
# Sample candidates
843-
x0 = jax.random.categorical(key=x0_key, logits=logits)[..., None]
844-
x_noise = self.corruption_process.sample_from_invariant(
845-
noise_key, data_spec=xt
848+
candidate_key,
849+
self.temperature,
846850
)
847851

848852
# Denoising rates

0 commit comments

Comments
 (0)