11import re
22from collections import deque
33from copy import deepcopy
4- from itertools import count
4+ from itertools import count , islice
55from pathlib import Path
66
77import torch
@@ -16,19 +16,28 @@ def main():
1616 training_buffer = deque (maxlen = Config .training_buffer_len )
1717 agents = deque ((Agent (Config .n_players ) for _ in range (Config .n_players )), maxlen = Config .n_players )
1818 if Config .pretrain :
19+ for agent , checkpoint_index in zip (islice (reversed (agents ), 1 , None ), sorted ((int (path .name .split ('.' )[0 ]) for path in Config .model_path .iterdir ()), reverse = True )):
20+ agent .load_state_dict (torch .load (Config .model_path .joinpath (f'{ checkpoint_index } .pth' )))
21+ agents [- 1 ].load_state_dict (torch .load (Config .model_path .joinpath (f"{ max (int (path .name .split ('.' )[0 ]) for path in Config .model_path .iterdir ())} .pth" )))
1922 training_buffer += list (map (eval , map (Path .read_text , sorted (Config .data_path .iterdir (), key = lambda path : int (path .name ), reverse = True )[:Config .training_buffer_len ])))
2023 train_agent (agents [- 1 ], training_buffer )
2124 scores = deque (maxlen = Config .max_results_held )
2225 for _ in (count () if Config .n_games is None else range (Config .n_games )):
2326 buffer , winner = self_play (agents )
24- Config .data_path .joinpath (str (max ((* tuple (int (path .name ) for path in Config .data_path .iterdir ()), - 1 )) + 1 )).write_text (str ((list (buffer [0 ][0 ]), list (buffer [0 ][1 ]), buffer [0 ][2 ])))
27+ start_index = max ((* tuple (int (path .name ) for path in Config .data_path .iterdir ()), - 1 )) + 1
28+ for start_index , sample in enumerate (buffer , start_index + 1 ):
29+ Config .data_path .joinpath (str (start_index )).write_text (str ((list (sample [0 ]), list (sample [1 ]), sample [2 ])))
2530 scores .append (agents [- 1 ] is winner )
26- if len (scores ) >= Config .min_games_to_replace_agents and sum (scores ) > Config .minimal_relative_agent_improvement * len (scores ) / len (agents ):
31+ if ( len (scores ) < Config . min_games_to_replace_agents and sum ( scores ) >= Config .minimal_relative_agent_improvement * Config . min_games_to_replace_agents / len ( agents )) or ( len ( scores ) >= Config . min_games_to_replace_agents and sum (scores ) >= Config .minimal_relative_agent_improvement * len (scores ) / len (agents ) ):
2732 torch .save (agents [- 1 ].state_dict (), Config .model_path .joinpath (str (max (map (int , (* re .findall (r'\d+' , '' .join (map (str , Config .model_path .iterdir ()))), - 1 ))) + 1 ) + ".pth" ))
28- agents .append (Agent (Config .n_players ).load_state_dict (deepcopy (agents [- 1 ].state_dict ())))
33+ agents .append (Agent (Config .n_players ))
34+ agents [- 1 ].load_state_dict (deepcopy (agents [- 1 ].state_dict ()))
2935 agents [- 1 ].training = True
3036 scores = deque (maxlen = Config .max_results_held )
31- print (sum (scores ) / len (scores ), len (scores ))
37+ elif len (scores ) >= Config .min_games_to_replace_agents :
38+ print (f'{ len (scores )} { sum (scores ) / len (scores ):.2f} ' )
39+ else :
40+ print (f'{ len (scores )} { sum (scores )} /{ len (scores )} ' )
3241 training_buffer += buffer
3342 train_agent (agents [- 1 ], training_buffer )
3443
0 commit comments