|
1 | | -from collections import defaultdict |
2 | | -from dataclasses import astuple |
3 | | -from math import sqrt |
| 1 | +from collections import deque |
4 | 2 |
|
5 | 3 | import numpy as np |
6 | | -from torch import nn, Tensor |
7 | | -from tqdm import tqdm |
| 4 | +from torch import nn, optim |
| 5 | +from torch.utils.data import DataLoader |
8 | 6 |
|
9 | 7 | from Config import Config |
10 | | -from src.Game import Game |
11 | 8 | from .Agent import Agent |
12 | | - |
13 | | - |
14 | | -def train_agent(): |
15 | | - agent = Agent(Config.n_players) |
16 | | - agent.eval() |
17 | | - examples = [] |
18 | | - examples_per_game = [] |
19 | | - for i in range(Config.n_games): |
20 | | - game = Game(n_players=Config.n_players) |
21 | | - while True: |
22 | | - pi, action = policy(game, agent, 1, Config.n_simulations) |
23 | | - examples_per_game.append((game, pi, 0)) |
24 | | - game = game.perform(action) |
25 | | - print(len(game.players[1].cards), game.players[1].points) |
26 | | - if game.is_terminal(): |
27 | | - for example in examples_per_game: |
28 | | - example[2] = game.get_state() |
29 | | - break |
30 | | - examples += examples_per_game |
31 | | - break |
32 | | - return examples |
33 | | - |
34 | | - |
35 | | -def search( |
36 | | - game: Game, |
37 | | - agent: nn.Module, |
38 | | - c: float, |
39 | | - N: defaultdict, |
40 | | - visited: set, |
41 | | - P: defaultdict, |
42 | | - Q: defaultdict, |
43 | | -): |
44 | | - state = game.get_state() |
45 | | - if game.is_terminal(): |
46 | | - return game.get_results()[game.current_player] |
47 | | - if state not in visited: |
48 | | - visited.add(state) |
49 | | - move_scores, v = agent(Tensor([state])) |
50 | | - tuple( |
51 | | - P[state].__setitem__(move, move_scores[0, index]) |
52 | | - for index, move in enumerate(game.all_moves) |
53 | | - ) |
54 | | - return -v |
55 | | - |
56 | | - action = max( |
57 | | - game.get_possible_actions(), |
58 | | - key=lambda action: Q[state].get(action, 1) |
59 | | - + c * P[state][action] * sqrt(sum(N[state].values())) / (1 + N[state][action]), |
60 | | - ) |
61 | | - |
62 | | - next_game_state = game.perform(action) |
63 | | - v = search(next_game_state, agent, c, N, visited, P, Q) |
64 | | - |
65 | | - Q[state][action] = (N[state][action] * Q[state].get(action, 1) + v) / ( |
66 | | - N[state][action] + 1 |
67 | | - ) |
68 | | - N[state][action] += 1 |
69 | | - return -v |
70 | | - |
71 | | - |
72 | | -def policy( |
73 | | - game: Game, |
74 | | - agent: nn.Module, |
75 | | - c: float, |
76 | | - n_simulations: int, |
77 | | -): |
78 | | - N = defaultdict(lambda: defaultdict(int)) |
79 | | - visited = set() |
80 | | - P = defaultdict(dict) |
81 | | - Q = defaultdict(dict) |
82 | | - initial_state = game.get_state() |
83 | | - all_moves = game.get_possible_actions() |
84 | | - for _ in tqdm(range(n_simulations)): |
85 | | - search(game, agent, c, N, visited, P, Q) |
86 | | - pi = [N[initial_state][a] for a in all_moves] |
87 | | - return pi, all_moves[np.argmax(pi)] |
| 9 | +from .RLDataset import RLDataset |
| 10 | + |
| 11 | + |
| 12 | +def train_agent(agent: Agent, train_data: deque[tuple[tuple, np.array, int]]): |
| 13 | + agent.train() |
| 14 | + categorical_cross_entropy = nn.CrossEntropyLoss() |
| 15 | + mse = nn.MSELoss() |
| 16 | + optimizer = optim.Adam(agent.parameters(), lr=Config.learning_rate) |
| 17 | + dataset = RLDataset(train_data) |
| 18 | + loader = DataLoader(dataset, batch_size=Config.train_batch_size) |
| 19 | + for batch in loader: |
| 20 | + state, policy, win_probability = batch |
| 21 | + state, policy, win_probability = state.float(), policy.float(), win_probability.float() |
| 22 | + optimizer.zero_grad() |
| 23 | + output_policy, output_v = agent(state) |
| 24 | + bce = mse(output_v, win_probability) |
| 25 | + cce = categorical_cross_entropy(output_policy, policy) |
| 26 | + bce.backward(retain_graph=True) |
| 27 | + cce.backward() |
| 28 | + optimizer.step() |
0 commit comments