@@ -963,5 +963,97 @@ def __call__(self, routing_weights, logits, x0, xt, time, next_time, key):
963963 chex .assert_trees_all_equal (out , routing_weights )
964964
965965
966+ class GreedyPlannerTest (absltest .TestCase ):
967+
968+ def test_greedy_planner_budget (self ):
969+ planner = discrete_step_sampler .GreedyPlanner ()
970+ # 1 batch, 4 seq len. Realistic routing: all eligible with stay/noise > 0.
971+ routing_weights = discrete_step_sampler .RoutingWeights (
972+ stay = jnp .full ((1 , 4 , 1 ), 0.3 ),
973+ noise = jnp .full ((1 , 4 , 1 ), 0.2 ),
974+ clean = jnp .full ((1 , 4 , 1 ), 0.5 ),
975+ )
976+ logits = jnp .array (
977+ [[[10.0 , 0.0 ], [5.0 , 0.0 ], [2.0 , 0.0 ], [1.0 , 0.0 ]]]
978+ ) # high confidence first
979+ x0 = jnp .zeros ((1 , 4 , 1 ), dtype = jnp .int32 )
980+ xt = jnp .ones ((1 , 4 , 1 ), dtype = jnp .int32 )
981+ time = jnp .array ([1.0 ])
982+ next_time = jnp .array ([0.5 ]) # frac = 0.5 -> budget = 4 eligible * 0.5 = 2
983+ key = jax .random .PRNGKey (0 )
984+
985+ out_probs = planner (routing_weights , logits , x0 , xt , time , next_time , key )
986+
987+ # Top 2 positions → force CLEAN (stay=0, noise=0, clean=1).
988+ # Non-selected → keep original stay/noise, zero out clean.
989+ expected = discrete_step_sampler .RoutingWeights (
990+ stay = jnp .array ([[[0.0 ], [0.0 ], [0.3 ], [0.3 ]]]),
991+ noise = jnp .array ([[[0.0 ], [0.0 ], [0.2 ], [0.2 ]]]),
992+ clean = jnp .array ([[[1.0 ], [1.0 ], [0.0 ], [0.0 ]]]),
993+ )
994+ chex .assert_trees_all_close (out_probs .stay , expected .stay )
995+ chex .assert_trees_all_close (out_probs .noise , expected .noise )
996+ chex .assert_trees_all_close (out_probs .clean , expected .clean )
997+
998+ def test_greedy_planner_eligibility (self ):
999+ planner = discrete_step_sampler .GreedyPlanner ()
1000+ # Position 0 is NOT eligible (p_clean = 0), has original stay=1.0.
1001+ routing_weights = discrete_step_sampler .RoutingWeights (
1002+ stay = jnp .array ([[[1.0 ], [0.3 ], [0.3 ], [0.3 ]]]),
1003+ noise = jnp .array ([[[0.0 ], [0.2 ], [0.2 ], [0.2 ]]]),
1004+ clean = jnp .array ([[[0.0 ], [0.5 ], [0.5 ], [0.5 ]]]),
1005+ )
1006+ logits = jnp .array (
1007+ [[[10.0 , 0.0 ], [5.0 , 0.0 ], [2.0 , 0.0 ], [1.0 , 0.0 ]]]
1008+ ) # Pos 0 has highest logit but ineligible
1009+ x0 = jnp .zeros ((1 , 4 , 1 ), dtype = jnp .int32 )
1010+ xt = jnp .ones ((1 , 4 , 1 ), dtype = jnp .int32 )
1011+ time = jnp .array ([1.0 ])
1012+ next_time = jnp .array ([0.5 ]) # frac = 0.5 -> budget = 3 eligible * 0.5 = 1
1013+ key = jax .random .PRNGKey (0 )
1014+
1015+ out_probs = planner (routing_weights , logits , x0 , xt , time , next_time , key )
1016+
1017+ # Pos 0 is ineligible (p_clean=0), so num_eligible=3.
1018+ # Budget = 3 * 0.5 = 1 (truncated to int).
1019+ # Top 1 eligible position by confidence: Pos 1 → force CLEAN.
1020+ # Non-selected (Pos 0, 2, 3) → keep original stay/noise, zero clean.
1021+ expected = discrete_step_sampler .RoutingWeights (
1022+ stay = jnp .array ([[[1.0 ], [0.0 ], [0.3 ], [0.3 ]]]),
1023+ noise = jnp .array ([[[0.0 ], [0.0 ], [0.2 ], [0.2 ]]]),
1024+ clean = jnp .array ([[[0.0 ], [1.0 ], [0.0 ], [0.0 ]]]),
1025+ )
1026+ chex .assert_trees_all_close (out_probs .stay , expected .stay )
1027+ chex .assert_trees_all_close (out_probs .noise , expected .noise )
1028+ chex .assert_trees_all_close (out_probs .clean , expected .clean )
1029+
1030+ def test_greedy_planner_k_zero (self ):
1031+ """When clean weight is small, budget k=0. Keep original stay/noise."""
1032+ planner = discrete_step_sampler .GreedyPlanner ()
1033+ routing_weights = discrete_step_sampler .RoutingWeights (
1034+ stay = jnp .full ((1 , 4 , 1 ), 0.7 ),
1035+ noise = jnp .full ((1 , 4 , 1 ), 0.2 ),
1036+ clean = jnp .full ((1 , 4 , 1 ), 0.1 ),
1037+ )
1038+ logits = jnp .array ([[[10.0 , 0.0 ], [5.0 , 0.0 ], [2.0 , 0.0 ], [1.0 , 0.0 ]]])
1039+ x0 = jnp .zeros ((1 , 4 , 1 ), dtype = jnp .int32 )
1040+ xt = jnp .ones ((1 , 4 , 1 ), dtype = jnp .int32 )
1041+ time = jnp .array ([1.0 ])
1042+ next_time = jnp .array ([1.0 ])
1043+ key = jax .random .PRNGKey (0 )
1044+
1045+ out_probs = planner (routing_weights , logits , x0 , xt , time , next_time , key )
1046+
1047+ # k=0: no positions selected. All keep original stay/noise, clean zeroed.
1048+ expected = discrete_step_sampler .RoutingWeights (
1049+ stay = jnp .full ((1 , 4 , 1 ), 0.7 ),
1050+ noise = jnp .full ((1 , 4 , 1 ), 0.2 ),
1051+ clean = jnp .zeros ((1 , 4 , 1 )),
1052+ )
1053+ chex .assert_trees_all_close (out_probs .stay , expected .stay )
1054+ chex .assert_trees_all_close (out_probs .noise , expected .noise )
1055+ chex .assert_trees_all_close (out_probs .clean , expected .clean )
1056+
1057+
9661058if __name__ == '__main__' :
9671059 absltest .main ()
0 commit comments