Skip to content

Commit 7654d78

Browse files
authored
Merge pull request #1 from Tesla2000/feature/optimization
Feature/optimization
2 parents 4b690c4 + 5ee8d6f commit 7654d78

31 files changed

Lines changed: 487 additions & 170 deletions

.idea/Splendor.iml

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

.idea/misc.xml

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

.idea/vcs.xml

Lines changed: 0 additions & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Config.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
import random
2+
3+
import numpy as np
4+
import torch
5+
6+
random.seed(42)
7+
np.random.seed(42)
8+
torch.random.manual_seed(42)
9+
10+
11+
class Config:
12+
min_n_points_to_finish = 15
13+
n_simulations = 100
14+
n_games = 1
15+
n_players = 2

agent/Agent.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
from itertools import pairwise, starmap
2+
3+
import numpy as np
4+
from torch import nn, Tensor
5+
6+
7+
class Agent(nn.Module):
8+
_input_size_dictionary = {
9+
2: 215,
10+
}
11+
12+
def __init__(
13+
self,
14+
n_players: int,
15+
hidden_sizes: tuple = (256, 128, 64, 32),
16+
n_moves: int = 46,
17+
):
18+
super().__init__()
19+
self.relu = nn.ReLU()
20+
self.tanh = nn.Tanh()
21+
self.softmax = nn.Softmax(dim=1)
22+
first_size = self._get_size(n_players)
23+
sizes = first_size, *hidden_sizes
24+
self.layers = tuple(starmap(nn.Linear, pairwise(sizes)))
25+
for index, layer in enumerate(self.layers):
26+
setattr(self, f"layer_{index}", layer)
27+
self.fc_v = nn.Linear(hidden_sizes[-1], 1)
28+
self.fc_p = nn.Linear(hidden_sizes[-1], n_moves)
29+
self._n_moves = n_moves
30+
self._trained = False
31+
32+
def _get_size(self, n_players: int) -> int:
33+
return self._input_size_dictionary[n_players]
34+
35+
def forward(self, state: Tensor):
36+
if not self.training and not self._trained:
37+
return self.softmax(Tensor(np.random.random((1, self._n_moves)))), Tensor(
38+
np.random.uniform(-1, 1, (1, 1))
39+
)
40+
self._trained = True
41+
for layer in self.layers:
42+
state = layer(state)
43+
state = self.relu(state)
44+
return self.softmax(self.fc_p(state)), self.tanh(self.fc_v(state))

agent/__init__.py

Whitespace-only changes.

agent/train_agent.py

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
from collections import defaultdict
2+
from dataclasses import astuple
3+
from math import sqrt
4+
5+
import numpy as np
6+
from torch import nn, Tensor
7+
from tqdm import tqdm
8+
9+
from Config import Config
10+
from src.Game import Game
11+
from .Agent import Agent
12+
13+
14+
def train_agent():
15+
agent = Agent(Config.n_players)
16+
agent.eval()
17+
examples = []
18+
examples_per_game = []
19+
for i in range(Config.n_games):
20+
game = Game(n_players=Config.n_players)
21+
while True:
22+
pi, action = policy(game, agent, 1, Config.n_simulations)
23+
examples_per_game.append((game, pi, 0))
24+
game = game.perform(action)
25+
print(len(game.players[1].cards), game.players[1].points)
26+
if game.is_terminal():
27+
for example in examples_per_game:
28+
example[2] = game.get_state()
29+
break
30+
examples += examples_per_game
31+
break
32+
return examples
33+
34+
35+
def search(
36+
game: Game,
37+
agent: nn.Module,
38+
c: float,
39+
N: defaultdict,
40+
visited: set,
41+
P: defaultdict,
42+
Q: defaultdict,
43+
):
44+
state = game.get_state()
45+
if game.is_terminal():
46+
return game.get_results()[game.current_player]
47+
if state not in visited:
48+
visited.add(state)
49+
move_scores, v = agent(Tensor([state]))
50+
tuple(
51+
P[state].__setitem__(move, move_scores[0, index])
52+
for index, move in enumerate(game.all_moves)
53+
)
54+
return -v
55+
56+
action = max(
57+
game.get_possible_actions(),
58+
key=lambda action: Q[state].get(action, 1)
59+
+ c * P[state][action] * sqrt(sum(N[state].values())) / (1 + N[state][action]),
60+
)
61+
62+
next_game_state = game.perform(action)
63+
v = search(next_game_state, agent, c, N, visited, P, Q)
64+
65+
Q[state][action] = (N[state][action] * Q[state].get(action, 1) + v) / (
66+
N[state][action] + 1
67+
)
68+
N[state][action] += 1
69+
return -v
70+
71+
72+
def policy(
73+
game: Game,
74+
agent: nn.Module,
75+
c: float,
76+
n_simulations: int,
77+
):
78+
N = defaultdict(lambda: defaultdict(int))
79+
visited = set()
80+
P = defaultdict(dict)
81+
Q = defaultdict(dict)
82+
initial_state = game.get_state()
83+
all_moves = game.get_possible_actions()
84+
for _ in tqdm(range(n_simulations)):
85+
search(game, agent, c, N, visited, P, Q)
86+
pi = [N[initial_state][a] for a in all_moves]
87+
return pi, all_moves[np.argmax(pi)]

hashabledict.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
class hashabledict(dict):
2+
def __hash__(self):
3+
return tuple(self.items()).__hash__()

main.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,4 @@
1-
import random
1+
from agent.train_agent import train_agent
22

3-
from src.Game import Game
4-
5-
if __name__ == '__main__':
6-
game = Game()
7-
while not game.is_terminal():
8-
game.perform(random.choice(game.get_possible_actions()))
3+
if __name__ == "__main__":
4+
train_agent()

src/Game.py

Lines changed: 75 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,28 @@
1-
from dataclasses import astuple, fields, asdict, dataclass, field
1+
from dataclasses import fields, dataclass, field
22
from itertools import combinations, starmap, product
3-
from typing import Self, Iterable, Any
4-
from dacite import from_dict
3+
from typing import Self, Type
54

5+
from Config import Config
6+
from .StateExtractor import StateExtractor
7+
from .entities.AllResources import AllResources
68
from .entities.BasicResources import BasicResources
79
from .entities.Board import Board
8-
from .entities.Card import empty_card
910
from .entities.Player import Player
1011
from .entities.Tier import Tier
11-
from .moves import Move, GrabThreeResource, GrabTwoResource, BuildBoard, BuildReserve, ReserveVisible, ReserveTop, \
12-
NullMove
12+
from .entities.extended_lists.Aristocrats import Aristocrats
13+
from .entities.extended_lists.PlayerAristocrats import PlayerAristocrats
14+
from .entities.extended_lists.PlayerCards import PlayerCards
15+
from .entities.extended_lists.PlayerReserve import PlayerReserve
16+
from .moves import (
17+
Move,
18+
GrabThreeResource,
19+
GrabTwoResource,
20+
BuildBoard,
21+
BuildReserve,
22+
ReserveVisible,
23+
ReserveTop,
24+
NullMove,
25+
)
1326

1427

1528
@dataclass(slots=True)
@@ -22,24 +35,24 @@ class Game:
2235
_turn_counter: int = 0
2336
_performed_the_last_move: dict = None
2437
_last_turn: bool = False
38+
_state_extractor: Type[StateExtractor] = StateExtractor
2539

2640
def __post_init__(self):
27-
if not self.board or not self.players:
41+
if not self.players:
2842
self.players = tuple(Player() for _ in range(self.n_players))
43+
if not self.board:
2944
self.board = Board(self.n_players)
45+
if not self._performed_the_last_move:
3046
self._performed_the_last_move = dict(
3147
(player, False) for player in self.players
3248
)
33-
self.is_blocked = dict(
34-
(player, False) for player in self.players
35-
)
36-
self._last_turn = False
49+
self.is_blocked = dict((player, False) for player in self.players)
3750
self.current_player = self.players[0]
3851

3952
def perform(self, action: Move) -> Self:
40-
action.perform(self)
41-
self.next_turn()
42-
return self
53+
new_state = action.perform(self)
54+
new_state.next_turn()
55+
return new_state
4356

4457
def next_turn(self) -> None:
4558
self.players = (*self.players[1:], self.players[0])
@@ -48,73 +61,76 @@ def next_turn(self) -> None:
4861
self.current_player.aristocrats.append(
4962
self.board.aristocrats.pop(index)
5063
)
51-
if self.current_player.points >= 15 or self._last_turn:
64+
if self.current_player.points >= Config.min_n_points_to_finish or self._last_turn:
5265
self._last_turn = True
5366
self._performed_the_last_move[self.current_player] = self._last_turn
5467
self.current_player = self.players[0]
55-
self._turn_counter += 1
5668

5769
def is_terminal(self) -> bool:
5870
return all(self._performed_the_last_move.values()) or (
5971
not self.get_possible_actions()
6072
)
6173

62-
def get_results(self) -> dict[Player, bool]:
74+
def get_results(self) -> dict[Player, int]:
6375
results = {}
6476
for player in self.players:
65-
if not all(self._performed_the_last_move.values()):
66-
results[player] = player == max(self.players, key=lambda p: (p.points, -len(p.cards)))
67-
else:
68-
print("Finished game")
77+
results[player] = (
78+
1
79+
if player
80+
== max(self.players, key=lambda p: (p.points, -len(p.cards)))
81+
else -1
82+
)
6983
return results
7084

7185
def get_state(self) -> tuple:
72-
tiers = self.board.tiers
73-
self.board.tiers = list(Tier([], tier.visible) for tier in tiers)
74-
state = self._flatter_recursively(astuple(self.board))
75-
self.board.tiers = tiers
76-
for player in self.players:
77-
state += astuple(player.resources, tuple_factory=list)
78-
state += astuple(player.production, tuple_factory=list)
79-
if player != self.current_player:
80-
state.append(sum(card != empty_card for card in player.reserve))
81-
else:
82-
state += self._flatter_recursively(map(astuple, self.current_player.reserve))
83-
state.append(player.points)
84-
return tuple(state)
86+
return self._state_extractor.get_state(self)
8587

8688
def copy(self) -> Self:
87-
game = from_dict(Game, asdict(self))
89+
game = Game(
90+
players=tuple(
91+
Player(
92+
resources=AllResources(
93+
(resources := player.resources).red,
94+
resources.green,
95+
resources.blue,
96+
resources.black,
97+
resources.white,
98+
resources.gold,
99+
),
100+
cards=PlayerCards(player.cards),
101+
reserve=PlayerReserve(player.reserve),
102+
aristocrats=PlayerAristocrats(player.aristocrats),
103+
)
104+
for player in self.players
105+
),
106+
board=Board(
107+
n_players=(board := self.board).n_players,
108+
tiers=list(Tier(list(tier.hidden), list(tier.visible)) for tier in board.tiers),
109+
aristocrats=Aristocrats(board.aristocrats),
110+
resources=AllResources(
111+
board.resources.red,
112+
board.resources.green,
113+
board.resources.blue,
114+
board.resources.black,
115+
board.resources.white,
116+
board.resources.gold,
117+
),
118+
),
119+
n_players=self.n_players,
120+
)
88121
game.current_player = game.players[0]
122+
for player in game.players:
123+
game.is_blocked[player] = next(
124+
value for key, value in self.is_blocked.items() if key == player
125+
)
126+
game._performed_the_last_move[player] = next(
127+
value for key, value in self._performed_the_last_move.items() if key == player
128+
)
89129
return game
90130

91131
def get_possible_actions(self) -> list[Move]:
92132
return list(move for move in self.all_moves if move.is_valid(self))
93133

94-
def _flatter_recursively(
95-
self, iterable: Iterable, output: list = None, expected_length: int = None
96-
) -> list:
97-
if output is None:
98-
if expected_length:
99-
output = expected_length * [None]
100-
if not expected_length:
101-
return list(self._get_flatten_elements(iterable))
102-
index = 0
103-
for index, item in enumerate(self._get_flatten_elements(iterable)):
104-
if expected_length is None:
105-
output[index] = item
106-
if index != expected_length - 1:
107-
raise ValueError
108-
return output
109-
110-
def _get_flatten_elements(self, iterable: Iterable) -> Any:
111-
for element in iterable:
112-
if isinstance(element, Iterable):
113-
for inner_element in self._get_flatten_elements(element):
114-
yield inner_element
115-
else:
116-
yield element
117-
118134
combos = combinations([{field.name: 1} for field in fields(BasicResources)], 3)
119135
all_moves = list(
120136
GrabThreeResource(BasicResources(**res_1, **res_2, **res_3))
@@ -130,4 +146,3 @@ def _get_flatten_elements(self, iterable: Iterable) -> Any:
130146
all_moves += list(starmap(ReserveVisible, product(range(3), range(4))))
131147
all_moves += list(map(ReserveTop, range(3)))
132148
all_moves.append(NullMove())
133-
all_moves = tuple(all_moves)

0 commit comments

Comments
 (0)