@@ -963,5 +963,145 @@ def __call__(self, routing_weights, logits, x0, xt, time, next_time, key):
963963 chex .assert_trees_all_equal (out , routing_weights )
964964
965965
966+ class GreedyPlannerTest (absltest .TestCase ):
967+
968+ def test_greedy_planner_budget (self ):
969+ planner = discrete_step_sampler .GreedyPlanner ()
970+ # 1 batch, 4 seq len. Realistic routing: all eligible with stay/noise > 0.
971+ routing_weights = discrete_step_sampler .RoutingWeights (
972+ stay = jnp .full ((1 , 4 , 1 ), 0.3 ),
973+ noise = jnp .full ((1 , 4 , 1 ), 0.2 ),
974+ clean = jnp .full ((1 , 4 , 1 ), 0.5 ),
975+ )
976+ logits = jnp .array (
977+ [[[10.0 , 0.0 ], [5.0 , 0.0 ], [2.0 , 0.0 ], [1.0 , 0.0 ]]]
978+ ) # high confidence first
979+ x0 = jnp .zeros ((1 , 4 , 1 ), dtype = jnp .int32 )
980+ xt = jnp .ones ((1 , 4 , 1 ), dtype = jnp .int32 )
981+ time = jnp .array ([1.0 ])
982+ next_time = jnp .array ([0.5 ]) # frac = 0.5 -> budget = 4 eligible * 0.5 = 2
983+ key = jax .random .PRNGKey (0 )
984+
985+ out_probs = planner (routing_weights , logits , x0 , xt , time , next_time , key )
986+
987+ # Top 2 positions → force CLEAN (stay=0, noise=0, clean=1).
988+ # Non-selected → keep original stay/noise, zero out clean.
989+ expected = discrete_step_sampler .RoutingWeights (
990+ stay = jnp .array ([[[0.0 ], [0.0 ], [0.3 ], [0.3 ]]]),
991+ noise = jnp .array ([[[0.0 ], [0.0 ], [0.2 ], [0.2 ]]]),
992+ clean = jnp .array ([[[1.0 ], [1.0 ], [0.0 ], [0.0 ]]]),
993+ )
994+ chex .assert_trees_all_close (out_probs .stay , expected .stay )
995+ chex .assert_trees_all_close (out_probs .noise , expected .noise )
996+ chex .assert_trees_all_close (out_probs .clean , expected .clean )
997+
998+ def test_greedy_planner_eligibility (self ):
999+ planner = discrete_step_sampler .GreedyPlanner ()
1000+ # Position 0 is NOT eligible (p_clean = 0), has original stay=1.0.
1001+ routing_weights = discrete_step_sampler .RoutingWeights (
1002+ stay = jnp .array ([[[1.0 ], [0.3 ], [0.3 ], [0.3 ]]]),
1003+ noise = jnp .array ([[[0.0 ], [0.2 ], [0.2 ], [0.2 ]]]),
1004+ clean = jnp .array ([[[0.0 ], [0.5 ], [0.5 ], [0.5 ]]]),
1005+ )
1006+ logits = jnp .array (
1007+ [[[10.0 , 0.0 ], [5.0 , 0.0 ], [2.0 , 0.0 ], [1.0 , 0.0 ]]]
1008+ ) # Pos 0 has highest logit but ineligible
1009+ x0 = jnp .zeros ((1 , 4 , 1 ), dtype = jnp .int32 )
1010+ xt = jnp .ones ((1 , 4 , 1 ), dtype = jnp .int32 )
1011+ time = jnp .array ([1.0 ])
1012+ next_time = jnp .array ([0.5 ]) # frac = 0.5 -> budget = 3 eligible * 0.5 = 1
1013+ key = jax .random .PRNGKey (0 )
1014+
1015+ out_probs = planner (routing_weights , logits , x0 , xt , time , next_time , key )
1016+
1017+ # Pos 0 is ineligible (p_clean=0), so num_eligible=3.
1018+ # Budget = 3 * 0.5 = 1 (truncated to int).
1019+ # Top 1 eligible position by confidence: Pos 1 → force CLEAN.
1020+ # Non-selected (Pos 0, 2, 3) → keep original stay/noise, zero clean.
1021+ expected = discrete_step_sampler .RoutingWeights (
1022+ stay = jnp .array ([[[1.0 ], [0.0 ], [0.3 ], [0.3 ]]]),
1023+ noise = jnp .array ([[[0.0 ], [0.0 ], [0.2 ], [0.2 ]]]),
1024+ clean = jnp .array ([[[0.0 ], [1.0 ], [0.0 ], [0.0 ]]]),
1025+ )
1026+ chex .assert_trees_all_close (out_probs .stay , expected .stay )
1027+ chex .assert_trees_all_close (out_probs .noise , expected .noise )
1028+ chex .assert_trees_all_close (out_probs .clean , expected .clean )
1029+
1030+ def test_greedy_planner_k_zero (self ):
1031+ """When clean weight is small, budget k=0. Keep original stay/noise."""
1032+ planner = discrete_step_sampler .GreedyPlanner ()
1033+ routing_weights = discrete_step_sampler .RoutingWeights (
1034+ stay = jnp .full ((1 , 4 , 1 ), 0.7 ),
1035+ noise = jnp .full ((1 , 4 , 1 ), 0.2 ),
1036+ clean = jnp .full ((1 , 4 , 1 ), 0.1 ),
1037+ )
1038+ logits = jnp .array ([[[10.0 , 0.0 ], [5.0 , 0.0 ], [2.0 , 0.0 ], [1.0 , 0.0 ]]])
1039+ x0 = jnp .zeros ((1 , 4 , 1 ), dtype = jnp .int32 )
1040+ xt = jnp .ones ((1 , 4 , 1 ), dtype = jnp .int32 )
1041+ time = jnp .array ([1.0 ])
1042+ next_time = jnp .array ([1.0 ])
1043+ key = jax .random .PRNGKey (0 )
1044+
1045+ out_probs = planner (routing_weights , logits , x0 , xt , time , next_time , key )
1046+
1047+ # k=0: no positions selected. All keep original stay/noise, clean zeroed.
1048+ expected = discrete_step_sampler .RoutingWeights (
1049+ stay = jnp .full ((1 , 4 , 1 ), 0.7 ),
1050+ noise = jnp .full ((1 , 4 , 1 ), 0.2 ),
1051+ clean = jnp .zeros ((1 , 4 , 1 )),
1052+ )
1053+ chex .assert_trees_all_close (out_probs .stay , expected .stay )
1054+ chex .assert_trees_all_close (out_probs .noise , expected .noise )
1055+ chex .assert_trees_all_close (out_probs .clean , expected .clean )
1056+
1057+ def test_greedy_planner_2d_spatial (self ):
1058+ """GreedyPlanner must work with 2D spatial data (e.g.
1059+
1060+ adjacency matrices).
1061+ """
1062+ planner = discrete_step_sampler .GreedyPlanner ()
1063+ # Shape (1, 3, 3, 1): batch=1, spatial=(3,3), vocab_trailing=1.
1064+ routing_weights = discrete_step_sampler .RoutingWeights (
1065+ stay = jnp .full ((1 , 3 , 3 , 1 ), 0.3 ),
1066+ noise = jnp .full ((1 , 3 , 3 , 1 ), 0.2 ),
1067+ clean = jnp .full ((1 , 3 , 3 , 1 ), 0.5 ),
1068+ )
1069+ # 9 positions total, 2 vocab classes.
1070+ # Logits: position (0,0) has highest confidence, then (0,1), etc.
1071+ logits_flat = jnp .array ([
1072+ [10.0 , 0.0 ],
1073+ [9.0 , 0.0 ],
1074+ [8.0 , 0.0 ],
1075+ [7.0 , 0.0 ],
1076+ [6.0 , 0.0 ],
1077+ [5.0 , 0.0 ],
1078+ [4.0 , 0.0 ],
1079+ [3.0 , 0.0 ],
1080+ [2.0 , 0.0 ],
1081+ ])
1082+ logits = logits_flat .reshape (1 , 3 , 3 , 2 )
1083+ x0 = jnp .zeros ((1 , 3 , 3 , 1 ), dtype = jnp .int32 )
1084+ xt = jnp .ones ((1 , 3 , 3 , 1 ), dtype = jnp .int32 )
1085+ time = jnp .array ([1.0 ])
1086+ next_time = jnp .array ([0.5 ]) # frac = 0.5 -> budget = 9 * 0.5 = 4
1087+ key = jax .random .PRNGKey (0 )
1088+
1089+ out = planner (routing_weights , logits , x0 , xt , time , next_time , key )
1090+
1091+ # Output must have the same spatial shape.
1092+ self .assertEqual (out .stay .shape , (1 , 3 , 3 , 1 ))
1093+ self .assertEqual (out .noise .shape , (1 , 3 , 3 , 1 ))
1094+ self .assertEqual (out .clean .shape , (1 , 3 , 3 , 1 ))
1095+
1096+ # Budget k = 4. Top-4 positions (by confidence) → forced CLEAN.
1097+ # Positions (0,0), (0,1), (0,2), (1,0) should be selected.
1098+ selected = out .clean [0 , :, :, 0 ] # (3, 3)
1099+ num_selected = int (jnp .sum (selected > 0 ))
1100+ self .assertEqual (num_selected , 4 )
1101+ # Selected positions have stay=0, noise=0.
1102+ self .assertEqual (float (jnp .sum (out .stay [out .clean > 0 ])), 0.0 )
1103+ self .assertEqual (float (jnp .sum (out .noise [out .clean > 0 ])), 0.0 )
1104+
1105+
9661106if __name__ == '__main__' :
9671107 absltest .main ()
0 commit comments