Skip to content

Commit 04c4c39

Browse files
committed
Added data saving
1 parent e8cd14d commit 04c4c39

6 files changed

Lines changed: 11 additions & 7 deletions

File tree

Config.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ class _ConfigPaths:
1919

2020
class Config(_ConfigPaths):
2121
# hidden_sizes = (256, 128, 64, 32)
22+
c = .1
2223
hidden_sizes = (256,)
2324
# hidden_sizes = tuple()
2425
learning_rate = 1e-3

agent/Agent.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ def __init__(
2323
first_size = self._get_size(n_players)
2424
sizes = first_size, *hidden_sizes
2525
self.layers = nn.ModuleList(starmap(nn.Linear, pairwise(sizes)))
26-
self.trained = False
26+
self.trained = True
2727
self.fc_v = nn.Linear(sizes[-1], 1)
2828
self.fc_p = nn.Linear(sizes[-1], n_moves)
2929
self._n_moves = n_moves

agent/self_play.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from collections import deque
2-
from itertools import cycle
2+
from itertools import cycle, count
33

44
import numpy as np
55
from tqdm import tqdm
@@ -16,8 +16,9 @@ def self_play(agents: deque[Agent]) -> tuple[list[tuple[np.array, np.array, int]
1616
id_to_agent = dict((player.id, agent) for agent, player in zip(agents, game.players))
1717
for agent in agents:
1818
agent.eval()
19-
for agent in tqdm(cycle(agents)):
20-
pi, action = policy(game, agent, 1, Config.n_simulations)
19+
for _ in tqdm(count()):
20+
agent = id_to_agent[game.current_player.id]
21+
pi, action = policy(game, agent, Config.c, Config.n_simulations)
2122
action_index = game.all_moves.index(action)
2223
onehot_encoded_action = np.zeros(Config.n_actions)
2324
onehot_encoded_action[action_index] = 1

main.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import re
2-
from collections import deque, defaultdict
2+
from collections import deque
33
from copy import deepcopy
44
from itertools import count
55

@@ -17,9 +17,10 @@ def main():
1717
scores = deque(maxlen=Config.max_results_held)
1818
for _ in (count() if Config.n_games is None else range(Config.n_games)):
1919
buffer, winner = self_play(agents)
20+
Config.data_path.joinpath(str(max((*tuple(map(int, map(str, Config.data_path.iterdir()))), -1)) + 1)).write_text(str((list(buffer[0][0]), list(buffer[0][1]), buffer[0][2])))
2021
scores.append(agents[-1] is winner)
2122
if len(scores) >= Config.min_games_to_replace_agents and sum(scores) > Config.minimal_relative_agent_improvement * len(scores) / len(agents):
22-
torch.save(agents[-1].state_dict(), Config.model_path.joinpath(str(max(map(int, (*re.findall(r'\d+', ''.join(Config.model_path.iterdir())), -1))) + 1) + ".pth"))
23+
torch.save(agents[-1].state_dict(), Config.model_path.joinpath(str(max(map(int, (*re.findall(r'\d+', ''.join(map(str, Config.model_path.iterdir()))), -1))) + 1) + ".pth"))
2324
agents.append(Agent(Config.n_players).load_state_dict(deepcopy(agents[-1].state_dict())))
2425
agents[-1].training = True
2526
scores = deque(maxlen=Config.max_results_held)

src/Game.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,7 @@ def copy(self) -> Self:
130130
),
131131
),
132132
n_players=self.n_players,
133+
_last_turn=self._last_turn,
133134
)
134135
game.current_player = game.players[0]
135136
for player in game.players:

src/moves/Move.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
class Move(ABC):
1111
@abstractmethod
1212
def perform(self, game: "Game") -> "Game":
13-
# game = game.copy()
13+
game = game.copy()
1414
game.is_blocked[game.current_player] = False
1515
return game
1616

0 commit comments

Comments
 (0)