@@ -10,10 +10,10 @@ def search(
1010 game : Game ,
1111 agent : nn .Module ,
1212 c : float ,
13- N : defaultdict ,
13+ N : defaultdict [ list [ int ]] ,
1414 visited : set ,
15- P : defaultdict ,
16- Q : defaultdict ,
15+ P : defaultdict [ list ] ,
16+ Q : defaultdict [ list ] ,
1717):
1818 if game .is_terminal ():
1919 return - game .get_results ()[game .current_player .id ]
@@ -26,13 +26,26 @@ def search(
2626 for index , move in enumerate (game .all_moves )
2727 )
2828 return - v
29-
30- action = max (
31- game .get_possible_actions (),
32- key = lambda action : Q [state ].get (action , 1 )
33- + c * P [state ][action ] * sqrt (sum (N [state ].values ())) / (1 + N [state ][action ]),
34- )
35-
29+ q_state = Q [state ]
30+ p_state = P [state ]
31+ n_state = N [state ]
32+ sqrt_value = sqrt (sum (n_state .values ()))
33+ def _get_action (game : Game ):
34+ return max (
35+ game .get_possible_actions (),
36+ key = lambda action : q_state .get (action , 1 ) + c * p_state [action ] * sqrt_value / (1 + n_state [action ]),
37+ )
38+ # def _get_action(game: Game):
39+ # actions = sorted(
40+ # game.all_moves,
41+ # key=lambda action: q_state.get(action, 1)
42+ # + c * p_state[action] * sqrt_value / (1 + n_state[action]),
43+ # reverse=True,
44+ # )
45+ # for action in actions:
46+ # if action.is_valid(game):
47+ # return action
48+ action = _get_action (game )
3649 next_game_state = game .perform (action )
3750 v = search (next_game_state , agent , c , N , visited , P , Q )
3851
0 commit comments