Skip to content

Commit daf75ee

Browse files
committed
modifying options to find a mistake
1 parent bcd46e1 commit daf75ee

6 files changed

Lines changed: 12 additions & 11 deletions

File tree

Config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ class Config(_ConfigPaths, _ConfigAgent):
3030
train_batch_size = 64
3131
training_buffer_len = 1000
3232
min_n_points_to_finish = 15
33-
n_simulations = 100
33+
n_simulations = 250
3434
n_games = None
3535
n_players = 2
3636
n_actions = 46

agent/RLDataset.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,4 +12,5 @@ def __len__(self):
1212
return len(self.examples)
1313

1414
def __getitem__(self, index) -> tuple[np.array, ...]:
15-
return np.array(self.examples[index][0]), np.array(self.examples[index][1]), np.array([self.examples[index][2]])
15+
return np.array(self.examples[index][0]), np.array(self.examples[index][1]), np.array(
16+
[self.examples[index][2] * 2 - 1])

agent/policy.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
def policy(
1111
game: Game,
12-
agents: dict[int, nn.Module],
12+
agent: nn.Module,
1313
c: float,
1414
n_simulations: int,
1515
):
@@ -20,6 +20,6 @@ def policy(
2020
initial_state = game.get_state()
2121
all_moves = game.get_possible_actions()
2222
for _ in range(n_simulations):
23-
search(game.copy(), agents, c, N, visited, P, Q)
23+
search(game.copy(), agent, c, N, visited, P, Q)
2424
pi = [N[initial_state][a] for a in all_moves]
2525
return pi, all_moves[np.argmax(pi)]

agent/search.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,17 +8,16 @@
88

99
def search(
1010
game: Game,
11-
agents: dict[int, nn.Module],
11+
agent: nn.Module,
1212
c: float,
1313
N: defaultdict,
1414
visited: set,
1515
P: defaultdict,
1616
Q: defaultdict,
1717
):
1818
if game.is_terminal():
19-
return game.get_results()[game.current_player.id]
19+
return -game.get_results()[game.current_player.id]
2020
state = game.get_state()
21-
agent = agents[game.current_player.id]
2221
if state not in visited:
2322
visited.add(state)
2423
move_scores, v = agent(Tensor([state]))
@@ -35,7 +34,7 @@ def search(
3534
)
3635

3736
next_game_state = game.perform(action)
38-
v = search(next_game_state, agents, c, N, visited, P, Q)
37+
v = search(next_game_state, agent, c, N, visited, P, Q)
3938

4039
Q[state][action] = (N[state][action] * Q[state].get(action, 1) + v) / (
4140
N[state][action] + 1

agent/self_play.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,8 @@ def self_play(agents: deque[Agent]) -> tuple[list[tuple[np.array, np.array, int]
1717
for agent in agents:
1818
agent.eval()
1919
for _ in tqdm(count()):
20-
pi, action = policy(game, id_to_agent, Config.c, Config.n_simulations)
20+
agent = id_to_agent[game.current_player.id]
21+
pi, action = policy(game, agent, Config.c, Config.n_simulations)
2122
action_index = game.all_moves.index(action)
2223
onehot_encoded_action = np.zeros(Config.n_actions)
2324
onehot_encoded_action[action_index] = 1

agent/train_agent.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
def train_agent(agent: Agent, train_data: deque[tuple[tuple, np.array, int]]):
1313
agent.train()
1414
categorical_cross_entropy = nn.CrossEntropyLoss()
15-
binary_cross_entropy = nn.BCELoss()
15+
mse = nn.MSELoss()
1616
optimizer = optim.Adam(agent.parameters(), lr=Config.learning_rate)
1717
dataset = RLDataset(train_data)
1818
loader = DataLoader(dataset, batch_size=Config.train_batch_size)
@@ -21,7 +21,7 @@ def train_agent(agent: Agent, train_data: deque[tuple[tuple, np.array, int]]):
2121
state, policy, win_probability = state.float(), policy.float(), win_probability.float()
2222
optimizer.zero_grad()
2323
output_policy, output_v = agent(state)
24-
bce = binary_cross_entropy((output_v + 1) / 2, win_probability)
24+
bce = mse(output_v, win_probability)
2525
cce = categorical_cross_entropy(output_policy, policy)
2626
bce.backward(retain_graph=True)
2727
cce.backward()

0 commit comments

Comments
 (0)