|
1 | | -from collections import deque |
| 1 | +from typing import Sequence |
2 | 2 |
|
3 | 3 | import numpy as np |
| 4 | +from sklearn.metrics import accuracy_score |
4 | 5 | from torch import nn, optim |
5 | 6 | from torch.utils.data import DataLoader |
6 | 7 |
|
|
9 | 10 | from .RLDataset import RLDataset |
10 | 11 |
|
11 | 12 |
|
12 | | -def train_agent(agent: Agent, train_data: deque[tuple[tuple, np.array, int]]): |
| 13 | +def train_agent(agent: Agent, train_data: Sequence[tuple[tuple, np.array, int]]): |
13 | 14 | agent.train() |
| 15 | + optimizer = optim.Adam(agent.parameters(), lr=Config.learning_rate) |
| 16 | + _loop(agent, train_data, optimizer) |
| 17 | + |
| 18 | + |
| 19 | +def eval_agent(agent: Agent, eval_set: Sequence[tuple[tuple, np.array, int]]): |
| 20 | + agent.eval() |
| 21 | + return _loop(agent, eval_set, batch_size=len(eval_set)) |
| 22 | + |
| 23 | + |
| 24 | +def _loop( |
| 25 | + agent: Agent, |
| 26 | + dataset: Sequence[tuple[tuple, np.array, int]], |
| 27 | + optimizer: optim.Optimizer = None, |
| 28 | + batch_size=Config.train_batch_size, |
| 29 | +): |
| 30 | + is_optimizer = optimizer is not None |
14 | 31 | categorical_cross_entropy = nn.CrossEntropyLoss() |
15 | 32 | mse = nn.MSELoss() |
16 | | - optimizer = optim.Adam(agent.parameters(), lr=Config.learning_rate) |
17 | | - dataset = RLDataset(train_data) |
18 | | - loader = DataLoader(dataset, batch_size=Config.train_batch_size) |
19 | | - for batch in loader: |
20 | | - state, policy, win_probability = batch |
| 33 | + dataset = RLDataset(dataset) |
| 34 | + loader = DataLoader(dataset, batch_size=batch_size) |
| 35 | + for index, (state, policy, win_probability) in enumerate(loader): |
21 | 36 | state, policy, win_probability = ( |
22 | 37 | state.float(), |
23 | 38 | policy.float(), |
24 | 39 | win_probability.float(), |
25 | 40 | ) |
26 | | - optimizer.zero_grad() |
| 41 | + if is_optimizer: |
| 42 | + optimizer.zero_grad() |
27 | 43 | output_policy, output_v = agent(state) |
28 | 44 | bce = mse(output_v, win_probability) |
29 | 45 | cce = categorical_cross_entropy(output_policy, policy) |
30 | | - bce.backward(retain_graph=True) |
31 | | - cce.backward() |
32 | | - optimizer.step() |
| 46 | + if is_optimizer: |
| 47 | + bce.backward(retain_graph=True) |
| 48 | + cce.backward() |
| 49 | + optimizer.step() |
| 50 | + else: |
| 51 | + print( |
| 52 | + accuracy_score(win_probability, np.sign(output_v.detach().numpy())), |
| 53 | + accuracy_score( |
| 54 | + np.argmax(policy.detach().numpy(), axis=1), |
| 55 | + np.argmax(output_policy.detach().numpy(), axis=1), |
| 56 | + ), |
| 57 | + ) |
| 58 | + return bce.item(), cce.item() |
0 commit comments