Skip to content

Commit 2ab0169

Browse files
refactor: separate staleness control from workflow execution (#444)
* isolate staleness control * . * pass test * polish docstring * polish docstring * rename to staleness manager * minor fix test_megatron_engine * fix naming * Update areal/core/staleness_manager.py Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> --------- Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
1 parent 93fb172 commit 2ab0169

9 files changed

Lines changed: 879 additions & 70 deletions

areal/api/workflow_api.py

Lines changed: 70 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
from areal.api.cli_args import InferenceEngineConfig
1919
from areal.api.engine_api import InferenceEngine
20-
from areal.api.io_struct import RolloutStat
20+
from areal.core.staleness_manager import StalenessManager
2121
from areal.experimental.openai.types import CompletionWithTokenLogpReward
2222
from areal.utils import logging
2323
from areal.utils.data import concat_padded_tensors, cycle_dataloader
@@ -256,24 +256,28 @@ def __init__(
256256
self,
257257
config: InferenceEngineConfig,
258258
inference_engine: "InferenceEngine",
259+
staleness_manager: StalenessManager | None = None,
259260
):
260261
self.max_concurrent_rollouts = (
261262
config.max_concurrent_rollouts or config.consumer_batch_size
262263
)
264+
self.consumer_batch_size = config.consumer_batch_size
265+
263266
self.config = config
264267
self.exiting = threading.Event()
265268
self.paused = threading.Event()
266-
self.lock = threading.Lock()
267269

268270
self.inference_engine = inference_engine
269271

272+
# Use provided staleness manager or create a default one
273+
# The manager will be properly initialized in initialize()
274+
self.staleness_manager = staleness_manager
275+
270276
qsize = config.queue_size or self.max_concurrent_rollouts * 16
271277
self.input_queue = queue.Queue(maxsize=qsize)
272278
self.output_queue = queue.Queue(maxsize=qsize)
273279
self.result_cache: List[_TimedResult] = []
274280

275-
self.rollout_stat = RolloutStat()
276-
277281
# For trajectory format checking
278282
self._expected_trajectory_keys: set | None = None
279283

@@ -282,18 +286,31 @@ def initialize(self, logger=None, train_data_parallel_size: int | None = None):
282286
logger = logging.getLogger("WorkflowExecutor")
283287
self.logger = logger
284288

285-
if train_data_parallel_size is not None:
286-
self.dp_world_size = train_data_parallel_size
287-
else:
288-
if dist.is_initialized():
289-
if not mpu.is_initialized():
290-
self.dp_world_size = dist.get_world_size()
291-
else:
292-
self.dp_world_size = mpu.get_data_parallel_world_size()
289+
# Initialize staleness manager if not provided
290+
if self.staleness_manager is None:
291+
if train_data_parallel_size is not None:
292+
dp_world_size = train_data_parallel_size
293293
else:
294-
self.dp_world_size = 1
294+
if dist.is_initialized():
295+
if not mpu.is_initialized():
296+
dp_world_size = dist.get_world_size()
297+
else:
298+
dp_world_size = mpu.get_data_parallel_world_size()
299+
else:
300+
dp_world_size = 1
301+
302+
# Apply data parallel scaling
303+
max_concurrent_rollouts = max(
304+
1, self.max_concurrent_rollouts // dp_world_size
305+
)
306+
consumer_batch_size = max(1, self.consumer_batch_size // dp_world_size)
307+
308+
self.staleness_manager = StalenessManager(
309+
max_concurrent_rollouts=max_concurrent_rollouts,
310+
consumer_batch_size=consumer_batch_size,
311+
max_staleness=self.config.max_head_offpolicyness,
312+
)
295313

296-
self.rollout_tasks: Dict[str, _RolloutTask] = {}
297314
self.rollout_thread = threading.Thread(
298315
target=self._rollout_thread, daemon=True
299316
) # set daemon=True to automatically exit when error occurs
@@ -304,17 +321,8 @@ def destroy(self):
304321
self.rollout_thread.join()
305322

306323
def get_capacity(self):
307-
with self.lock:
308-
max_concurrent_rollouts = max(
309-
1, self.max_concurrent_rollouts // self.dp_world_size
310-
)
311-
capacity = max_concurrent_rollouts - len(self.rollout_tasks)
312-
# Staleness control
313-
version = self.inference_engine.get_version()
314-
ofp = self.config.max_head_offpolicyness
315-
sample_cnt = self.rollout_stat.accepted + self.rollout_stat.running
316-
consumer_bs = max(1, self.config.consumer_batch_size // self.dp_world_size)
317-
capacity = min(capacity, (ofp + version + 1) * consumer_bs - sample_cnt)
324+
version = self.inference_engine.get_version()
325+
capacity = self.staleness_manager.get_capacity(version)
318326
return capacity
319327

320328
def _rollout_thread(self):
@@ -325,14 +333,13 @@ def _rollout_thread(self):
325333
traceback.print_exc()
326334

327335
async def _rollout_thread_async(self):
328-
rollout_tasks = self.rollout_tasks
336+
rollout_tasks: Dict[str, _RolloutTask] = {}
329337
rid = 0
330338
try:
331339
while not self.exiting.is_set():
332340
# Check capacity
333341
capacity = self.get_capacity()
334342
# Create new rollout task
335-
self.lock.acquire()
336343
while (
337344
capacity > 0
338345
and not self.paused.is_set()
@@ -348,19 +355,19 @@ async def _rollout_thread_async(self):
348355
rollout_tasks[str(rid)] = _RolloutTask(
349356
create_time=time.monotonic_ns(), task=task, task_input=x
350357
)
351-
self.rollout_stat.submitted += 1
352-
self.rollout_stat.running += 1
358+
# Notify staleness manager
359+
self.staleness_manager.on_rollout_submitted()
353360
if self.config.enable_rollout_tracing:
361+
stat = self.staleness_manager.get_stats()
354362
self.logger.info(
355363
f"Submit rollout rid {rid}. "
356-
f"Submit: {self.rollout_stat.submitted}, "
357-
f"running: {self.rollout_stat.running}, "
358-
f"accepted: {self.rollout_stat.accepted}."
364+
f"Submit: {stat.submitted}, "
365+
f"running: {stat.running}, "
366+
f"accepted: {stat.accepted}."
359367
)
360368
capacity -= 1
361369
rid += 1
362370
tasks = [x.task for x in rollout_tasks.values()]
363-
self.lock.release()
364371

365372
# Wait for rollout completion
366373
done = []
@@ -396,26 +403,25 @@ async def _rollout_thread_async(self):
396403
)
397404
assert traj is None or isinstance(traj, dict), traj
398405
task_rid = task.get_name()
399-
with self.lock:
400-
task_obj = rollout_tasks.pop(task_rid)
401-
self.rollout_stat.accepted += 1
402-
self.rollout_stat.running -= 1
403-
if self.config.enable_rollout_tracing:
404-
self.logger.info(
405-
f"Finish rollout {task_rid}. "
406-
f"Submit: {self.rollout_stat.submitted}, "
407-
f"running: {self.rollout_stat.running}, "
408-
f"accepted: {self.rollout_stat.accepted}."
409-
)
406+
task_obj = rollout_tasks.pop(task_rid)
410407

411408
task_input = task_obj.task_input
412-
if traj is not None and (
409+
# Check if trajectory should be accepted
410+
should_accept_traj = traj is not None and (
413411
task_input.should_accept is None
414412
or task_input.should_accept(traj)
415-
):
413+
)
414+
415+
if should_accept_traj:
416+
# Notify staleness manager of accepted rollout
417+
self.staleness_manager.on_rollout_accepted()
416418
if self.config.enable_rollout_tracing:
419+
stat = self.staleness_manager.get_stats()
417420
self.logger.info(
418-
f"Accept rollout result of task {task_rid}."
421+
f"Finish and accept rollout {task_rid}. "
422+
f"Submit: {stat.submitted}, "
423+
f"running: {stat.running}, "
424+
f"accepted: {stat.accepted}."
419425
)
420426
try:
421427
self.output_queue.put_nowait(
@@ -426,24 +432,30 @@ async def _rollout_thread_async(self):
426432
"Output queue full. Please increase queue_size."
427433
)
428434
else:
435+
# Rollout completed but was rejected
436+
# Only decrement running count since it was never accepted
437+
self.staleness_manager.on_rollout_rejected()
429438
if self.config.enable_rollout_tracing:
430-
self.logger.info(f"Rollout is rejected.")
431-
with self.lock:
432-
self.rollout_stat.accepted -= 1
439+
stat = self.staleness_manager.get_stats()
440+
self.logger.info(
441+
f"Finish but reject rollout {task_rid}. "
442+
f"Submit: {stat.submitted}, "
443+
f"running: {stat.running}, "
444+
f"accepted: {stat.accepted}."
445+
)
433446

434447
await asyncio.sleep(1)
435448
except Exception:
436449
traceback.print_exc()
437450
finally:
438451
# Cancel remaining tasks
439-
with self.lock:
440-
for task_obj in rollout_tasks.values():
441-
if not task_obj.task.done():
442-
task_obj.task.cancel()
443-
try:
444-
await task_obj.task
445-
except asyncio.CancelledError:
446-
pass
452+
for task_obj in rollout_tasks.values():
453+
if not task_obj.task.done():
454+
task_obj.task.cancel()
455+
try:
456+
await task_obj.task
457+
except asyncio.CancelledError:
458+
pass
447459

448460
def submit(
449461
self,

areal/core/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
"""Core components for AREAL."""

areal/core/staleness_manager.py

Lines changed: 144 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,144 @@
1+
"""Staleness-aware capacity manager for rollout generation.
2+
3+
This module provides the StalenessManager class which manages capacity
4+
and staleness constraints for asynchronous rollout generation in RL training.
5+
"""
6+
7+
from threading import Lock
8+
9+
from areal.api.io_struct import RolloutStat
10+
11+
12+
class StalenessManager:
13+
"""Manages rollout capacity based on staleness and concurrency constraints.
14+
15+
The manager ensures that:
16+
1. The number of concurrent rollouts doesn't exceed the configured maximum
17+
2. Rollouts don't become too stale (off-policy) by limiting acceptance based on
18+
the current model version and maximum allowed offpolicyness
19+
20+
Parameters
21+
----------
22+
max_concurrent_rollouts : int
23+
Maximum number of concurrent rollouts allowed
24+
consumer_batch_size : int
25+
Expected batch size for consuming rollouts during training
26+
max_staleness : int
27+
Maximum allowed offpolicyness (version difference) for rollouts
28+
"""
29+
30+
def __init__(
31+
self,
32+
max_concurrent_rollouts: int,
33+
consumer_batch_size: int,
34+
max_staleness: int,
35+
):
36+
"""Initialize the staleness manager.
37+
38+
Parameters
39+
----------
40+
max_concurrent_rollouts : int
41+
Maximum number of concurrent rollouts allowed
42+
consumer_batch_size : int
43+
Expected batch size for consuming rollouts during training
44+
max_staleness : int
45+
Maximum allowed offpolicyness (version difference) for rollouts
46+
"""
47+
self.max_concurrent_rollouts = max_concurrent_rollouts
48+
self.consumer_batch_size = consumer_batch_size
49+
self.max_staleness = max_staleness
50+
51+
# Thread-safe access to rollout statistics
52+
self.lock = Lock()
53+
self.rollout_stat = RolloutStat()
54+
55+
def get_capacity(self, current_version: int) -> int:
56+
"""Calculate available capacity for new rollouts.
57+
58+
This method considers both concurrency limits and staleness constraints
59+
to determine how many new rollouts can be accepted.
60+
61+
The capacity calculation ensures:
62+
1. The number of running rollouts doesn't exceed max_concurrent_rollouts
63+
2. Samples don't become too stale by limiting based on:
64+
- current_version: The current model version
65+
- max_staleness: Maximum allowed version difference
66+
- consumer_batch_size: Expected batch size for training
67+
68+
Parameters
69+
----------
70+
current_version : int
71+
The current version of the model weights
72+
73+
Returns
74+
-------
75+
int
76+
Number of new rollout slots available. Can be negative if over capacity.
77+
78+
Notes
79+
-----
80+
The staleness control formula is:
81+
max_samples = (max_staleness + current_version + 1) * consumer_batch_size
82+
capacity = min(concurrency_limit, max_samples - current_samples)
83+
84+
This ensures that by the time samples are consumed, they won't exceed
85+
the maximum allowed staleness.
86+
"""
87+
with self.lock:
88+
# Calculate concurrency-based capacity
89+
max_concurrent_rollouts = max(1, self.max_concurrent_rollouts)
90+
concurrency_capacity = max_concurrent_rollouts - self.rollout_stat.running
91+
92+
# Calculate staleness-based capacity
93+
ofp = self.max_staleness
94+
sample_cnt = self.rollout_stat.accepted + self.rollout_stat.running
95+
consumer_bs = max(1, self.consumer_batch_size)
96+
staleness_capacity = (ofp + current_version + 1) * consumer_bs - sample_cnt
97+
98+
# Return the minimum of both constraints
99+
capacity = min(concurrency_capacity, staleness_capacity)
100+
return capacity
101+
102+
def on_rollout_submitted(self) -> None:
103+
"""Callback when a rollout is submitted for execution.
104+
105+
Thread-safe method to increment the submitted and running counters.
106+
"""
107+
with self.lock:
108+
self.rollout_stat.submitted += 1
109+
self.rollout_stat.running += 1
110+
111+
def on_rollout_accepted(self) -> None:
112+
"""Callback when a rollout completes successfully and is accepted.
113+
114+
Thread-safe method to increment accepted counter and decrement running counter.
115+
"""
116+
with self.lock:
117+
self.rollout_stat.accepted += 1
118+
self.rollout_stat.running -= 1
119+
120+
def on_rollout_rejected(self) -> None:
121+
"""Callback when a rollout completes but is rejected.
122+
123+
Thread-safe method to decrement running counter only.
124+
This is called when a trajectory is filtered out by should_accept or
125+
when the workflow returns None. The rollout was never added to accepted,
126+
so we only need to decrement running.
127+
"""
128+
with self.lock:
129+
self.rollout_stat.running -= 1
130+
131+
def get_stats(self) -> RolloutStat:
132+
"""Get a snapshot of current rollout statistics.
133+
134+
Returns
135+
-------
136+
RolloutStat
137+
Current rollout statistics (submitted, accepted, running)
138+
"""
139+
with self.lock:
140+
return RolloutStat(
141+
submitted=self.rollout_stat.submitted,
142+
accepted=self.rollout_stat.accepted,
143+
running=self.rollout_stat.running,
144+
)

0 commit comments

Comments
 (0)