Skip to content

Commit c4c9a24

Browse files
authored
Merge pull request #3 from Tesla2000/feature/evaluation
Feature/evaluation
2 parents ac4ff65 + c5103f3 commit c4c9a24

23 files changed

Lines changed: 372 additions & 114 deletions

.gitignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
11
/sandbox.py
22
/data/
33
/models/
4+
/evaluation_data/
5+
/training_data/

Config.py

Lines changed: 21 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -7,33 +7,42 @@
77

88
class _ConfigPaths:
99
root = Path(__file__).parent
10-
data_path = root / 'data'
11-
data_path.mkdir(exist_ok=True)
12-
model_path = root / 'models'
10+
training_data_path = root / "training_data"
11+
training_data_path.mkdir(exist_ok=True)
12+
evaluation_data_path = root / "evaluation_data"
13+
evaluation_data_path.mkdir(exist_ok=True)
14+
model_path = root / "models"
1315
model_path.mkdir(exist_ok=True)
1416

1517

1618
class _ConfigAgent:
17-
# hidden_sizes = (256, 128, 64, 32)
18-
hidden_sizes = (256,)
19-
# hidden_sizes = tuple()
20-
c = .1
21-
learning_rate = 1e-3
19+
# hidden_sizes = (
20+
# 256,
21+
# 128,
22+
# 64,
23+
# 32,
24+
# )
25+
# hidden_sizes = (256,)
26+
hidden_sizes = tuple()
27+
c = 0.2
28+
learning_rate = 1e-5
2229
debug = False
2330
pretrain = True
2431

2532

2633
class Config(_ConfigPaths, _ConfigAgent):
34+
train = False
2735
max_results_held = 100
2836
minimal_relative_agent_improvement = 1.1
29-
min_games_to_replace_agents = 20
30-
train_batch_size = 64
31-
training_buffer_len = 1000
37+
min_games_to_replace_agents = 40
38+
train_batch_size = 128
39+
training_buffer_len = 100_000
3240
min_n_points_to_finish = 15
3341
n_simulations = 100
3442
n_games = None
3543
n_players = 2
36-
n_actions = 46
44+
n_actions = 45
45+
eval_rate = 0.2
3746

3847

3948
if Config.debug:

agent/Agent.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ def __init__(
1515
self,
1616
n_players: int,
1717
hidden_sizes: tuple = Config.hidden_sizes,
18-
n_moves: int = 46,
18+
n_moves: int = Config.n_actions,
1919
):
2020
super().__init__()
2121
self.tanh = nn.Tanh()

agent/RLDataset.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,19 @@
1-
from collections import deque
1+
from typing import Sequence
22

33
import numpy as np
44
from torch.utils.data import Dataset
55

66

77
class RLDataset(Dataset):
8-
def __init__(self, examples: deque[tuple[tuple, np.array, int]]):
8+
def __init__(self, examples: Sequence[tuple[tuple, np.array, int]]):
99
self.examples = examples
1010

1111
def __len__(self):
1212
return len(self.examples)
1313

1414
def __getitem__(self, index) -> tuple[np.array, ...]:
15-
return np.array(self.examples[index][0]), np.array(self.examples[index][1]), np.array(
16-
[self.examples[index][2] * 2 - 1])
15+
return (
16+
np.array(self.examples[index][0]),
17+
np.array(self.examples[index][1]),
18+
np.array([self.examples[index][2] * 2 - 1]),
19+
)

agent/pretrain.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
import operator
2+
from collections import deque
3+
from functools import reduce
4+
from itertools import islice
5+
6+
import torch
7+
8+
from Config import Config
9+
from agent.Agent import Agent
10+
from agent.train_agent import train_agent
11+
12+
13+
def pretrain(agents: deque[Agent]):
14+
for agent, checkpoint_index in zip(
15+
islice(reversed(agents), 1, None),
16+
sorted(
17+
(int(path.name.split(".")[0]) for path in Config.model_path.iterdir()),
18+
reverse=True,
19+
),
20+
):
21+
agent.load_state_dict(
22+
torch.load(Config.model_path.joinpath(f"{checkpoint_index}.pth"))
23+
)
24+
newest = Config.model_path.joinpath(
25+
f"{max((*tuple(int(path.name.split('.')[0]) for path in Config.model_path.iterdir()), 0))}.pth"
26+
)
27+
if newest.exists():
28+
agents[-1].load_state_dict(torch.load(newest))
29+
training_buffer = reduce(
30+
operator.add,
31+
(
32+
deque(eval(path.read_text()))
33+
for path in sorted(
34+
Config.training_data_path.iterdir(), key=lambda path: int(path.name)
35+
)
36+
),
37+
deque(maxlen=Config.training_buffer_len),
38+
)
39+
train_agent(agents[-1], training_buffer)
40+
return training_buffer

agent/save.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
from Config import Config
2+
3+
4+
def save_temp_buffer(buffer, train: bool):
5+
path = Config.training_data_path if train else Config.evaluation_data_path
6+
index = max((*tuple(int(path.name) for path in path.iterdir()), -1)) + 1
7+
path.joinpath(str(index)).write_text(
8+
str(list((list(sample[0]), list(sample[1]), sample[2]) for sample in buffer))
9+
)

agent/search.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,4 +40,4 @@ def search(
4040
N[state][action] + 1
4141
)
4242
N[state][action] += 1
43-
return -v
43+
return -v

agent/self_play.py

Lines changed: 35 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,30 +1,58 @@
1+
import random
12
from collections import deque
23
from itertools import count
34

45
import numpy as np
56
from tqdm import tqdm
67

78
from Config import Config
9+
from src.Game import Game
810
from .Agent import Agent
911
from .policy import policy
10-
from src.Game import Game
1112

1213

13-
def self_play(agents: deque[Agent]) -> tuple[list[tuple[np.array, np.array, int]], Agent]:
14+
def self_play(
15+
agents: deque[Agent],
16+
) -> tuple[list[tuple[np.array, np.array, int]], list[Agent]]:
1417
states = []
18+
winners = []
1519
game = Game(n_players=Config.n_players)
16-
id_to_agent = dict((player.id, agent) for agent, player in zip(agents, game.players))
1720
for agent in agents:
1821
agent.eval()
19-
for _ in tqdm(count()):
22+
id_to_agent = dict(
23+
(player.id, agent)
24+
for agent, player in zip(random.sample(agents, Config.n_players), game.players)
25+
)
26+
results, winner = _perform_game(game, [], id_to_agent)
27+
states += results
28+
winners.append(winner)
29+
return states, winners
30+
31+
32+
def _perform_game(
33+
game: Game, states: list, id_to_agent: dict[int, Agent]
34+
) -> tuple[list[tuple[np.array, np.array, int]], Agent]:
35+
for turn in tqdm(count()):
2036
agent = id_to_agent[game.current_player.id]
2137
pi, action = policy(game, agent, Config.c, Config.n_simulations)
2238
action_index = game.all_moves.index(action)
2339
onehot_encoded_action = np.zeros(Config.n_actions)
2440
onehot_encoded_action[action_index] = 1
25-
states.append((game, onehot_encoded_action, 0))
41+
states.append((game, action, 0))
2642
game = game.perform(action)
2743
if game.is_terminal():
2844
result = game.get_results()
29-
return (list((state[0].get_state(), state[1], int(result[state[0].current_player.id] == 1)) for state in states),
30-
id_to_agent[next(player.id for player in game.players if result[player.id])])
45+
return (
46+
list(
47+
(
48+
state[0].get_state(),
49+
np.eye(Config.n_actions)[game.all_moves.index(state[1])],
50+
int(result[state[0].current_player.id] == 1),
51+
)
52+
for state in states
53+
if state[1] != game.null_move
54+
),
55+
id_to_agent[
56+
next(player.id for player in game.players if result[player.id])
57+
],
58+
)

agent/train_agent.py

Lines changed: 42 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
1-
from collections import deque
1+
from typing import Sequence
22

33
import numpy as np
4+
from sklearn.metrics import accuracy_score
45
from torch import nn, optim
56
from torch.utils.data import DataLoader
67

@@ -9,20 +10,49 @@
910
from .RLDataset import RLDataset
1011

1112

12-
def train_agent(agent: Agent, train_data: deque[tuple[tuple, np.array, int]]):
13+
def train_agent(agent: Agent, train_data: Sequence[tuple[tuple, np.array, int]]):
1314
agent.train()
15+
optimizer = optim.Adam(agent.parameters(), lr=Config.learning_rate)
16+
_loop(agent, train_data, optimizer)
17+
18+
19+
def eval_agent(agent: Agent, eval_set: Sequence[tuple[tuple, np.array, int]]):
20+
agent.eval()
21+
return _loop(agent, eval_set, batch_size=len(eval_set))
22+
23+
24+
def _loop(
25+
agent: Agent,
26+
dataset: Sequence[tuple[tuple, np.array, int]],
27+
optimizer: optim.Optimizer = None,
28+
batch_size=Config.train_batch_size,
29+
):
30+
is_optimizer = optimizer is not None
1431
categorical_cross_entropy = nn.CrossEntropyLoss()
1532
mse = nn.MSELoss()
16-
optimizer = optim.Adam(agent.parameters(), lr=Config.learning_rate)
17-
dataset = RLDataset(train_data)
18-
loader = DataLoader(dataset, batch_size=Config.train_batch_size)
19-
for batch in loader:
20-
state, policy, win_probability = batch
21-
state, policy, win_probability = state.float(), policy.float(), win_probability.float()
22-
optimizer.zero_grad()
33+
dataset = RLDataset(dataset)
34+
loader = DataLoader(dataset, batch_size=batch_size)
35+
for index, (state, policy, win_probability) in enumerate(loader):
36+
state, policy, win_probability = (
37+
state.float(),
38+
policy.float(),
39+
win_probability.float(),
40+
)
41+
if is_optimizer:
42+
optimizer.zero_grad()
2343
output_policy, output_v = agent(state)
2444
bce = mse(output_v, win_probability)
2545
cce = categorical_cross_entropy(output_policy, policy)
26-
bce.backward(retain_graph=True)
27-
cce.backward()
28-
optimizer.step()
46+
if is_optimizer:
47+
bce.backward(retain_graph=True)
48+
cce.backward()
49+
optimizer.step()
50+
else:
51+
print(
52+
accuracy_score(win_probability, np.sign(output_v.detach().numpy())),
53+
accuracy_score(
54+
np.argmax(policy.detach().numpy(), axis=1),
55+
np.argmax(output_policy.detach().numpy(), axis=1),
56+
),
57+
)
58+
return bce.item(), cce.item()

hashabledict.py

Lines changed: 0 additions & 3 deletions
This file was deleted.

0 commit comments

Comments
 (0)