|
| 1 | +import random |
1 | 2 | import re |
2 | 3 | from collections import deque |
3 | 4 | from copy import deepcopy |
4 | | -from itertools import count, islice |
5 | | -from pathlib import Path |
| 5 | +from itertools import count |
6 | 6 |
|
7 | 7 | import torch |
8 | 8 |
|
9 | 9 | from Config import Config |
10 | 10 | from agent.Agent import Agent |
| 11 | +from agent.pretrain import pretrain |
| 12 | +from agent.save import save_temp_buffer |
11 | 13 | from agent.self_play import self_play |
12 | 14 | from agent.train_agent import train_agent |
13 | 15 |
|
14 | 16 |
|
15 | 17 | def main(): |
16 | 18 | training_buffer = deque(maxlen=Config.training_buffer_len) |
17 | | - agents = deque((Agent(Config.n_players) for _ in range(Config.n_players)), maxlen=Config.n_players) |
| 19 | + agents = deque( |
| 20 | + (Agent(Config.n_players) for _ in range(Config.n_players)), |
| 21 | + maxlen=Config.n_players, |
| 22 | + ) |
18 | 23 | if Config.pretrain: |
19 | | - for agent, checkpoint_index in zip(islice(reversed(agents), 1, None), sorted((int(path.name.split('.')[0]) for path in Config.model_path.iterdir()), reverse=True)): |
20 | | - agent.load_state_dict(torch.load(Config.model_path.joinpath(f'{checkpoint_index}.pth'))) |
21 | | - newest = Config.model_path.joinpath( |
22 | | - f"{max((*tuple(int(path.name.split('.')[0]) for path in Config.model_path.iterdir()), 0))}.pth") |
23 | | - if newest.exists(): |
24 | | - agents[-1].load_state_dict(torch.load(newest)) |
25 | | - training_buffer += list(map(eval, map(Path.read_text, sorted(Config.training_data_path.iterdir(), key=lambda path: int(path.name), reverse=True)[:Config.training_buffer_len]))) |
26 | | - train_agent(agents[-1], training_buffer) |
| 24 | + pretrain(agents) |
27 | 25 | scores = deque(maxlen=Config.max_results_held) |
28 | | - for _ in (count() if Config.n_games is None else range(Config.n_games)): |
| 26 | + for _ in count() if Config.n_games is None else range(Config.n_games): |
29 | 27 | buffer, winners = self_play(agents) |
30 | | - start_index = max((*tuple(int(path.name) for path in Config.training_data_path.iterdir()), -1)) + 1 |
31 | | - for start_index, sample in enumerate(buffer, start_index + 1): |
32 | | - Config.training_data_path.joinpath(str(start_index)).write_text(str((list(sample[0]), list(sample[1]), sample[2]))) |
| 28 | + to_train = random.random() > Config.eval_rate |
| 29 | + save_temp_buffer(buffer, to_train) |
33 | 30 | for winner in winners: |
34 | 31 | scores.append(agents[-1] is winner) |
35 | | - if (len(scores) < Config.min_games_to_replace_agents and sum(scores) >= Config.minimal_relative_agent_improvement * Config.min_games_to_replace_agents / len(agents)) or (len(scores) >= Config.min_games_to_replace_agents and sum(scores) >= Config.minimal_relative_agent_improvement * len(scores) / len(agents)): |
36 | | - torch.save(agents[-1].state_dict(), Config.model_path.joinpath(str(max(map(int, (*re.findall(r'\d+', ''.join(map(str, Config.model_path.iterdir()))), -1))) + 1) + ".pth")) |
| 32 | + if ( |
| 33 | + len(scores) < Config.min_games_to_replace_agents |
| 34 | + and sum(scores) |
| 35 | + > Config.minimal_relative_agent_improvement |
| 36 | + * Config.min_games_to_replace_agents |
| 37 | + / len(agents) |
| 38 | + ) or ( |
| 39 | + len(scores) > Config.min_games_to_replace_agents |
| 40 | + and sum(scores) |
| 41 | + >= Config.minimal_relative_agent_improvement * len(scores) / len(agents) |
| 42 | + ): |
| 43 | + torch.save( |
| 44 | + agents[-1].state_dict(), |
| 45 | + Config.model_path.joinpath( |
| 46 | + str( |
| 47 | + max( |
| 48 | + map( |
| 49 | + int, |
| 50 | + ( |
| 51 | + *re.findall( |
| 52 | + r"\d+", |
| 53 | + "".join(map(str, Config.model_path.iterdir())), |
| 54 | + ), |
| 55 | + -1, |
| 56 | + ), |
| 57 | + ) |
| 58 | + ) |
| 59 | + + 1 |
| 60 | + ) |
| 61 | + + ".pth" |
| 62 | + ), |
| 63 | + ) |
37 | 64 | agents.append(Agent(Config.n_players)) |
38 | | - agents[-1].load_state_dict(deepcopy(agents[-1].state_dict())) |
| 65 | + agents[-1].load_state_dict(deepcopy(agents[-2].state_dict())) |
39 | 66 | agents[-1].training = True |
40 | 67 | scores = deque(maxlen=Config.max_results_held) |
41 | 68 | elif len(scores) >= Config.min_games_to_replace_agents: |
42 | | - print(f'{len(scores)} {sum(scores) / len(scores):.2f}') |
| 69 | + print(f"{len(scores)} {sum(scores) / len(scores):.2f}") |
43 | 70 | else: |
44 | | - print(f'{len(scores)} {sum(scores)}/{len(scores)}') |
45 | | - training_buffer += buffer |
46 | | - train_agent(agents[-1], training_buffer) |
| 71 | + print(f"{len(scores)} {sum(scores)}/{len(scores)}") |
| 72 | + if to_train: |
| 73 | + training_buffer += buffer |
| 74 | + train_agent(agents[-1], training_buffer) |
47 | 75 |
|
48 | 76 |
|
49 | 77 | if __name__ == "__main__": |
|
0 commit comments