Skip to content

Commit 2bb8a27

Browse files
committed
Moving on to optimize copy
1 parent 5471e15 commit 2bb8a27

6 files changed

Lines changed: 81 additions & 57 deletions

File tree

Config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,6 @@
99

1010

1111
class Config:
12-
n_simulations = 1000
12+
n_simulations = 100
1313
n_games = 100
1414
n_players = 2

agent/train_agent.py

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -17,20 +17,17 @@ def train_agent():
1717
examples = []
1818
examples_per_game = []
1919
for i in range(Config.n_games):
20-
N = defaultdict(lambda: defaultdict(int))
2120
game = Game(n_players=Config.n_players)
22-
visited = set()
23-
P = defaultdict(dict)
24-
Q = defaultdict(dict)
2521
while True:
26-
pi, action = policy(game, agent, 1, Config.n_simulations, N, visited, P, Q)
22+
pi, action = policy(game, agent, 1, Config.n_simulations)
2723
examples_per_game.append((game, pi, 0))
2824
game = game.perform(action)
2925
if game.is_terminal():
3026
for example in examples_per_game:
3127
example[2] = game.get_state()
3228
break
3329
examples += examples_per_game
30+
break
3431
return examples
3532

3633

@@ -52,8 +49,10 @@ def search(
5249
if state not in visited:
5350
visited.add(state)
5451
move_scores, v = agent(Tensor([state]))
55-
for index, move in enumerate(game.all_moves):
56-
P[state][move] = move_scores[0, index]
52+
tuple(
53+
P[state].__setitem__(move, move_scores[0, index])
54+
for index, move in enumerate(game.all_moves)
55+
)
5756
return -v
5857

5958
action = max(
@@ -77,11 +76,11 @@ def policy(
7776
agent: nn.Module,
7877
c: float,
7978
n_simulations: int,
80-
N: defaultdict,
81-
visited: set,
82-
P: defaultdict,
83-
Q: defaultdict,
8479
):
80+
N = defaultdict(lambda: defaultdict(int))
81+
visited = set()
82+
P = defaultdict(dict)
83+
Q = defaultdict(dict)
8584
initial_state = game.get_state()
8685
all_moves = game.get_possible_actions()
8786
for _ in tqdm(range(n_simulations)):

src/StateExtractor.py

Lines changed: 28 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,31 +1,42 @@
11
from dataclasses import astuple
2+
from itertools import chain
23
from typing import Iterable, Any, TYPE_CHECKING
34

4-
from .entities.Card import empty_card
5-
from .entities.Tier import Tier
6-
75
if TYPE_CHECKING:
86
from .Game import Game
97

108

119
class StateExtractor:
1210
@classmethod
1311
def get_state(cls, game: "Game") -> tuple:
14-
tiers = game.board.tiers
15-
game.board.tiers = list(Tier([], tier.visible) for tier in tiers)
16-
state = cls._flatter_recursively(astuple(game.board))
17-
game.board.tiers = tiers
18-
for player in game.players:
19-
state += astuple(player.resources, tuple_factory=list)
20-
state += astuple(player.production, tuple_factory=list)
21-
if player is not game.current_player:
22-
state.append(sum(card != empty_card for card in player.reserve))
23-
else:
24-
state += cls._flatter_recursively(
25-
map(astuple, game.current_player.reserve)
12+
return tuple(
13+
chain.from_iterable(
14+
(
15+
chain.from_iterable(
16+
(*astuple(card.cost), *astuple(card.production), card.points)
17+
for tier in game.board.tiers
18+
for card in tier.visible
19+
),
20+
chain.from_iterable(
21+
(*astuple(card.cost), *astuple(card.production), card.points)
22+
for card in game.current_player.reserve
23+
),
24+
chain.from_iterable(
25+
astuple(aristocrat.cost)
26+
for aristocrat in game.board.aristocrats
27+
),
28+
chain.from_iterable(
29+
(
30+
*astuple(player.resources),
31+
*astuple(player.production),
32+
player.points,
33+
)
34+
for player in game.players
35+
),
36+
(len(player.reserve) for player in game.players[1:]),
2637
)
27-
state.append(player.points)
28-
return tuple(state)
38+
)
39+
)
2940

3041
@classmethod
3142
def _flatter_recursively(

src/entities/AllResources.py

Lines changed: 27 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from dataclasses import dataclass, asdict, fields, astuple
1+
from dataclasses import dataclass
22
from typing import Self
33

44
from .BasicResources import BasicResources
@@ -11,23 +11,35 @@ class AllResources(BasicResources):
1111
def __sub__(self, other: BasicResources) -> Self:
1212
if not isinstance(other, BasicResources):
1313
raise ValueError(f"Other element must be resource is {other.__class__}")
14-
self_dict = asdict(self)
15-
other_dict = asdict(other)
16-
resources = AllResources(
17-
**dict(
18-
(key, value - other_dict.get(key, 0))
19-
for key, value in self_dict.items()
20-
)
21-
)
22-
resources = AllResources(
23-
*tuple(max(0, resource) for resource in astuple(resources)[:-1]),
24-
resources.gold
14+
return AllResources(
15+
max(0, self.red - other.red),
16+
max(0, self.green - other.green),
17+
max(0, self.blue - other.blue),
18+
max(0, self.black - other.black),
19+
max(0, self.white - other.white),
20+
self.gold
2521
+ sum(
26-
min(0, getattr(resources, field.name))
27-
for field in fields(BasicResources)
22+
(
23+
min(0, self.red - other.red),
24+
min(0, self.green - other.green),
25+
min(0, self.blue - other.blue),
26+
min(0, self.black - other.black),
27+
min(0, self.white - other.white),
28+
)
2829
),
2930
)
30-
return resources
31+
32+
def __add__(self, other: Self) -> Self:
33+
if not isinstance(other, BasicResources):
34+
raise ValueError(f"Other element must be resource is {other.__class__}")
35+
return AllResources(
36+
self.red + other.red,
37+
self.green + other.green,
38+
self.blue + other.blue,
39+
self.black + other.black,
40+
self.white + other.white,
41+
self.gold + getattr(other, "gold", 0),
42+
)
3143

3244
def __rsub__(self, other: BasicResources) -> Self:
3345
return self.__sub__(other)

src/entities/BasicResources.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,10 @@ class BasicResources:
1313
def __add__(self, other: Self) -> Self:
1414
if not isinstance(other, BasicResources):
1515
raise ValueError(f"Other element must be resource is {other.__class__}")
16-
self_dict = asdict(self)
17-
other_dict = asdict(other)
18-
return type(self)(
19-
**dict(
20-
(key, value + other_dict.get(key, 0))
21-
for key, value in self_dict.items()
22-
)
16+
return BasicResources(
17+
self.red + other.red,
18+
self.green + other.green,
19+
self.blue + other.blue,
20+
self.black + other.black,
21+
self.white + other.white,
2322
)

src/entities/extended_lists/PlayerCards.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,18 @@
1-
import operator
2-
from functools import reduce
3-
4-
from ..BasicResources import BasicResources
51
from .hashablelist import hashablelist
2+
from ..BasicResources import BasicResources
63

74

85
class PlayerCards(hashablelist):
96
@property
107
def production(self) -> BasicResources:
11-
return reduce(
12-
operator.add, (card.production for card in self), BasicResources()
8+
if not self:
9+
return BasicResources()
10+
return BasicResources(
11+
sum(card.production.red for card in self),
12+
sum(card.production.green for card in self),
13+
sum(card.production.blue for card in self),
14+
sum(card.production.black for card in self),
15+
sum(card.production.white for card in self),
1316
)
1417

1518
@property

0 commit comments

Comments
 (0)