Skip to content

Commit 70d0516

Browse files
committed
Removed Null move possibility in the middle of a game
1 parent bf9b923 commit 70d0516

17 files changed

Lines changed: 272 additions & 84 deletions

.gitignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
11
/sandbox.py
22
/data/
33
/models/
4+
/evaluation_data/
5+
/training_data/

Config.py

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -7,22 +7,22 @@
77

88
class _ConfigPaths:
99
root = Path(__file__).parent
10-
training_data_path = root / 'training_data'
10+
training_data_path = root / "training_data"
1111
training_data_path.mkdir(exist_ok=True)
12-
evaluation_data_path = root / 'evaluation_data'
12+
evaluation_data_path = root / "evaluation_data"
1313
evaluation_data_path.mkdir(exist_ok=True)
14-
model_path = root / 'models'
14+
model_path = root / "models"
1515
model_path.mkdir(exist_ok=True)
1616

1717

1818
class _ConfigAgent:
1919
# hidden_sizes = (256, 128, 64, 32)
2020
hidden_sizes = (256,)
2121
# hidden_sizes = tuple()
22-
c = .1
23-
learning_rate = 1e-3
24-
debug = False
25-
pretrain = True
22+
c = 0.1
23+
learning_rate = 1e-4
24+
debug = True
25+
pretrain = False
2626

2727

2828
class Config(_ConfigPaths, _ConfigAgent):
@@ -35,10 +35,11 @@ class Config(_ConfigPaths, _ConfigAgent):
3535
n_simulations = 100
3636
n_games = None
3737
n_players = 2
38-
n_actions = 46
38+
n_actions = 45
39+
eval_rate = 0.2
3940

4041

4142
if Config.debug:
42-
random.seed(42)
43-
np.random.seed(42)
44-
torch.random.manual_seed(42)
43+
random.seed(69)
44+
np.random.seed(69)
45+
torch.random.manual_seed(69)

agent/Agent.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ def __init__(
1515
self,
1616
n_players: int,
1717
hidden_sizes: tuple = Config.hidden_sizes,
18-
n_moves: int = 46,
18+
n_moves: int = Config.n_actions,
1919
):
2020
super().__init__()
2121
self.tanh = nn.Tanh()

agent/RLDataset.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,5 +12,8 @@ def __len__(self):
1212
return len(self.examples)
1313

1414
def __getitem__(self, index) -> tuple[np.array, ...]:
15-
return np.array(self.examples[index][0]), np.array(self.examples[index][1]), np.array(
16-
[self.examples[index][2] * 2 - 1])
15+
return (
16+
np.array(self.examples[index][0]),
17+
np.array(self.examples[index][1]),
18+
np.array([self.examples[index][2] * 2 - 1]),
19+
)

agent/pretrain.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
import operator
2+
from collections import deque
3+
from functools import reduce
4+
from itertools import islice
5+
6+
import torch
7+
8+
from Config import Config
9+
from agent.Agent import Agent
10+
from agent.train_agent import train_agent
11+
12+
13+
def pretrain(agents: deque[Agent]):
14+
for agent, checkpoint_index in zip(
15+
islice(reversed(agents), 1, None),
16+
sorted(
17+
(int(path.name.split(".")[0]) for path in Config.model_path.iterdir()),
18+
reverse=True,
19+
),
20+
):
21+
agent.load_state_dict(
22+
torch.load(Config.model_path.joinpath(f"{checkpoint_index}.pth"))
23+
)
24+
newest = Config.model_path.joinpath(
25+
f"{max((*tuple(int(path.name.split('.')[0]) for path in Config.model_path.iterdir()), 0))}.pth"
26+
)
27+
if newest.exists():
28+
agents[-1].load_state_dict(torch.load(newest))
29+
training_buffer = reduce(
30+
operator.add,
31+
(
32+
deque(eval(path.read_text()))
33+
for path in sorted(
34+
Config.training_data_path.iterdir(), key=lambda path: int(path.name)
35+
)
36+
),
37+
deque(maxlen=Config.training_buffer_len),
38+
)
39+
train_agent(agents[-1], training_buffer)
40+
return training_buffer

agent/save.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
from Config import Config
2+
3+
4+
def save_temp_buffer(buffer, train: bool):
5+
path = Config.training_data_path if train else Config.evaluation_data_path
6+
index = max((*tuple(int(path.name) for path in path.iterdir()), -1)) + 1
7+
path.joinpath(str(index)).write_text(
8+
str(list((list(sample[0]), list(sample[1]), sample[2]) for sample in buffer))
9+
)

agent/search.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,4 +40,4 @@ def search(
4040
N[state][action] + 1
4141
)
4242
N[state][action] += 1
43-
return -v
43+
return -v

agent/self_play.py

Lines changed: 25 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -11,34 +11,50 @@
1111
from src.Game import Game
1212

1313

14-
def self_play(agents: deque[Agent]) -> tuple[list[tuple[np.array, np.array, int]], list[Agent]]:
14+
def self_play(
15+
agents: deque[Agent],
16+
) -> tuple[list[tuple[np.array, np.array, int]], list[Agent]]:
1517
states = []
1618
winners = []
1719
initial_state = Game(n_players=Config.n_players)
1820
for agent in agents:
1921
agent.eval()
20-
for agents_in_order in islice(windowed(cycle(agents), Config.n_players), Config.n_players):
22+
for agents_in_order in islice(
23+
windowed(cycle(agents), Config.n_players), Config.n_players
24+
):
2125
game = initial_state.copy()
22-
id_to_agent = dict((player.id, agent) for agent, player in zip(agents_in_order, game.players))
26+
id_to_agent = dict(
27+
(player.id, agent) for agent, player in zip(agents_in_order, game.players)
28+
)
2329
results, winner = _perform_game(game, [], id_to_agent)
2430
states += results
2531
winners.append(winner)
2632
return states, winners
2733

2834

29-
def _perform_game(game: Game, states: list, id_to_agent: dict[int, Agent]) -> tuple[
30-
list[tuple[np.array, np.array, int]], Agent]:
31-
for _ in tqdm(count()):
35+
def _perform_game(
36+
game: Game, states: list, id_to_agent: dict[int, Agent]
37+
) -> tuple[list[tuple[np.array, np.array, int]], Agent]:
38+
for turn in tqdm(count()):
3239
agent = id_to_agent[game.current_player.id]
3340
pi, action = policy(game, agent, Config.c, Config.n_simulations)
3441
action_index = game.all_moves.index(action)
3542
onehot_encoded_action = np.zeros(Config.n_actions)
3643
onehot_encoded_action[action_index] = 1
37-
states.append((game, onehot_encoded_action, 0))
44+
states.append((game, action, 0))
3845
game = game.perform(action)
3946
if game.is_terminal():
4047
result = game.get_results()
4148
return (
4249
list(
43-
(state[0].get_state(), state[1], int(result[state[0].current_player.id] == 1)) for state in states),
44-
id_to_agent[next(player.id for player in game.players if result[player.id])])
50+
(
51+
state[0].get_state(),
52+
(onehot_encoded_action := np.zeros(Config.n_actions), onehot_encoded_action.__setitem__(game.all_moves.index(state[1]), 1))[0],
53+
int(result[state[0].current_player.id] == 1),
54+
)
55+
for state in states
56+
),
57+
id_to_agent[
58+
next(player.id for player in game.players if result[player.id])
59+
],
60+
)

agent/train_agent.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,11 @@ def train_agent(agent: Agent, train_data: deque[tuple[tuple, np.array, int]]):
1818
loader = DataLoader(dataset, batch_size=Config.train_batch_size)
1919
for batch in loader:
2020
state, policy, win_probability = batch
21-
state, policy, win_probability = state.float(), policy.float(), win_probability.float()
21+
state, policy, win_probability = (
22+
state.float(),
23+
policy.float(),
24+
win_probability.float(),
25+
)
2226
optimizer.zero_grad()
2327
output_policy, output_v = agent(state)
2428
bce = mse(output_v, win_probability)

main.py

Lines changed: 50 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,49 +1,77 @@
1+
import random
12
import re
23
from collections import deque
34
from copy import deepcopy
4-
from itertools import count, islice
5-
from pathlib import Path
5+
from itertools import count
66

77
import torch
88

99
from Config import Config
1010
from agent.Agent import Agent
11+
from agent.pretrain import pretrain
12+
from agent.save import save_temp_buffer
1113
from agent.self_play import self_play
1214
from agent.train_agent import train_agent
1315

1416

1517
def main():
1618
training_buffer = deque(maxlen=Config.training_buffer_len)
17-
agents = deque((Agent(Config.n_players) for _ in range(Config.n_players)), maxlen=Config.n_players)
19+
agents = deque(
20+
(Agent(Config.n_players) for _ in range(Config.n_players)),
21+
maxlen=Config.n_players,
22+
)
1823
if Config.pretrain:
19-
for agent, checkpoint_index in zip(islice(reversed(agents), 1, None), sorted((int(path.name.split('.')[0]) for path in Config.model_path.iterdir()), reverse=True)):
20-
agent.load_state_dict(torch.load(Config.model_path.joinpath(f'{checkpoint_index}.pth')))
21-
newest = Config.model_path.joinpath(
22-
f"{max((*tuple(int(path.name.split('.')[0]) for path in Config.model_path.iterdir()), 0))}.pth")
23-
if newest.exists():
24-
agents[-1].load_state_dict(torch.load(newest))
25-
training_buffer += list(map(eval, map(Path.read_text, sorted(Config.training_data_path.iterdir(), key=lambda path: int(path.name), reverse=True)[:Config.training_buffer_len])))
26-
train_agent(agents[-1], training_buffer)
24+
pretrain(agents)
2725
scores = deque(maxlen=Config.max_results_held)
28-
for _ in (count() if Config.n_games is None else range(Config.n_games)):
26+
for _ in count() if Config.n_games is None else range(Config.n_games):
2927
buffer, winners = self_play(agents)
30-
start_index = max((*tuple(int(path.name) for path in Config.training_data_path.iterdir()), -1)) + 1
31-
for start_index, sample in enumerate(buffer, start_index + 1):
32-
Config.training_data_path.joinpath(str(start_index)).write_text(str((list(sample[0]), list(sample[1]), sample[2])))
28+
to_train = random.random() > Config.eval_rate
29+
save_temp_buffer(buffer, to_train)
3330
for winner in winners:
3431
scores.append(agents[-1] is winner)
35-
if (len(scores) < Config.min_games_to_replace_agents and sum(scores) >= Config.minimal_relative_agent_improvement * Config.min_games_to_replace_agents / len(agents)) or (len(scores) >= Config.min_games_to_replace_agents and sum(scores) >= Config.minimal_relative_agent_improvement * len(scores) / len(agents)):
36-
torch.save(agents[-1].state_dict(), Config.model_path.joinpath(str(max(map(int, (*re.findall(r'\d+', ''.join(map(str, Config.model_path.iterdir()))), -1))) + 1) + ".pth"))
32+
if (
33+
len(scores) < Config.min_games_to_replace_agents
34+
and sum(scores)
35+
> Config.minimal_relative_agent_improvement
36+
* Config.min_games_to_replace_agents
37+
/ len(agents)
38+
) or (
39+
len(scores) > Config.min_games_to_replace_agents
40+
and sum(scores)
41+
>= Config.minimal_relative_agent_improvement * len(scores) / len(agents)
42+
):
43+
torch.save(
44+
agents[-1].state_dict(),
45+
Config.model_path.joinpath(
46+
str(
47+
max(
48+
map(
49+
int,
50+
(
51+
*re.findall(
52+
r"\d+",
53+
"".join(map(str, Config.model_path.iterdir())),
54+
),
55+
-1,
56+
),
57+
)
58+
)
59+
+ 1
60+
)
61+
+ ".pth"
62+
),
63+
)
3764
agents.append(Agent(Config.n_players))
38-
agents[-1].load_state_dict(deepcopy(agents[-1].state_dict()))
65+
agents[-1].load_state_dict(deepcopy(agents[-2].state_dict()))
3966
agents[-1].training = True
4067
scores = deque(maxlen=Config.max_results_held)
4168
elif len(scores) >= Config.min_games_to_replace_agents:
42-
print(f'{len(scores)} {sum(scores) / len(scores):.2f}')
69+
print(f"{len(scores)} {sum(scores) / len(scores):.2f}")
4370
else:
44-
print(f'{len(scores)} {sum(scores)}/{len(scores)}')
45-
training_buffer += buffer
46-
train_agent(agents[-1], training_buffer)
71+
print(f"{len(scores)} {sum(scores)}/{len(scores)}")
72+
if to_train:
73+
training_buffer += buffer
74+
train_agent(agents[-1], training_buffer)
4775

4876

4977
if __name__ == "__main__":

0 commit comments

Comments
 (0)