Skip to content

Commit b655359

Browse files
committed
Fixing mistakes one by one
1 parent 2bb8a27 commit b655359

7 files changed

Lines changed: 84 additions & 45 deletions

File tree

agent/train_agent.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77

88
from Config import Config
99
from src.Game import Game
10-
from src.entities.Player import Player
1110
from .Agent import Agent
1211

1312

@@ -39,13 +38,10 @@ def search(
3938
visited: set,
4039
P: defaultdict,
4140
Q: defaultdict,
42-
evaluated_player: Player = None,
4341
):
44-
if evaluated_player is None:
45-
evaluated_player = game.current_player
4642
state = game.get_state()
4743
if game.is_terminal():
48-
return game.get_results()[evaluated_player]
44+
return game.get_results()[game.current_player]
4945
if state not in visited:
5046
visited.add(state)
5147
move_scores, v = agent(Tensor([state]))
@@ -62,7 +58,7 @@ def search(
6258
)
6359

6460
next_game_state = game.perform(action)
65-
v = search(next_game_state, agent, c, N, visited, P, Q, evaluated_player)
61+
v = search(next_game_state, agent, c, N, visited, P, Q)
6662

6763
Q[state][action] = (N[state][action] * Q[state].get(action, 1) + v) / (
6864
N[state][action] + 1

src/Game.py

Lines changed: 42 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,17 @@
1-
from dataclasses import fields, asdict, dataclass, field
1+
from dataclasses import fields, dataclass, field
22
from itertools import combinations, starmap, product
33
from typing import Self, Type
44

5-
from dacite import from_dict
6-
7-
from hashabledict import hashabledict
85
from .StateExtractor import StateExtractor
6+
from .entities.AllResources import AllResources
97
from .entities.BasicResources import BasicResources
108
from .entities.Board import Board
119
from .entities.Player import Player
10+
from .entities.Tier import Tier
11+
from .entities.extended_lists.Aristocrats import Aristocrats
12+
from .entities.extended_lists.PlayerAristocrats import PlayerAristocrats
13+
from .entities.extended_lists.PlayerCards import PlayerCards
14+
from .entities.extended_lists.PlayerReserve import PlayerReserve
1215
from .moves import (
1316
Move,
1417
GrabThreeResource,
@@ -34,14 +37,15 @@ class Game:
3437
_state_extractor: Type[StateExtractor] = StateExtractor
3538

3639
def __post_init__(self):
37-
if not self.board or not self.players:
40+
if not self.players:
3841
self.players = tuple(Player() for _ in range(self.n_players))
42+
if not self.board:
3943
self.board = Board(self.n_players)
44+
if not self._performed_the_last_move:
4045
self._performed_the_last_move = dict(
4146
(player, False) for player in self.players
4247
)
4348
self.is_blocked = dict((player, False) for player in self.players)
44-
self._last_turn = False
4549
self.current_player = self.players[0]
4650

4751
def perform(self, action: Move) -> Self:
@@ -60,7 +64,6 @@ def next_turn(self) -> None:
6064
self._last_turn = True
6165
self._performed_the_last_move[self.current_player] = self._last_turn
6266
self.current_player = self.players[0]
63-
self._turn_counter += 1
6467

6568
def is_terminal(self) -> bool:
6669
return all(self._performed_the_last_move.values()) or (
@@ -85,8 +88,38 @@ def get_state(self) -> tuple:
8588
return self._state_extractor.get_state(self)
8689

8790
def copy(self) -> Self:
88-
dict_repr = asdict(self, dict_factory=hashabledict)
89-
game = from_dict(Game, dict_repr)
91+
game = Game(
92+
players=tuple(
93+
Player(
94+
resources=AllResources(
95+
(resources := player.resources).red,
96+
resources.green,
97+
resources.blue,
98+
resources.black,
99+
resources.white,
100+
resources.gold,
101+
),
102+
cards=PlayerCards(tuple(player.cards)),
103+
reserve=PlayerReserve(tuple(player.reserve)),
104+
aristocrats=PlayerAristocrats(tuple(player.aristocrats)),
105+
)
106+
for player in self.players
107+
),
108+
board=Board(
109+
n_players=(board := self.board).n_players,
110+
tiers=list(Tier(tier.hidden, tier.visible) for tier in board.tiers),
111+
aristocrats=Aristocrats(board.aristocrats),
112+
resources=AllResources(
113+
board.resources.red,
114+
board.resources.green,
115+
board.resources.blue,
116+
board.resources.black,
117+
board.resources.white,
118+
board.resources.gold,
119+
),
120+
),
121+
n_players=self.n_players,
122+
)
90123
game.current_player = game.players[0]
91124
for player in game.players:
92125
game.is_blocked[player] = next(

src/entities/AllResources.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,3 +46,6 @@ def __rsub__(self, other: BasicResources) -> Self:
4646

4747
def lacks(self) -> bool:
4848
return self.gold < 0
49+
50+
if __name__ == '__main__':
51+
(AllResources(red=0, green=0, blue=0, black=4, white=4, gold=5) - BasicResources(red=1, green=1, blue=0, black=0, white=1))

src/moves/GrabResource.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,13 @@
11
from abc import ABC
2-
from dataclasses import dataclass
2+
from dataclasses import dataclass, astuple
3+
from typing import TYPE_CHECKING
34

45
from src.entities.BasicResources import BasicResources
56
from .Move import Move
7+
from ..entities.AllResources import AllResources
8+
9+
if TYPE_CHECKING:
10+
from src.Game import Game
611

712

813
@dataclass(slots=True, frozen=True)
@@ -11,3 +16,17 @@ class GrabResource(Move, ABC):
1116

1217
def __repr__(self):
1318
return self.resources.__repr__()
19+
20+
def is_valid(self, game: "Game") -> bool:
21+
if (
22+
sum(astuple(game.current_player.resources)) + sum(astuple(self.resources))
23+
> 10
24+
):
25+
return False
26+
return not (AllResources(
27+
game.board.resources.red,
28+
game.board.resources.green,
29+
game.board.resources.blue,
30+
game.board.resources.black,
31+
game.board.resources.white,
32+
) - self.resources).lacks()

src/moves/GrabThreeResource.py

Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,11 @@
11
from dataclasses import astuple
2-
32
from typing import TYPE_CHECKING
43

4+
from .GrabResource import GrabResource
55
from .Move import Move
66

77
if TYPE_CHECKING:
88
from src.Game import Game
9-
from .GrabResource import GrabResource
109

1110

1211
class GrabThreeResource(GrabResource):
@@ -17,11 +16,3 @@ def perform(self, game: "Game") -> "Game":
1716
if sum(astuple(game.current_player.resources)) > 10:
1817
raise ValueError
1918
return game
20-
21-
def is_valid(self, game: "Game") -> bool:
22-
if (
23-
sum(astuple(game.current_player.resources)) + sum(astuple(self.resources))
24-
> 10
25-
):
26-
return False
27-
return not (game.board.resources - self.resources).lacks()

src/moves/GrabTwoResource.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,10 +20,7 @@ def perform(self, game: "Game") -> "Game":
2020
return game
2121

2222
def is_valid(self, game: "Game") -> bool:
23-
tuple_resources = astuple(self.resources)
24-
if sum(astuple(game.current_player.resources)) + sum(tuple_resources) > 10:
25-
return False
26-
resource = next(compress(asdict(self.resources).keys(), tuple_resources))
23+
resource = next(compress(asdict(self.resources).keys(), astuple(self.resources)))
2724
if getattr(game.board.resources, resource) < 4:
2825
return False
29-
return not (game.board.resources - self.resources).lacks()
26+
return super().is_valid(game)

src/moves/Reserve.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,9 @@
11
from abc import ABC
2-
from dataclasses import dataclass, fields
3-
4-
from src.entities.Card import Card
2+
from dataclasses import dataclass
53
from typing import TYPE_CHECKING
64

5+
from src.entities.Card import Card
76
from ..entities.AllResources import AllResources
8-
from ..entities.BasicResources import BasicResources
97

108
if TYPE_CHECKING:
119
from src.Game import Game
@@ -21,16 +19,18 @@ def reserve_card(self, game: "Game", card: Card):
2119
current_player.reserve.append(card)
2220
if game.board.resources.gold:
2321
game.board.resources = AllResources(
24-
**dict(
25-
(field.name, getattr(game.board.resources, field.name))
26-
for field in fields(BasicResources)
27-
),
28-
gold=game.board.resources.gold - 1
22+
game.board.resources.red,
23+
game.board.resources.green,
24+
game.board.resources.blue,
25+
game.board.resources.black,
26+
game.board.resources.white,
27+
game.board.resources.gold - 1
2928
)
3029
current_player.resources = AllResources(
31-
**dict(
32-
(field.name, getattr(game.board.resources, field.name))
33-
for field in fields(BasicResources)
34-
),
35-
gold=game.board.resources.gold + 1
30+
current_player.resources.red,
31+
current_player.resources.green,
32+
current_player.resources.blue,
33+
current_player.resources.black,
34+
current_player.resources.white,
35+
current_player.resources.gold + 1
3636
)

0 commit comments

Comments
 (0)