Skip to content

Commit bf9b923

Browse files
committed
Made games more fair by replaying with switched positions
1 parent ac4ff65 commit bf9b923

3 files changed

Lines changed: 36 additions & 16 deletions

File tree

Config.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,10 @@
77

88
class _ConfigPaths:
99
root = Path(__file__).parent
10-
data_path = root / 'data'
11-
data_path.mkdir(exist_ok=True)
10+
training_data_path = root / 'training_data'
11+
training_data_path.mkdir(exist_ok=True)
12+
evaluation_data_path = root / 'evaluation_data'
13+
evaluation_data_path.mkdir(exist_ok=True)
1214
model_path = root / 'models'
1315
model_path.mkdir(exist_ok=True)
1416

@@ -26,9 +28,9 @@ class _ConfigAgent:
2628
class Config(_ConfigPaths, _ConfigAgent):
2729
max_results_held = 100
2830
minimal_relative_agent_improvement = 1.1
29-
min_games_to_replace_agents = 20
31+
min_games_to_replace_agents = 40
3032
train_batch_size = 64
31-
training_buffer_len = 1000
33+
training_buffer_len = 100_000
3234
min_n_points_to_finish = 15
3335
n_simulations = 100
3436
n_games = None

agent/self_play.py

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from collections import deque
2-
from itertools import count
2+
from itertools import count, cycle, islice
3+
from more_itertools import windowed
34

45
import numpy as np
56
from tqdm import tqdm
@@ -10,12 +11,23 @@
1011
from src.Game import Game
1112

1213

13-
def self_play(agents: deque[Agent]) -> tuple[list[tuple[np.array, np.array, int]], Agent]:
14+
def self_play(agents: deque[Agent]) -> tuple[list[tuple[np.array, np.array, int]], list[Agent]]:
1415
states = []
15-
game = Game(n_players=Config.n_players)
16-
id_to_agent = dict((player.id, agent) for agent, player in zip(agents, game.players))
16+
winners = []
17+
initial_state = Game(n_players=Config.n_players)
1718
for agent in agents:
1819
agent.eval()
20+
for agents_in_order in islice(windowed(cycle(agents), Config.n_players), Config.n_players):
21+
game = initial_state.copy()
22+
id_to_agent = dict((player.id, agent) for agent, player in zip(agents_in_order, game.players))
23+
results, winner = _perform_game(game, [], id_to_agent)
24+
states += results
25+
winners.append(winner)
26+
return states, winners
27+
28+
29+
def _perform_game(game: Game, states: list, id_to_agent: dict[int, Agent]) -> tuple[
30+
list[tuple[np.array, np.array, int]], Agent]:
1931
for _ in tqdm(count()):
2032
agent = id_to_agent[game.current_player.id]
2133
pi, action = policy(game, agent, Config.c, Config.n_simulations)
@@ -26,5 +38,7 @@ def self_play(agents: deque[Agent]) -> tuple[list[tuple[np.array, np.array, int]
2638
game = game.perform(action)
2739
if game.is_terminal():
2840
result = game.get_results()
29-
return (list((state[0].get_state(), state[1], int(result[state[0].current_player.id] == 1)) for state in states),
30-
id_to_agent[next(player.id for player in game.players if result[player.id])])
41+
return (
42+
list(
43+
(state[0].get_state(), state[1], int(result[state[0].current_player.id] == 1)) for state in states),
44+
id_to_agent[next(player.id for player in game.players if result[player.id])])

main.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -18,16 +18,20 @@ def main():
1818
if Config.pretrain:
1919
for agent, checkpoint_index in zip(islice(reversed(agents), 1, None), sorted((int(path.name.split('.')[0]) for path in Config.model_path.iterdir()), reverse=True)):
2020
agent.load_state_dict(torch.load(Config.model_path.joinpath(f'{checkpoint_index}.pth')))
21-
agents[-1].load_state_dict(torch.load(Config.model_path.joinpath(f"{max(int(path.name.split('.')[0]) for path in Config.model_path.iterdir())}.pth")))
22-
training_buffer += list(map(eval, map(Path.read_text, sorted(Config.data_path.iterdir(), key=lambda path: int(path.name), reverse=True)[:Config.training_buffer_len])))
21+
newest = Config.model_path.joinpath(
22+
f"{max((*tuple(int(path.name.split('.')[0]) for path in Config.model_path.iterdir()), 0))}.pth")
23+
if newest.exists():
24+
agents[-1].load_state_dict(torch.load(newest))
25+
training_buffer += list(map(eval, map(Path.read_text, sorted(Config.training_data_path.iterdir(), key=lambda path: int(path.name), reverse=True)[:Config.training_buffer_len])))
2326
train_agent(agents[-1], training_buffer)
2427
scores = deque(maxlen=Config.max_results_held)
2528
for _ in (count() if Config.n_games is None else range(Config.n_games)):
26-
buffer, winner = self_play(agents)
27-
start_index = max((*tuple(int(path.name) for path in Config.data_path.iterdir()), -1)) + 1
29+
buffer, winners = self_play(agents)
30+
start_index = max((*tuple(int(path.name) for path in Config.training_data_path.iterdir()), -1)) + 1
2831
for start_index, sample in enumerate(buffer, start_index + 1):
29-
Config.data_path.joinpath(str(start_index)).write_text(str((list(sample[0]), list(sample[1]), sample[2])))
30-
scores.append(agents[-1] is winner)
32+
Config.training_data_path.joinpath(str(start_index)).write_text(str((list(sample[0]), list(sample[1]), sample[2])))
33+
for winner in winners:
34+
scores.append(agents[-1] is winner)
3135
if (len(scores) < Config.min_games_to_replace_agents and sum(scores) >= Config.minimal_relative_agent_improvement * Config.min_games_to_replace_agents / len(agents)) or (len(scores) >= Config.min_games_to_replace_agents and sum(scores) >= Config.minimal_relative_agent_improvement * len(scores) / len(agents)):
3236
torch.save(agents[-1].state_dict(), Config.model_path.joinpath(str(max(map(int, (*re.findall(r'\d+', ''.join(map(str, Config.model_path.iterdir()))), -1))) + 1) + ".pth"))
3337
agents.append(Agent(Config.n_players))

0 commit comments

Comments
 (0)