11from itertools import pairwise , starmap
22
3+ import numpy as np
34from torch import nn , Tensor
45
6+ from Config import Config
7+
58
69class Agent (nn .Module ):
710 _input_size_dictionary = {
@@ -11,27 +14,29 @@ class Agent(nn.Module):
1114 def __init__ (
1215 self ,
1316 n_players : int ,
14- hidden_sizes : tuple = ( 256 , 128 , 64 , 32 ) ,
17+ hidden_sizes : tuple = Config . hidden_sizes ,
1518 n_moves : int = 46 ,
1619 ):
1720 super ().__init__ ()
18- self .relu = nn .ReLU ()
1921 self .tanh = nn .Tanh ()
2022 self .softmax = nn .Softmax (dim = 1 )
2123 first_size = self ._get_size (n_players )
2224 sizes = first_size , * hidden_sizes
23- self .layers = tuple (starmap (nn .Linear , pairwise (sizes )))
24- for index , layer in enumerate (self .layers ):
25- setattr (self , f"layer_{ index } " , layer )
26- self .fc_v = nn .Linear (hidden_sizes [- 1 ], 1 )
27- self .fc_p = nn .Linear (hidden_sizes [- 1 ], n_moves )
25+ self .layers = nn .ModuleList (starmap (nn .Linear , pairwise (sizes )))
26+ self .trained = False
27+ self .fc_v = nn .Linear (sizes [- 1 ], 1 )
28+ self .fc_p = nn .Linear (sizes [- 1 ], n_moves )
2829 self ._n_moves = n_moves
2930
3031 def _get_size (self , n_players : int ) -> int :
3132 return self ._input_size_dictionary [n_players ]
3233
3334 def forward (self , state : Tensor ):
35+ if not self .training and not self .trained :
36+ return self .softmax (Tensor (np .random .random ((1 , self ._n_moves )))), Tensor (
37+ np .random .uniform (- 1 , 1 , (1 , 1 ))
38+ )
39+ self .trained = True
3440 for layer in self .layers :
3541 state = layer (state )
36- state = self .relu (state )
3742 return self .softmax (self .fc_p (state )), self .tanh (self .fc_v (state ))
0 commit comments