1+ import random
12from collections import deque
2- from itertools import count , cycle , islice
3- from more_itertools import windowed
3+ from itertools import count
44
55import numpy as np
66from tqdm import tqdm
77
88from Config import Config
9+ from src .Game import Game
910from .Agent import Agent
1011from .policy import policy
11- from src .Game import Game
1212
1313
1414def self_play (
1515 agents : deque [Agent ],
1616) -> tuple [list [tuple [np .array , np .array , int ]], list [Agent ]]:
1717 states = []
1818 winners = []
19- initial_state = Game (n_players = Config .n_players )
19+ game = Game (n_players = Config .n_players )
2020 for agent in agents :
2121 agent .eval ()
22- for agents_in_order in islice (
23- windowed (cycle (agents ), Config .n_players ), Config .n_players
24- ):
25- game = initial_state .copy ()
26- id_to_agent = dict (
27- (player .id , agent ) for agent , player in zip (agents_in_order , game .players )
28- )
29- results , winner = _perform_game (game , [], id_to_agent )
30- states += results
31- winners .append (winner )
22+ id_to_agent = dict (
23+ (player .id , agent ) for agent , player in zip (random .sample (agents , Config .n_players ), game .players )
24+ )
25+ results , winner = _perform_game (game , [], id_to_agent )
26+ states += results
27+ winners .append (winner )
3228 return states , winners
3329
3430
@@ -49,10 +45,10 @@ def _perform_game(
4945 list (
5046 (
5147 state [0 ].get_state (),
52- ( onehot_encoded_action := np .zeros (Config .n_actions ), onehot_encoded_action . __setitem__ ( game .all_moves .index (state [1 ]), 1 ))[ 0 ],
48+ np .eye (Config .n_actions )[ game .all_moves .index (state [1 ])],
5349 int (result [state [0 ].current_player .id ] == 1 ),
5450 )
55- for state in states
51+ for state in states if state [ 1 ] != game . null_move
5652 ),
5753 id_to_agent [
5854 next (player .id for player in game .players if result [player .id ])
0 commit comments