Skip to content

Commit 6340e62

Browse files
committed
Merge branch 'refs/heads/feature/evaluation'
2 parents c4c9a24 + 0d2f479 commit 6340e62

4 files changed

Lines changed: 27 additions & 11 deletions

File tree

Config.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,16 +22,16 @@ class _ConfigAgent:
2222
# 64,
2323
# 32,
2424
# )
25-
# hidden_sizes = (256,)
26-
hidden_sizes = tuple()
25+
hidden_sizes = (256,)
26+
# hidden_sizes = tuple()
2727
c = 0.2
2828
learning_rate = 1e-5
2929
debug = False
3030
pretrain = True
3131

3232

3333
class Config(_ConfigPaths, _ConfigAgent):
34-
train = False
34+
train = True
3535
max_results_held = 100
3636
minimal_relative_agent_improvement = 1.1
3737
min_games_to_replace_agents = 40

agent/policy.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,5 +21,5 @@ def policy(
2121
all_moves = game.get_possible_actions()
2222
for _ in range(n_simulations):
2323
search(game.copy(), agent, c, N, visited, P, Q)
24-
pi = [N[initial_state][a] for a in all_moves]
24+
pi = np.array([N[initial_state][a] for a in all_moves])
2525
return pi, all_moves[np.argmax(pi)]

agent/self_play.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -35,18 +35,15 @@ def _perform_game(
3535
for turn in tqdm(count()):
3636
agent = id_to_agent[game.current_player.id]
3737
pi, action = policy(game, agent, Config.c, Config.n_simulations)
38-
action_index = game.all_moves.index(action)
39-
onehot_encoded_action = np.zeros(Config.n_actions)
40-
onehot_encoded_action[action_index] = 1
41-
states.append((game, action, 0))
38+
states.append((game, pi / pi.sum(), 0))
4239
game = game.perform(action)
4340
if game.is_terminal():
4441
result = game.get_results()
4542
return (
4643
list(
4744
(
4845
state[0].get_state(),
49-
np.eye(Config.n_actions)[game.all_moves.index(state[1])],
46+
state[1],
5047
int(result[state[0].current_player.id] == 1),
5148
)
5249
for state in states

main.py

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -92,8 +92,27 @@ def evaluation():
9292
bce, cce = eval_agent(agent, eval_set)
9393
if bce >= prev_bce and cce >= prev_bce:
9494
break
95-
prev_bce = bce
96-
prev_cce = cce
95+
prev_bce = min(prev_bce, bce)
96+
prev_cce = min(prev_cce, cce)
97+
98+
99+
# def evaluation():
100+
# v_agent = LogisticRegression()
101+
# p_agent = LogisticRegression()
102+
# train_set = reduce(
103+
# operator.add,
104+
# (eval(path.read_text()) for path in Config.training_data_path.iterdir()),
105+
# )
106+
# eval_set = reduce(
107+
# operator.add,
108+
# (eval(path.read_text()) for path in Config.evaluation_data_path.iterdir()),
109+
# )
110+
# v_agent.fit(tuple(sample[0] for sample in train_set), tuple(sample[2] for sample in train_set))
111+
# p_agent.fit(tuple(sample[0] for sample in train_set), np.argmax(np.array(tuple(sample[1] for sample in train_set)), axis=1))
112+
# print(
113+
# v_agent.score(tuple(sample[0] for sample in eval_set), tuple(sample[2] for sample in eval_set)),
114+
# p_agent.score(tuple(sample[0] for sample in eval_set), np.argmax(np.array(tuple(sample[1] for sample in eval_set)), axis=1)),
115+
# )
97116

98117

99118
def main():

0 commit comments

Comments
 (0)