Skip to content

Commit 8ff5b7a

Browse files
committed
Added missing state values
1 parent 6340e62 commit 8ff5b7a

6 files changed

Lines changed: 35 additions & 15 deletions

File tree

Config.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,8 @@ class _ConfigAgent:
2727
c = 0.2
2828
learning_rate = 1e-5
2929
debug = False
30-
pretrain = True
30+
# pretrain = True
31+
pretrain = False
3132

3233

3334
class Config(_ConfigPaths, _ConfigAgent):

agent/Agent.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
class Agent(nn.Module):
1010
_input_size_dictionary = {
11-
2: 205,
11+
2: 211,
1212
}
1313

1414
def __init__(

agent/policy.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,10 @@ def policy(
1313
c: float,
1414
n_simulations: int,
1515
):
16-
N = defaultdict(lambda: defaultdict(int))
16+
N = defaultdict(list)
1717
visited = set()
18-
P = defaultdict(dict)
19-
Q = defaultdict(dict)
18+
P = defaultdict(list)
19+
Q = defaultdict(list)
2020
initial_state = game.get_state()
2121
all_moves = game.get_possible_actions()
2222
for _ in range(n_simulations):

agent/search.py

Lines changed: 23 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,10 @@ def search(
1010
game: Game,
1111
agent: nn.Module,
1212
c: float,
13-
N: defaultdict,
13+
N: defaultdict[list[int]],
1414
visited: set,
15-
P: defaultdict,
16-
Q: defaultdict,
15+
P: defaultdict[list],
16+
Q: defaultdict[list],
1717
):
1818
if game.is_terminal():
1919
return -game.get_results()[game.current_player.id]
@@ -26,13 +26,26 @@ def search(
2626
for index, move in enumerate(game.all_moves)
2727
)
2828
return -v
29-
30-
action = max(
31-
game.get_possible_actions(),
32-
key=lambda action: Q[state].get(action, 1)
33-
+ c * P[state][action] * sqrt(sum(N[state].values())) / (1 + N[state][action]),
34-
)
35-
29+
q_state = Q[state]
30+
p_state = P[state]
31+
n_state = N[state]
32+
sqrt_value = sqrt(sum(n_state.values()))
33+
def _get_action(game: Game):
34+
return max(
35+
game.get_possible_actions(),
36+
key=lambda action: q_state.get(action, 1) + c * p_state[action] * sqrt_value / (1 + n_state[action]),
37+
)
38+
# def _get_action(game: Game):
39+
# actions = sorted(
40+
# game.all_moves,
41+
# key=lambda action: q_state.get(action, 1)
42+
# + c * p_state[action] * sqrt_value / (1 + n_state[action]),
43+
# reverse=True,
44+
# )
45+
# for action in actions:
46+
# if action.is_valid(game):
47+
# return action
48+
action = _get_action(game)
3649
next_game_state = game.perform(action)
3750
v = search(next_game_state, agent, c, N, visited, P, Q)
3851

src/Game.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,11 @@ def get_possible_actions(self) -> tuple[Move, ...]:
155155
(self.null_move,) if self.null_move.is_valid(self) else tuple()
156156
)
157157

158+
def get_possible_action_indexes(self) -> tuple[int, ...]:
159+
return tuple(index for index, move in enumerate(self.all_moves) if move.is_valid(self)) or (
160+
(self.null_move,) if self.null_move.is_valid(self) else tuple()
161+
)
162+
158163
combos = combinations([{field.name: 1} for field in fields(BasicResources)], 3)
159164
all_moves = list(
160165
GrabThreeResource(BasicResources(**res_1, **res_2, **res_3))

src/StateExtractor.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ def get_state(cls, game: "Game") -> tuple:
4848
tuple(iter(aristocrat.cost))
4949
for aristocrat in game.board.aristocrats
5050
),
51+
iter(game.board.resources),
5152
chain.from_iterable(
5253
(
5354
*tuple(iter(player.resources)),

0 commit comments

Comments
 (0)