Skip to content

Commit ac4ff65

Browse files
authored
Merge pull request #2 from Tesla2000/feature/training-loop
Feature/training loop
2 parents 7654d78 + 632fc3c commit ac4ff65

20 files changed

Lines changed: 276 additions & 139 deletions

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
/sandbox.py
2+
/data/
3+
/models/

Config.py

Lines changed: 32 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,42 @@
11
import random
2+
from pathlib import Path
23

34
import numpy as np
45
import torch
56

6-
random.seed(42)
7-
np.random.seed(42)
8-
torch.random.manual_seed(42)
97

8+
class _ConfigPaths:
9+
root = Path(__file__).parent
10+
data_path = root / 'data'
11+
data_path.mkdir(exist_ok=True)
12+
model_path = root / 'models'
13+
model_path.mkdir(exist_ok=True)
1014

11-
class Config:
15+
16+
class _ConfigAgent:
17+
# hidden_sizes = (256, 128, 64, 32)
18+
hidden_sizes = (256,)
19+
# hidden_sizes = tuple()
20+
c = .1
21+
learning_rate = 1e-3
22+
debug = False
23+
pretrain = True
24+
25+
26+
class Config(_ConfigPaths, _ConfigAgent):
27+
max_results_held = 100
28+
minimal_relative_agent_improvement = 1.1
29+
min_games_to_replace_agents = 20
30+
train_batch_size = 64
31+
training_buffer_len = 1000
1232
min_n_points_to_finish = 15
1333
n_simulations = 100
14-
n_games = 1
34+
n_games = None
1535
n_players = 2
36+
n_actions = 46
37+
38+
39+
if Config.debug:
40+
random.seed(42)
41+
np.random.seed(42)
42+
torch.random.manual_seed(42)

agent/Agent.py

Lines changed: 10 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -3,42 +3,40 @@
33
import numpy as np
44
from torch import nn, Tensor
55

6+
from Config import Config
7+
68

79
class Agent(nn.Module):
810
_input_size_dictionary = {
9-
2: 215,
11+
2: 205,
1012
}
1113

1214
def __init__(
1315
self,
1416
n_players: int,
15-
hidden_sizes: tuple = (256, 128, 64, 32),
17+
hidden_sizes: tuple = Config.hidden_sizes,
1618
n_moves: int = 46,
1719
):
1820
super().__init__()
19-
self.relu = nn.ReLU()
2021
self.tanh = nn.Tanh()
2122
self.softmax = nn.Softmax(dim=1)
2223
first_size = self._get_size(n_players)
2324
sizes = first_size, *hidden_sizes
24-
self.layers = tuple(starmap(nn.Linear, pairwise(sizes)))
25-
for index, layer in enumerate(self.layers):
26-
setattr(self, f"layer_{index}", layer)
27-
self.fc_v = nn.Linear(hidden_sizes[-1], 1)
28-
self.fc_p = nn.Linear(hidden_sizes[-1], n_moves)
25+
self.layers = nn.ModuleList(starmap(nn.Linear, pairwise(sizes)))
26+
self.trained = True
27+
self.fc_v = nn.Linear(sizes[-1], 1)
28+
self.fc_p = nn.Linear(sizes[-1], n_moves)
2929
self._n_moves = n_moves
30-
self._trained = False
3130

3231
def _get_size(self, n_players: int) -> int:
3332
return self._input_size_dictionary[n_players]
3433

3534
def forward(self, state: Tensor):
36-
if not self.training and not self._trained:
35+
if not self.training and not self.trained:
3736
return self.softmax(Tensor(np.random.random((1, self._n_moves)))), Tensor(
3837
np.random.uniform(-1, 1, (1, 1))
3938
)
40-
self._trained = True
39+
self.trained = True
4140
for layer in self.layers:
4241
state = layer(state)
43-
state = self.relu(state)
4442
return self.softmax(self.fc_p(state)), self.tanh(self.fc_v(state))

agent/RLDataset.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
from collections import deque
2+
3+
import numpy as np
4+
from torch.utils.data import Dataset
5+
6+
7+
class RLDataset(Dataset):
8+
def __init__(self, examples: deque[tuple[tuple, np.array, int]]):
9+
self.examples = examples
10+
11+
def __len__(self):
12+
return len(self.examples)
13+
14+
def __getitem__(self, index) -> tuple[np.array, ...]:
15+
return np.array(self.examples[index][0]), np.array(self.examples[index][1]), np.array(
16+
[self.examples[index][2] * 2 - 1])

agent/policy.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
from collections import defaultdict
2+
3+
import numpy as np
4+
from torch import nn
5+
6+
from agent.search import search
7+
from src.Game import Game
8+
9+
10+
def policy(
11+
game: Game,
12+
agent: nn.Module,
13+
c: float,
14+
n_simulations: int,
15+
):
16+
N = defaultdict(lambda: defaultdict(int))
17+
visited = set()
18+
P = defaultdict(dict)
19+
Q = defaultdict(dict)
20+
initial_state = game.get_state()
21+
all_moves = game.get_possible_actions()
22+
for _ in range(n_simulations):
23+
search(game.copy(), agent, c, N, visited, P, Q)
24+
pi = [N[initial_state][a] for a in all_moves]
25+
return pi, all_moves[np.argmax(pi)]

agent/search.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
from collections import defaultdict
2+
from math import sqrt
3+
4+
from torch import nn, Tensor
5+
6+
from src.Game import Game
7+
8+
9+
def search(
10+
game: Game,
11+
agent: nn.Module,
12+
c: float,
13+
N: defaultdict,
14+
visited: set,
15+
P: defaultdict,
16+
Q: defaultdict,
17+
):
18+
if game.is_terminal():
19+
return -game.get_results()[game.current_player.id]
20+
state = game.get_state()
21+
if state not in visited:
22+
visited.add(state)
23+
move_scores, v = agent(Tensor([state]))
24+
tuple(
25+
P[state].__setitem__(move, move_scores[0, index])
26+
for index, move in enumerate(game.all_moves)
27+
)
28+
return -v
29+
30+
action = max(
31+
game.get_possible_actions(),
32+
key=lambda action: Q[state].get(action, 1)
33+
+ c * P[state][action] * sqrt(sum(N[state].values())) / (1 + N[state][action]),
34+
)
35+
36+
next_game_state = game.perform(action)
37+
v = search(next_game_state, agent, c, N, visited, P, Q)
38+
39+
Q[state][action] = (N[state][action] * Q[state].get(action, 1) + v) / (
40+
N[state][action] + 1
41+
)
42+
N[state][action] += 1
43+
return -v

agent/self_play.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
from collections import deque
2+
from itertools import count
3+
4+
import numpy as np
5+
from tqdm import tqdm
6+
7+
from Config import Config
8+
from .Agent import Agent
9+
from .policy import policy
10+
from src.Game import Game
11+
12+
13+
def self_play(agents: deque[Agent]) -> tuple[list[tuple[np.array, np.array, int]], Agent]:
14+
states = []
15+
game = Game(n_players=Config.n_players)
16+
id_to_agent = dict((player.id, agent) for agent, player in zip(agents, game.players))
17+
for agent in agents:
18+
agent.eval()
19+
for _ in tqdm(count()):
20+
agent = id_to_agent[game.current_player.id]
21+
pi, action = policy(game, agent, Config.c, Config.n_simulations)
22+
action_index = game.all_moves.index(action)
23+
onehot_encoded_action = np.zeros(Config.n_actions)
24+
onehot_encoded_action[action_index] = 1
25+
states.append((game, onehot_encoded_action, 0))
26+
game = game.perform(action)
27+
if game.is_terminal():
28+
result = game.get_results()
29+
return (list((state[0].get_state(), state[1], int(result[state[0].current_player.id] == 1)) for state in states),
30+
id_to_agent[next(player.id for player in game.players if result[player.id])])

agent/train_agent.py

Lines changed: 23 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -1,87 +1,28 @@
1-
from collections import defaultdict
2-
from dataclasses import astuple
3-
from math import sqrt
1+
from collections import deque
42

53
import numpy as np
6-
from torch import nn, Tensor
7-
from tqdm import tqdm
4+
from torch import nn, optim
5+
from torch.utils.data import DataLoader
86

97
from Config import Config
10-
from src.Game import Game
118
from .Agent import Agent
12-
13-
14-
def train_agent():
15-
agent = Agent(Config.n_players)
16-
agent.eval()
17-
examples = []
18-
examples_per_game = []
19-
for i in range(Config.n_games):
20-
game = Game(n_players=Config.n_players)
21-
while True:
22-
pi, action = policy(game, agent, 1, Config.n_simulations)
23-
examples_per_game.append((game, pi, 0))
24-
game = game.perform(action)
25-
print(len(game.players[1].cards), game.players[1].points)
26-
if game.is_terminal():
27-
for example in examples_per_game:
28-
example[2] = game.get_state()
29-
break
30-
examples += examples_per_game
31-
break
32-
return examples
33-
34-
35-
def search(
36-
game: Game,
37-
agent: nn.Module,
38-
c: float,
39-
N: defaultdict,
40-
visited: set,
41-
P: defaultdict,
42-
Q: defaultdict,
43-
):
44-
state = game.get_state()
45-
if game.is_terminal():
46-
return game.get_results()[game.current_player]
47-
if state not in visited:
48-
visited.add(state)
49-
move_scores, v = agent(Tensor([state]))
50-
tuple(
51-
P[state].__setitem__(move, move_scores[0, index])
52-
for index, move in enumerate(game.all_moves)
53-
)
54-
return -v
55-
56-
action = max(
57-
game.get_possible_actions(),
58-
key=lambda action: Q[state].get(action, 1)
59-
+ c * P[state][action] * sqrt(sum(N[state].values())) / (1 + N[state][action]),
60-
)
61-
62-
next_game_state = game.perform(action)
63-
v = search(next_game_state, agent, c, N, visited, P, Q)
64-
65-
Q[state][action] = (N[state][action] * Q[state].get(action, 1) + v) / (
66-
N[state][action] + 1
67-
)
68-
N[state][action] += 1
69-
return -v
70-
71-
72-
def policy(
73-
game: Game,
74-
agent: nn.Module,
75-
c: float,
76-
n_simulations: int,
77-
):
78-
N = defaultdict(lambda: defaultdict(int))
79-
visited = set()
80-
P = defaultdict(dict)
81-
Q = defaultdict(dict)
82-
initial_state = game.get_state()
83-
all_moves = game.get_possible_actions()
84-
for _ in tqdm(range(n_simulations)):
85-
search(game, agent, c, N, visited, P, Q)
86-
pi = [N[initial_state][a] for a in all_moves]
87-
return pi, all_moves[np.argmax(pi)]
9+
from .RLDataset import RLDataset
10+
11+
12+
def train_agent(agent: Agent, train_data: deque[tuple[tuple, np.array, int]]):
13+
agent.train()
14+
categorical_cross_entropy = nn.CrossEntropyLoss()
15+
mse = nn.MSELoss()
16+
optimizer = optim.Adam(agent.parameters(), lr=Config.learning_rate)
17+
dataset = RLDataset(train_data)
18+
loader = DataLoader(dataset, batch_size=Config.train_batch_size)
19+
for batch in loader:
20+
state, policy, win_probability = batch
21+
state, policy, win_probability = state.float(), policy.float(), win_probability.float()
22+
optimizer.zero_grad()
23+
output_policy, output_v = agent(state)
24+
bce = mse(output_v, win_probability)
25+
cce = categorical_cross_entropy(output_policy, policy)
26+
bce.backward(retain_graph=True)
27+
cce.backward()
28+
optimizer.step()

main.py

Lines changed: 43 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,46 @@
1+
import re
2+
from collections import deque
3+
from copy import deepcopy
4+
from itertools import count, islice
5+
from pathlib import Path
6+
7+
import torch
8+
9+
from Config import Config
10+
from agent.Agent import Agent
11+
from agent.self_play import self_play
112
from agent.train_agent import train_agent
213

14+
15+
def main():
16+
training_buffer = deque(maxlen=Config.training_buffer_len)
17+
agents = deque((Agent(Config.n_players) for _ in range(Config.n_players)), maxlen=Config.n_players)
18+
if Config.pretrain:
19+
for agent, checkpoint_index in zip(islice(reversed(agents), 1, None), sorted((int(path.name.split('.')[0]) for path in Config.model_path.iterdir()), reverse=True)):
20+
agent.load_state_dict(torch.load(Config.model_path.joinpath(f'{checkpoint_index}.pth')))
21+
agents[-1].load_state_dict(torch.load(Config.model_path.joinpath(f"{max(int(path.name.split('.')[0]) for path in Config.model_path.iterdir())}.pth")))
22+
training_buffer += list(map(eval, map(Path.read_text, sorted(Config.data_path.iterdir(), key=lambda path: int(path.name), reverse=True)[:Config.training_buffer_len])))
23+
train_agent(agents[-1], training_buffer)
24+
scores = deque(maxlen=Config.max_results_held)
25+
for _ in (count() if Config.n_games is None else range(Config.n_games)):
26+
buffer, winner = self_play(agents)
27+
start_index = max((*tuple(int(path.name) for path in Config.data_path.iterdir()), -1)) + 1
28+
for start_index, sample in enumerate(buffer, start_index + 1):
29+
Config.data_path.joinpath(str(start_index)).write_text(str((list(sample[0]), list(sample[1]), sample[2])))
30+
scores.append(agents[-1] is winner)
31+
if (len(scores) < Config.min_games_to_replace_agents and sum(scores) >= Config.minimal_relative_agent_improvement * Config.min_games_to_replace_agents / len(agents)) or (len(scores) >= Config.min_games_to_replace_agents and sum(scores) >= Config.minimal_relative_agent_improvement * len(scores) / len(agents)):
32+
torch.save(agents[-1].state_dict(), Config.model_path.joinpath(str(max(map(int, (*re.findall(r'\d+', ''.join(map(str, Config.model_path.iterdir()))), -1))) + 1) + ".pth"))
33+
agents.append(Agent(Config.n_players))
34+
agents[-1].load_state_dict(deepcopy(agents[-1].state_dict()))
35+
agents[-1].training = True
36+
scores = deque(maxlen=Config.max_results_held)
37+
elif len(scores) >= Config.min_games_to_replace_agents:
38+
print(f'{len(scores)} {sum(scores) / len(scores):.2f}')
39+
else:
40+
print(f'{len(scores)} {sum(scores)}/{len(scores)}')
41+
training_buffer += buffer
42+
train_agent(agents[-1], training_buffer)
43+
44+
345
if __name__ == "__main__":
4-
train_agent()
46+
main()

0 commit comments

Comments
 (0)