1- from collections import defaultdict
21from math import sqrt
2+ from collections import defaultdict
33
4+ import torch
45from torch import nn , Tensor
56
67from src .Game import Game
@@ -10,42 +11,45 @@ def search(
1011 game : Game ,
1112 agent : nn .Module ,
1213 c : float ,
13- N : defaultdict [ list [ int ]] ,
14+ N : defaultdict ,
1415 visited : set ,
15- P : defaultdict [ list ] ,
16- Q : defaultdict [ list ] ,
16+ P : defaultdict ,
17+ Q : defaultdict ,
1718):
1819 if game .is_terminal ():
1920 return - game .get_results ()[game .current_player .id ]
2021 state = game .get_state ()
2122 if state not in visited :
2223 visited .add (state )
23- move_scores , v = agent (Tensor ([state ]))
24+ with torch .no_grad ():
25+ move_scores , v = agent (Tensor ([state ]))
2426 tuple (
25- P [state ].__setitem__ (move , move_scores [0 , index ])
27+ P [state ].__setitem__ (move , move_scores [0 , index ]. item () )
2628 for index , move in enumerate (game .all_moves )
2729 )
28- return - v
30+ return - v . item ()
2931 q_state = Q [state ]
3032 p_state = P [state ]
3133 n_state = N [state ]
3234 sqrt_value = sqrt (sum (n_state .values ()))
33- def _get_action (game : Game ):
34- return max (
35- game .get_possible_actions (),
36- key = lambda action : q_state .get (action , 1 ) + c * p_state [action ] * sqrt_value / (1 + n_state [action ]),
37- )
35+
3836 # def _get_action(game: Game):
39- # actions = sorted(
40- # game.all_moves,
41- # key=lambda action: q_state.get(action, 1)
42- # + c * p_state[action] * sqrt_value / (1 + n_state[action]),
43- # reverse=True,
37+ # return max(
38+ # game.get_possible_actions(),
39+ # key=lambda action: q_state.get(action, 1) + c * p_state[action] * sqrt_value / (1 + n_state[action]),
4440 # )
45- # for action in actions:
46- # if action.is_valid(game):
47- # return action
48- action = _get_action (game )
41+ # def _get_action(game: Game):
42+ # best_action = None
43+ # best_value = -float('inf')
44+ # for action in game.all_moves:
45+ # value = q_state.get(action, 1) + c * p_state[action] * sqrt_value / (1 + n_state[action])
46+ # if value > best_value and action.is_valid(game):
47+ # best_value, best_action = value, action
48+ # return best_action
49+ action = max (
50+ game .get_possible_actions (),
51+ key = lambda action : q_state .get (action , 1 ) + c * p_state [action ] * sqrt_value / (1 + n_state [action ]),
52+ )
4953 next_game_state = game .perform (action )
5054 v = search (next_game_state , agent , c , N , visited , P , Q )
5155
0 commit comments