Skip to content

Commit bcd46e1

Browse files
committed
Added pretraining with saved data
1 parent 04c4c39 commit bcd46e1

6 files changed

Lines changed: 26 additions & 15 deletions

File tree

Config.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,6 @@
44
import numpy as np
55
import torch
66

7-
random.seed(42)
8-
np.random.seed(42)
9-
torch.random.manual_seed(42)
10-
117

128
class _ConfigPaths:
139
root = Path(__file__).parent
@@ -17,12 +13,17 @@ class _ConfigPaths:
1713
model_path.mkdir(exist_ok=True)
1814

1915

20-
class Config(_ConfigPaths):
16+
class _ConfigAgent:
2117
# hidden_sizes = (256, 128, 64, 32)
22-
c = .1
2318
hidden_sizes = (256,)
2419
# hidden_sizes = tuple()
20+
c = .1
2521
learning_rate = 1e-3
22+
debug = False
23+
pretrain = True
24+
25+
26+
class Config(_ConfigPaths, _ConfigAgent):
2627
max_results_held = 100
2728
minimal_relative_agent_improvement = 1.1
2829
min_games_to_replace_agents = 20
@@ -33,3 +34,9 @@ class Config(_ConfigPaths):
3334
n_games = None
3435
n_players = 2
3536
n_actions = 46
37+
38+
39+
if Config.debug:
40+
random.seed(42)
41+
np.random.seed(42)
42+
torch.random.manual_seed(42)

agent/RLDataset.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,4 +12,4 @@ def __len__(self):
1212
return len(self.examples)
1313

1414
def __getitem__(self, index) -> tuple[np.array, ...]:
15-
return np.array(self.examples[index][0]), self.examples[index][1], np.array([self.examples[index][2]])
15+
return np.array(self.examples[index][0]), np.array(self.examples[index][1]), np.array([self.examples[index][2]])

agent/policy.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
def policy(
1111
game: Game,
12-
agent: nn.Module,
12+
agents: dict[int, nn.Module],
1313
c: float,
1414
n_simulations: int,
1515
):
@@ -20,6 +20,6 @@ def policy(
2020
initial_state = game.get_state()
2121
all_moves = game.get_possible_actions()
2222
for _ in range(n_simulations):
23-
search(game.copy(), agent, c, N, visited, P, Q)
23+
search(game.copy(), agents, c, N, visited, P, Q)
2424
pi = [N[initial_state][a] for a in all_moves]
2525
return pi, all_moves[np.argmax(pi)]

agent/search.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
def search(
1010
game: Game,
11-
agent: nn.Module,
11+
agents: dict[int, nn.Module],
1212
c: float,
1313
N: defaultdict,
1414
visited: set,
@@ -18,6 +18,7 @@ def search(
1818
if game.is_terminal():
1919
return game.get_results()[game.current_player.id]
2020
state = game.get_state()
21+
agent = agents[game.current_player.id]
2122
if state not in visited:
2223
visited.add(state)
2324
move_scores, v = agent(Tensor([state]))
@@ -34,7 +35,7 @@ def search(
3435
)
3536

3637
next_game_state = game.perform(action)
37-
v = search(next_game_state, agent, c, N, visited, P, Q)
38+
v = search(next_game_state, agents, c, N, visited, P, Q)
3839

3940
Q[state][action] = (N[state][action] * Q[state].get(action, 1) + v) / (
4041
N[state][action] + 1

agent/self_play.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from collections import deque
2-
from itertools import cycle, count
2+
from itertools import count
33

44
import numpy as np
55
from tqdm import tqdm
@@ -17,8 +17,7 @@ def self_play(agents: deque[Agent]) -> tuple[list[tuple[np.array, np.array, int]
1717
for agent in agents:
1818
agent.eval()
1919
for _ in tqdm(count()):
20-
agent = id_to_agent[game.current_player.id]
21-
pi, action = policy(game, agent, Config.c, Config.n_simulations)
20+
pi, action = policy(game, id_to_agent, Config.c, Config.n_simulations)
2221
action_index = game.all_moves.index(action)
2322
onehot_encoded_action = np.zeros(Config.n_actions)
2423
onehot_encoded_action[action_index] = 1

main.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from collections import deque
33
from copy import deepcopy
44
from itertools import count
5+
from pathlib import Path
56

67
import torch
78

@@ -14,10 +15,13 @@
1415
def main():
1516
training_buffer = deque(maxlen=Config.training_buffer_len)
1617
agents = deque((Agent(Config.n_players) for _ in range(Config.n_players)), maxlen=Config.n_players)
18+
if Config.pretrain:
19+
training_buffer += list(map(eval, map(Path.read_text, sorted(Config.data_path.iterdir(), key=lambda path: int(path.name), reverse=True)[:Config.training_buffer_len])))
20+
train_agent(agents[-1], training_buffer)
1721
scores = deque(maxlen=Config.max_results_held)
1822
for _ in (count() if Config.n_games is None else range(Config.n_games)):
1923
buffer, winner = self_play(agents)
20-
Config.data_path.joinpath(str(max((*tuple(map(int, map(str, Config.data_path.iterdir()))), -1)) + 1)).write_text(str((list(buffer[0][0]), list(buffer[0][1]), buffer[0][2])))
24+
Config.data_path.joinpath(str(max((*tuple(int(path.name) for path in Config.data_path.iterdir()), -1)) + 1)).write_text(str((list(buffer[0][0]), list(buffer[0][1]), buffer[0][2])))
2125
scores.append(agents[-1] is winner)
2226
if len(scores) >= Config.min_games_to_replace_agents and sum(scores) > Config.minimal_relative_agent_improvement * len(scores) / len(agents):
2327
torch.save(agents[-1].state_dict(), Config.model_path.joinpath(str(max(map(int, (*re.findall(r'\d+', ''.join(map(str, Config.model_path.iterdir()))), -1))) + 1) + ".pth"))

0 commit comments

Comments
 (0)