Skip to content

Commit f2408eb

Browse files
committed
Speed up complete
1 parent 8ff5b7a commit f2408eb

4 files changed

Lines changed: 31 additions & 28 deletions

File tree

Config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ class Config(_ConfigPaths, _ConfigAgent):
3939
train_batch_size = 128
4040
training_buffer_len = 100_000
4141
min_n_points_to_finish = 15
42-
n_simulations = 100
42+
n_simulations = 1000
4343
n_games = None
4444
n_players = 2
4545
n_actions = 45

agent/policy.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,12 @@ def policy(
1313
c: float,
1414
n_simulations: int,
1515
):
16-
N = defaultdict(list)
16+
N = defaultdict(lambda: defaultdict(int))
1717
visited = set()
18-
P = defaultdict(list)
19-
Q = defaultdict(list)
18+
P = defaultdict(dict)
19+
Q = defaultdict(dict)
2020
initial_state = game.get_state()
21-
all_moves = game.get_possible_actions()
21+
all_moves = game.all_moves
2222
for _ in range(n_simulations):
2323
search(game.copy(), agent, c, N, visited, P, Q)
2424
pi = np.array([N[initial_state][a] for a in all_moves])

agent/search.py

Lines changed: 25 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
1-
from collections import defaultdict
21
from math import sqrt
2+
from collections import defaultdict
33

4+
import torch
45
from torch import nn, Tensor
56

67
from src.Game import Game
@@ -10,42 +11,45 @@ def search(
1011
game: Game,
1112
agent: nn.Module,
1213
c: float,
13-
N: defaultdict[list[int]],
14+
N: defaultdict,
1415
visited: set,
15-
P: defaultdict[list],
16-
Q: defaultdict[list],
16+
P: defaultdict,
17+
Q: defaultdict,
1718
):
1819
if game.is_terminal():
1920
return -game.get_results()[game.current_player.id]
2021
state = game.get_state()
2122
if state not in visited:
2223
visited.add(state)
23-
move_scores, v = agent(Tensor([state]))
24+
with torch.no_grad():
25+
move_scores, v = agent(Tensor([state]))
2426
tuple(
25-
P[state].__setitem__(move, move_scores[0, index])
27+
P[state].__setitem__(move, move_scores[0, index].item())
2628
for index, move in enumerate(game.all_moves)
2729
)
28-
return -v
30+
return -v.item()
2931
q_state = Q[state]
3032
p_state = P[state]
3133
n_state = N[state]
3234
sqrt_value = sqrt(sum(n_state.values()))
33-
def _get_action(game: Game):
34-
return max(
35-
game.get_possible_actions(),
36-
key=lambda action: q_state.get(action, 1) + c * p_state[action] * sqrt_value / (1 + n_state[action]),
37-
)
35+
3836
# def _get_action(game: Game):
39-
# actions = sorted(
40-
# game.all_moves,
41-
# key=lambda action: q_state.get(action, 1)
42-
# + c * p_state[action] * sqrt_value / (1 + n_state[action]),
43-
# reverse=True,
37+
# return max(
38+
# game.get_possible_actions(),
39+
# key=lambda action: q_state.get(action, 1) + c * p_state[action] * sqrt_value / (1 + n_state[action]),
4440
# )
45-
# for action in actions:
46-
# if action.is_valid(game):
47-
# return action
48-
action = _get_action(game)
41+
# def _get_action(game: Game):
42+
# best_action = None
43+
# best_value = -float('inf')
44+
# for action in game.all_moves:
45+
# value = q_state.get(action, 1) + c * p_state[action] * sqrt_value / (1 + n_state[action])
46+
# if value > best_value and action.is_valid(game):
47+
# best_value, best_action = value, action
48+
# return best_action
49+
action = max(
50+
game.get_possible_actions(),
51+
key=lambda action: q_state.get(action, 1) + c * p_state[action] * sqrt_value / (1 + n_state[action]),
52+
)
4953
next_game_state = game.perform(action)
5054
v = search(next_game_state, agent, c, N, visited, P, Q)
5155

agent/self_play.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ def self_play(
3232
def _perform_game(
3333
game: Game, states: list, id_to_agent: dict[int, Agent]
3434
) -> tuple[list[tuple[np.array, np.array, int]], Agent]:
35-
for turn in tqdm(count()):
35+
for _ in tqdm(count()):
3636
agent = id_to_agent[game.current_player.id]
3737
pi, action = policy(game, agent, Config.c, Config.n_simulations)
3838
states.append((game, pi / pi.sum(), 0))
@@ -47,7 +47,6 @@ def _perform_game(
4747
int(result[state[0].current_player.id] == 1),
4848
)
4949
for state in states
50-
if state[1] != game.null_move
5150
),
5251
id_to_agent[
5352
next(player.id for player in game.players if result[player.id])

0 commit comments

Comments
 (0)