1717
1818from areal .api .cli_args import InferenceEngineConfig
1919from areal .api .engine_api import InferenceEngine
20- from areal .api . io_struct import RolloutStat
20+ from areal .core . staleness_manager import StalenessManager
2121from areal .experimental .openai .types import CompletionWithTokenLogpReward
2222from areal .utils import logging
2323from areal .utils .data import concat_padded_tensors , cycle_dataloader
@@ -256,24 +256,28 @@ def __init__(
256256 self ,
257257 config : InferenceEngineConfig ,
258258 inference_engine : "InferenceEngine" ,
259+ staleness_manager : StalenessManager | None = None ,
259260 ):
260261 self .max_concurrent_rollouts = (
261262 config .max_concurrent_rollouts or config .consumer_batch_size
262263 )
264+ self .consumer_batch_size = config .consumer_batch_size
265+
263266 self .config = config
264267 self .exiting = threading .Event ()
265268 self .paused = threading .Event ()
266- self .lock = threading .Lock ()
267269
268270 self .inference_engine = inference_engine
269271
272+ # Use provided staleness manager or create a default one
273+ # The manager will be properly initialized in initialize()
274+ self .staleness_manager = staleness_manager
275+
270276 qsize = config .queue_size or self .max_concurrent_rollouts * 16
271277 self .input_queue = queue .Queue (maxsize = qsize )
272278 self .output_queue = queue .Queue (maxsize = qsize )
273279 self .result_cache : List [_TimedResult ] = []
274280
275- self .rollout_stat = RolloutStat ()
276-
277281 # For trajectory format checking
278282 self ._expected_trajectory_keys : set | None = None
279283
@@ -282,18 +286,31 @@ def initialize(self, logger=None, train_data_parallel_size: int | None = None):
282286 logger = logging .getLogger ("WorkflowExecutor" )
283287 self .logger = logger
284288
285- if train_data_parallel_size is not None :
286- self .dp_world_size = train_data_parallel_size
287- else :
288- if dist .is_initialized ():
289- if not mpu .is_initialized ():
290- self .dp_world_size = dist .get_world_size ()
291- else :
292- self .dp_world_size = mpu .get_data_parallel_world_size ()
289+ # Initialize staleness manager if not provided
290+ if self .staleness_manager is None :
291+ if train_data_parallel_size is not None :
292+ dp_world_size = train_data_parallel_size
293293 else :
294- self .dp_world_size = 1
294+ if dist .is_initialized ():
295+ if not mpu .is_initialized ():
296+ dp_world_size = dist .get_world_size ()
297+ else :
298+ dp_world_size = mpu .get_data_parallel_world_size ()
299+ else :
300+ dp_world_size = 1
301+
302+ # Apply data parallel scaling
303+ max_concurrent_rollouts = max (
304+ 1 , self .max_concurrent_rollouts // dp_world_size
305+ )
306+ consumer_batch_size = max (1 , self .consumer_batch_size // dp_world_size )
307+
308+ self .staleness_manager = StalenessManager (
309+ max_concurrent_rollouts = max_concurrent_rollouts ,
310+ consumer_batch_size = consumer_batch_size ,
311+ max_staleness = self .config .max_head_offpolicyness ,
312+ )
295313
296- self .rollout_tasks : Dict [str , _RolloutTask ] = {}
297314 self .rollout_thread = threading .Thread (
298315 target = self ._rollout_thread , daemon = True
299316 ) # set daemon=True to automatically exit when error occurs
@@ -304,17 +321,8 @@ def destroy(self):
304321 self .rollout_thread .join ()
305322
306323 def get_capacity (self ):
307- with self .lock :
308- max_concurrent_rollouts = max (
309- 1 , self .max_concurrent_rollouts // self .dp_world_size
310- )
311- capacity = max_concurrent_rollouts - len (self .rollout_tasks )
312- # Staleness control
313- version = self .inference_engine .get_version ()
314- ofp = self .config .max_head_offpolicyness
315- sample_cnt = self .rollout_stat .accepted + self .rollout_stat .running
316- consumer_bs = max (1 , self .config .consumer_batch_size // self .dp_world_size )
317- capacity = min (capacity , (ofp + version + 1 ) * consumer_bs - sample_cnt )
324+ version = self .inference_engine .get_version ()
325+ capacity = self .staleness_manager .get_capacity (version )
318326 return capacity
319327
320328 def _rollout_thread (self ):
@@ -325,14 +333,13 @@ def _rollout_thread(self):
325333 traceback .print_exc ()
326334
327335 async def _rollout_thread_async (self ):
328- rollout_tasks = self . rollout_tasks
336+ rollout_tasks : Dict [ str , _RolloutTask ] = {}
329337 rid = 0
330338 try :
331339 while not self .exiting .is_set ():
332340 # Check capacity
333341 capacity = self .get_capacity ()
334342 # Create new rollout task
335- self .lock .acquire ()
336343 while (
337344 capacity > 0
338345 and not self .paused .is_set ()
@@ -348,19 +355,19 @@ async def _rollout_thread_async(self):
348355 rollout_tasks [str (rid )] = _RolloutTask (
349356 create_time = time .monotonic_ns (), task = task , task_input = x
350357 )
351- self . rollout_stat . submitted += 1
352- self .rollout_stat . running += 1
358+ # Notify staleness manager
359+ self .staleness_manager . on_rollout_submitted ()
353360 if self .config .enable_rollout_tracing :
361+ stat = self .staleness_manager .get_stats ()
354362 self .logger .info (
355363 f"Submit rollout rid { rid } . "
356- f"Submit: { self . rollout_stat .submitted } , "
357- f"running: { self . rollout_stat .running } , "
358- f"accepted: { self . rollout_stat .accepted } ."
364+ f"Submit: { stat .submitted } , "
365+ f"running: { stat .running } , "
366+ f"accepted: { stat .accepted } ."
359367 )
360368 capacity -= 1
361369 rid += 1
362370 tasks = [x .task for x in rollout_tasks .values ()]
363- self .lock .release ()
364371
365372 # Wait for rollout completion
366373 done = []
@@ -396,26 +403,25 @@ async def _rollout_thread_async(self):
396403 )
397404 assert traj is None or isinstance (traj , dict ), traj
398405 task_rid = task .get_name ()
399- with self .lock :
400- task_obj = rollout_tasks .pop (task_rid )
401- self .rollout_stat .accepted += 1
402- self .rollout_stat .running -= 1
403- if self .config .enable_rollout_tracing :
404- self .logger .info (
405- f"Finish rollout { task_rid } . "
406- f"Submit: { self .rollout_stat .submitted } , "
407- f"running: { self .rollout_stat .running } , "
408- f"accepted: { self .rollout_stat .accepted } ."
409- )
406+ task_obj = rollout_tasks .pop (task_rid )
410407
411408 task_input = task_obj .task_input
412- if traj is not None and (
409+ # Check if trajectory should be accepted
410+ should_accept_traj = traj is not None and (
413411 task_input .should_accept is None
414412 or task_input .should_accept (traj )
415- ):
413+ )
414+
415+ if should_accept_traj :
416+ # Notify staleness manager of accepted rollout
417+ self .staleness_manager .on_rollout_accepted ()
416418 if self .config .enable_rollout_tracing :
419+ stat = self .staleness_manager .get_stats ()
417420 self .logger .info (
418- f"Accept rollout result of task { task_rid } ."
421+ f"Finish and accept rollout { task_rid } . "
422+ f"Submit: { stat .submitted } , "
423+ f"running: { stat .running } , "
424+ f"accepted: { stat .accepted } ."
419425 )
420426 try :
421427 self .output_queue .put_nowait (
@@ -426,24 +432,30 @@ async def _rollout_thread_async(self):
426432 "Output queue full. Please increase queue_size."
427433 )
428434 else :
435+ # Rollout completed but was rejected
436+ # Only decrement running count since it was never accepted
437+ self .staleness_manager .on_rollout_rejected ()
429438 if self .config .enable_rollout_tracing :
430- self .logger .info (f"Rollout is rejected." )
431- with self .lock :
432- self .rollout_stat .accepted -= 1
439+ stat = self .staleness_manager .get_stats ()
440+ self .logger .info (
441+ f"Finish but reject rollout { task_rid } . "
442+ f"Submit: { stat .submitted } , "
443+ f"running: { stat .running } , "
444+ f"accepted: { stat .accepted } ."
445+ )
433446
434447 await asyncio .sleep (1 )
435448 except Exception :
436449 traceback .print_exc ()
437450 finally :
438451 # Cancel remaining tasks
439- with self .lock :
440- for task_obj in rollout_tasks .values ():
441- if not task_obj .task .done ():
442- task_obj .task .cancel ()
443- try :
444- await task_obj .task
445- except asyncio .CancelledError :
446- pass
452+ for task_obj in rollout_tasks .values ():
453+ if not task_obj .task .done ():
454+ task_obj .task .cancel ()
455+ try :
456+ await task_obj .task
457+ except asyncio .CancelledError :
458+ pass
447459
448460 def submit (
449461 self ,
0 commit comments