@@ -321,6 +321,27 @@ def _sample_routing(
321321 return new_xt
322322
323323
324+ def _generate_candidates (
325+ corruption_process : CategoricalProcess ,
326+ prediction : TargetInfo ,
327+ xt : DataArray ,
328+ time_bcast : TimeArray ,
329+ key : PRNGKey ,
330+ temperature : float ,
331+ ) -> tuple [DataArray , DataArray , Float ['... M' ]]:
332+ """Generate candidate x0, x_noise samples and logits."""
333+ logits = corruption_process .convert_predictions (prediction , xt , time_bcast )[
334+ 'logits'
335+ ]
336+ logits = logits / temperature
337+
338+ x0_key , noise_key = jax .random .split (key )
339+ x0 = jax .random .categorical (key = x0_key , logits = logits )[..., None ]
340+ x_noise = corruption_process .sample_from_invariant (noise_key , data_spec = xt )
341+
342+ return x0 , x_noise , logits
343+
344+
324345class RoutingStrategy (Protocol ):
325346 """Protocol for transforming routing weights.
326347
@@ -453,20 +474,15 @@ def update(
453474 next_time_bcast = utils .bcast_right (next_time , xt .ndim )
454475 key = next_step_info .rng
455476
456- # Get model predictions
457- logits = self .corruption_process .convert_predictions (
477+ # Get model predictions and candidates
478+ _ , candidate_key , plan_key , route_key = jax .random .split (key , 4 )
479+ x0 , x_noise , logits = _generate_candidates (
480+ self .corruption_process ,
458481 prediction ,
459482 xt ,
460483 time_bcast ,
461- )['logits' ]
462- logits = logits / self .temperature
463-
464- _ , x0_key , noise_key , plan_key , route_key = jax .random .split (key , 5 )
465-
466- # Sample candidates
467- x0 = jax .random .categorical (key = x0_key , logits = logits )[..., None ]
468- x_noise = self .corruption_process .sample_from_invariant (
469- noise_key , data_spec = xt
484+ candidate_key ,
485+ self .temperature ,
470486 )
471487
472488 currently_masked = self .corruption_mask_fn (xt ) # (bsz, seq_len, 1)
@@ -675,20 +691,15 @@ def update(
675691 next_time_bcast = utils .bcast_right (next_time , xt .ndim )
676692 key = next_step_info .rng
677693
678- # Get model predictions
679- logits = self .corruption_process .convert_predictions (
694+ # Get model predictions and candidates
695+ _ , candidate_key , plan_key , route_key = jax .random .split (key , 4 )
696+ x0 , x_noise , logits = _generate_candidates (
697+ self .corruption_process ,
680698 prediction ,
681699 xt ,
682700 time_bcast ,
683- )['logits' ]
684- logits = logits / self .temperature
685-
686- _ , x0_key , noise_key , plan_key , route_key = jax .random .split (key , 5 )
687-
688- # Sample candidates
689- x0 = jax .random .categorical (key = x0_key , logits = logits )[..., None ]
690- x_noise = self .corruption_process .sample_from_invariant (
691- noise_key , data_spec = xt
701+ candidate_key ,
702+ self .temperature ,
692703 )
693704
694705 # Schedule
@@ -744,8 +755,6 @@ def update(
744755 step_info = next_step_info ,
745756 aux = {'logits' : logits },
746757 )
747- # `logits` need to be passed in `aux` dictionary to a performance
748- # bug when using TPU. Needs to be investigated.
749758
750759 @kt .typechecked
751760 def finalize (
@@ -829,20 +838,15 @@ def update(
829838 next_time_bcast = utils .bcast_right (next_time , xt .ndim )
830839 key = next_step_info .rng
831840
832- # Get model predictions
833- logits = self .corruption_process .convert_predictions (
841+ # Get model predictions and candidates
842+ _ , candidate_key , plan_key , route_key = jax .random .split (key , 4 )
843+ x0 , x_noise , logits = _generate_candidates (
844+ self .corruption_process ,
834845 prediction ,
835846 xt ,
836847 time_bcast ,
837- )['logits' ]
838- logits = logits / self .temperature
839-
840- _ , x0_key , noise_key , plan_key , route_key = jax .random .split (key , 5 )
841-
842- # Sample candidates
843- x0 = jax .random .categorical (key = x0_key , logits = logits )[..., None ]
844- x_noise = self .corruption_process .sample_from_invariant (
845- noise_key , data_spec = xt
848+ candidate_key ,
849+ self .temperature ,
846850 )
847851
848852 # Denoising rates
0 commit comments