Skip to content

Commit c5103f3

Browse files
committed
Added evaluation
1 parent b85fc91 commit c5103f3

5 files changed

Lines changed: 86 additions & 23 deletions

File tree

Config.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,20 +16,26 @@ class _ConfigPaths:
1616

1717

1818
class _ConfigAgent:
19-
# hidden_sizes = (256, 128, 64, 32)
20-
hidden_sizes = (256,)
21-
# hidden_sizes = tuple()
22-
c = 0.1
23-
learning_rate = 1e-4
19+
# hidden_sizes = (
20+
# 256,
21+
# 128,
22+
# 64,
23+
# 32,
24+
# )
25+
# hidden_sizes = (256,)
26+
hidden_sizes = tuple()
27+
c = 0.2
28+
learning_rate = 1e-5
2429
debug = False
2530
pretrain = True
2631

2732

2833
class Config(_ConfigPaths, _ConfigAgent):
34+
train = False
2935
max_results_held = 100
3036
minimal_relative_agent_improvement = 1.1
3137
min_games_to_replace_agents = 40
32-
train_batch_size = 64
38+
train_batch_size = 128
3339
training_buffer_len = 100_000
3440
min_n_points_to_finish = 15
3541
n_simulations = 100

agent/RLDataset.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
1-
from collections import deque
1+
from typing import Sequence
22

33
import numpy as np
44
from torch.utils.data import Dataset
55

66

77
class RLDataset(Dataset):
8-
def __init__(self, examples: deque[tuple[tuple, np.array, int]]):
8+
def __init__(self, examples: Sequence[tuple[tuple, np.array, int]]):
99
self.examples = examples
1010

1111
def __len__(self):

agent/self_play.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,8 @@ def self_play(
2020
for agent in agents:
2121
agent.eval()
2222
id_to_agent = dict(
23-
(player.id, agent) for agent, player in zip(random.sample(agents, Config.n_players), game.players)
23+
(player.id, agent)
24+
for agent, player in zip(random.sample(agents, Config.n_players), game.players)
2425
)
2526
results, winner = _perform_game(game, [], id_to_agent)
2627
states += results
@@ -48,7 +49,8 @@ def _perform_game(
4849
np.eye(Config.n_actions)[game.all_moves.index(state[1])],
4950
int(result[state[0].current_player.id] == 1),
5051
)
51-
for state in states if state[1] != game.null_move
52+
for state in states
53+
if state[1] != game.null_move
5254
),
5355
id_to_agent[
5456
next(player.id for player in game.players if result[player.id])

agent/train_agent.py

Lines changed: 37 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
1-
from collections import deque
1+
from typing import Sequence
22

33
import numpy as np
4+
from sklearn.metrics import accuracy_score
45
from torch import nn, optim
56
from torch.utils.data import DataLoader
67

@@ -9,24 +10,49 @@
910
from .RLDataset import RLDataset
1011

1112

12-
def train_agent(agent: Agent, train_data: deque[tuple[tuple, np.array, int]]):
13+
def train_agent(agent: Agent, train_data: Sequence[tuple[tuple, np.array, int]]):
1314
agent.train()
15+
optimizer = optim.Adam(agent.parameters(), lr=Config.learning_rate)
16+
_loop(agent, train_data, optimizer)
17+
18+
19+
def eval_agent(agent: Agent, eval_set: Sequence[tuple[tuple, np.array, int]]):
20+
agent.eval()
21+
return _loop(agent, eval_set, batch_size=len(eval_set))
22+
23+
24+
def _loop(
25+
agent: Agent,
26+
dataset: Sequence[tuple[tuple, np.array, int]],
27+
optimizer: optim.Optimizer = None,
28+
batch_size=Config.train_batch_size,
29+
):
30+
is_optimizer = optimizer is not None
1431
categorical_cross_entropy = nn.CrossEntropyLoss()
1532
mse = nn.MSELoss()
16-
optimizer = optim.Adam(agent.parameters(), lr=Config.learning_rate)
17-
dataset = RLDataset(train_data)
18-
loader = DataLoader(dataset, batch_size=Config.train_batch_size)
19-
for batch in loader:
20-
state, policy, win_probability = batch
33+
dataset = RLDataset(dataset)
34+
loader = DataLoader(dataset, batch_size=batch_size)
35+
for index, (state, policy, win_probability) in enumerate(loader):
2136
state, policy, win_probability = (
2237
state.float(),
2338
policy.float(),
2439
win_probability.float(),
2540
)
26-
optimizer.zero_grad()
41+
if is_optimizer:
42+
optimizer.zero_grad()
2743
output_policy, output_v = agent(state)
2844
bce = mse(output_v, win_probability)
2945
cce = categorical_cross_entropy(output_policy, policy)
30-
bce.backward(retain_graph=True)
31-
cce.backward()
32-
optimizer.step()
46+
if is_optimizer:
47+
bce.backward(retain_graph=True)
48+
cce.backward()
49+
optimizer.step()
50+
else:
51+
print(
52+
accuracy_score(win_probability, np.sign(output_v.detach().numpy())),
53+
accuracy_score(
54+
np.argmax(policy.detach().numpy(), axis=1),
55+
np.argmax(output_policy.detach().numpy(), axis=1),
56+
),
57+
)
58+
return bce.item(), cce.item()

main.py

Lines changed: 31 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
1+
import operator
12
import random
23
import re
34
from collections import deque
45
from copy import deepcopy
6+
from functools import reduce
57
from itertools import count
68

79
import torch
@@ -11,10 +13,10 @@
1113
from agent.pretrain import pretrain
1214
from agent.save import save_temp_buffer
1315
from agent.self_play import self_play
14-
from agent.train_agent import train_agent
16+
from agent.train_agent import train_agent, eval_agent
1517

1618

17-
def main():
19+
def train_loop():
1820
training_buffer = deque(maxlen=Config.training_buffer_len)
1921
agents = deque(
2022
(Agent(Config.n_players) for _ in range(Config.n_players)),
@@ -74,5 +76,32 @@ def main():
7476
train_agent(agents[-1], training_buffer)
7577

7678

79+
def evaluation():
80+
agent = Agent(Config.n_players)
81+
train_set = reduce(
82+
operator.add,
83+
(eval(path.read_text()) for path in Config.training_data_path.iterdir()),
84+
)
85+
eval_set = reduce(
86+
operator.add,
87+
(eval(path.read_text()) for path in Config.evaluation_data_path.iterdir()),
88+
)
89+
prev_bce, prev_cce = float("inf"), float("inf")
90+
while True:
91+
train_agent(agent, train_set)
92+
bce, cce = eval_agent(agent, eval_set)
93+
if bce >= prev_bce and cce >= prev_bce:
94+
break
95+
prev_bce = bce
96+
prev_cce = cce
97+
98+
99+
def main():
100+
if Config.train:
101+
train_loop()
102+
else:
103+
evaluation()
104+
105+
77106
if __name__ == "__main__":
78107
main()

0 commit comments

Comments
 (0)