Skip to content

Commit 632fc3c

Browse files
committed
It is alive
1 parent daf75ee commit 632fc3c

3 files changed

Lines changed: 17 additions & 6 deletions

File tree

.gitignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,3 @@
11
/sandbox.py
2+
/data/
3+
/models/

Config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ class Config(_ConfigPaths, _ConfigAgent):
3030
train_batch_size = 64
3131
training_buffer_len = 1000
3232
min_n_points_to_finish = 15
33-
n_simulations = 250
33+
n_simulations = 100
3434
n_games = None
3535
n_players = 2
3636
n_actions = 46

main.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import re
22
from collections import deque
33
from copy import deepcopy
4-
from itertools import count
4+
from itertools import count, islice
55
from pathlib import Path
66

77
import torch
@@ -16,19 +16,28 @@ def main():
1616
training_buffer = deque(maxlen=Config.training_buffer_len)
1717
agents = deque((Agent(Config.n_players) for _ in range(Config.n_players)), maxlen=Config.n_players)
1818
if Config.pretrain:
19+
for agent, checkpoint_index in zip(islice(reversed(agents), 1, None), sorted((int(path.name.split('.')[0]) for path in Config.model_path.iterdir()), reverse=True)):
20+
agent.load_state_dict(torch.load(Config.model_path.joinpath(f'{checkpoint_index}.pth')))
21+
agents[-1].load_state_dict(torch.load(Config.model_path.joinpath(f"{max(int(path.name.split('.')[0]) for path in Config.model_path.iterdir())}.pth")))
1922
training_buffer += list(map(eval, map(Path.read_text, sorted(Config.data_path.iterdir(), key=lambda path: int(path.name), reverse=True)[:Config.training_buffer_len])))
2023
train_agent(agents[-1], training_buffer)
2124
scores = deque(maxlen=Config.max_results_held)
2225
for _ in (count() if Config.n_games is None else range(Config.n_games)):
2326
buffer, winner = self_play(agents)
24-
Config.data_path.joinpath(str(max((*tuple(int(path.name) for path in Config.data_path.iterdir()), -1)) + 1)).write_text(str((list(buffer[0][0]), list(buffer[0][1]), buffer[0][2])))
27+
start_index = max((*tuple(int(path.name) for path in Config.data_path.iterdir()), -1)) + 1
28+
for start_index, sample in enumerate(buffer, start_index + 1):
29+
Config.data_path.joinpath(str(start_index)).write_text(str((list(sample[0]), list(sample[1]), sample[2])))
2530
scores.append(agents[-1] is winner)
26-
if len(scores) >= Config.min_games_to_replace_agents and sum(scores) > Config.minimal_relative_agent_improvement * len(scores) / len(agents):
31+
if (len(scores) < Config.min_games_to_replace_agents and sum(scores) >= Config.minimal_relative_agent_improvement * Config.min_games_to_replace_agents / len(agents)) or (len(scores) >= Config.min_games_to_replace_agents and sum(scores) >= Config.minimal_relative_agent_improvement * len(scores) / len(agents)):
2732
torch.save(agents[-1].state_dict(), Config.model_path.joinpath(str(max(map(int, (*re.findall(r'\d+', ''.join(map(str, Config.model_path.iterdir()))), -1))) + 1) + ".pth"))
28-
agents.append(Agent(Config.n_players).load_state_dict(deepcopy(agents[-1].state_dict())))
33+
agents.append(Agent(Config.n_players))
34+
agents[-1].load_state_dict(deepcopy(agents[-1].state_dict()))
2935
agents[-1].training = True
3036
scores = deque(maxlen=Config.max_results_held)
31-
print(sum(scores) / len(scores), len(scores))
37+
elif len(scores) >= Config.min_games_to_replace_agents:
38+
print(f'{len(scores)} {sum(scores) / len(scores):.2f}')
39+
else:
40+
print(f'{len(scores)} {sum(scores)}/{len(scores)}')
3241
training_buffer += buffer
3342
train_agent(agents[-1], training_buffer)
3443

0 commit comments

Comments
 (0)