Skip to content

Commit 8d41ab3

Browse files
authored
refactor: use callbacks to implement xccl weigh transfer and avoid busy waiting during rollout (#769)
This PR implements RolloutCallback that redirects the method calls of RemoteInferenceEngine inside TrainEngine to RolloutController. It has two major benefits: 1. Updating weights with controllers with pass through method calls through callbacks to the engine. As a result, we can avoid the code duplication of update weights logic in engine and controller. 2. We can use callbacks to let the engine inform when a rollout is finished, thus avoiding the original busy waiting logic. Code validated with local and slurm scheduler with FSDP. Future works: - Test with RayScheduler and remove the hard-coded branching logic (e.g., if 'rollout' in role: ...) - Test with FSDP + TP/CP and vLLM. - Test with the Megatron backend until the previous end-to-end experiment (e.g., tau2-bench) runs fine. Commit history: * implement distributed xccl weight update and rollout waiting as callbacks * revert unused functions * add rollout callback * resolve pr comments * migrate to single-controller integration test * fix tests * fix tests * fix * fix gemini comment * fix * fix pytest skipif * use properly cleaned up process pool executor for callbacks * fix vllm * minor fix
1 parent a450f23 commit 8d41ab3

33 files changed

Lines changed: 779 additions & 555 deletions

areal/api/cli_args.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1096,7 +1096,7 @@ class InferenceEngineConfig:
10961096
metadata={"help": "Request scheduling policy", "choices": ["round_robin"]},
10971097
)
10981098
setup_timeout: float = field(
1099-
default=120.0,
1099+
default=300.0,
11001100
metadata={
11011101
"help": "Timeout in seconds of connecting to remote servers or launching local servers."
11021102
},
@@ -1382,7 +1382,7 @@ class ClusterSpecConfig:
13821382
class SchedulerConfig:
13831383
"""Configuration for worker scheduling. Used in the single-controller mode. Experimental."""
13841384

1385-
type: str = field(default="local")
1385+
type: str | None = field(default=None)
13861386
endpoint: str = field(default="http://localhost:8081")
13871387
deploy_mode: str = field(default="separation")
13881388
functioncall_service_domain: str = field(default="http://localhost:8080")

areal/api/io_struct.py

Lines changed: 5 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
from areal.api.cli_args import GenerationHyperparameters
1515
from areal.platforms import current_platform
1616
from areal.utils import logging
17-
from areal.utils.network import find_free_ports, gethostip
1817

1918
if TYPE_CHECKING:
2019
from transformers import AutoProcessor
@@ -126,9 +125,9 @@ class WeightUpdateMeta:
126125
path: str | None = None
127126
alloc_mode: AllocationMode | None = None
128127

129-
nccl_master_address: str = "127.0.0.1"
130-
nccl_master_port: int = 29500
131-
nccl_group_name: str = "update_weight_group"
128+
nccl_master_address: str | None = None
129+
nccl_master_port: int | None = None
130+
nccl_group_name: str | None = None
132131
weight_chunked_mem_mb: int = 1024
133132

134133
use_lora: bool = False
@@ -172,35 +171,27 @@ def from_disk(
172171
def from_megatron_xccl(
173172
cls,
174173
allocation_mode: AllocationMode,
175-
nccl_group_name: str = "update_weight_group",
176174
weight_chunked_mem_mb: int = 1024,
177175
):
178176
return cls(
179-
type=current_platform.communication_backend,
177+
type="xccl",
180178
alloc_mode=allocation_mode,
181-
nccl_master_address=gethostip(),
182-
nccl_master_port=find_free_ports(1)[0],
183-
nccl_group_name=nccl_group_name,
184179
weight_chunked_mem_mb=weight_chunked_mem_mb,
185180
)
186181

187182
@classmethod
188183
def from_fsdp_xccl(
189184
cls,
190185
allocation_mode: AllocationMode,
191-
nccl_group_name: str = "update_weight_group",
192186
weight_chunked_mem_mb: int = 1024,
193187
use_lora: bool = False,
194188
lora_name: str = "",
195189
lora_int_id: int = 0,
196190
base_model_name: str = "",
197191
):
198192
return cls(
199-
type=current_platform.communication_backend,
193+
type="xccl",
200194
alloc_mode=allocation_mode,
201-
nccl_master_address=gethostip(),
202-
nccl_master_port=find_free_ports(1)[0],
203-
nccl_group_name=nccl_group_name,
204195
weight_chunked_mem_mb=weight_chunked_mem_mb,
205196
use_lora=use_lora,
206197
lora_name=lora_name,
Lines changed: 211 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,211 @@
1+
import atexit
2+
import threading
3+
from concurrent.futures import Future, ThreadPoolExecutor
4+
from dataclasses import dataclass
5+
from typing import Any
6+
7+
import requests
8+
9+
from areal.api.io_struct import ParamSpec, WeightUpdateMeta
10+
from areal.scheduler.rpc.serialization import serialize_value
11+
from areal.utils import logging
12+
13+
logger = logging.getLogger(__name__)
14+
15+
# Lazy-initialized thread pool for async HTTP requests
16+
_executor: ThreadPoolExecutor | None = None
17+
_executor_lock = threading.Lock()
18+
19+
20+
def _get_executor() -> ThreadPoolExecutor:
21+
"""Get or create the shared thread pool executor."""
22+
global _executor
23+
if _executor is None:
24+
with _executor_lock:
25+
if _executor is None:
26+
_executor = ThreadPoolExecutor(
27+
max_workers=4, thread_name_prefix="rollout_callback"
28+
)
29+
# Register cleanup on process exit
30+
atexit.register(_shutdown_executor)
31+
return _executor
32+
33+
34+
def _shutdown_executor() -> None:
35+
"""Shutdown the shared thread pool executor if it exists.
36+
37+
Called via atexit at process exit, when no other threads should be
38+
accessing the executor.
39+
"""
40+
global _executor
41+
if _executor is not None:
42+
_executor.shutdown(wait=False)
43+
_executor = None
44+
45+
46+
@dataclass
47+
class RolloutCallback:
48+
"""Callback interface for train workers to coordinate with TrainController.
49+
50+
This class acts as a proxy that train engines use to trigger operations on
51+
the inference side via HTTP callbacks to the TrainController. The controller
52+
then forwards these to the RolloutController.
53+
54+
IMPORTANT: Methods that return Future must be non-blocking to avoid deadlocks.
55+
NCCL operations are collective - both train and inference sides must participate
56+
concurrently. If these methods blocked, the train side couldn't start its NCCL
57+
operations while waiting for the inference side, causing a deadlock.
58+
"""
59+
60+
controller_addr: str
61+
request_timeout: float = 600.0
62+
63+
def _post(self, endpoint: str, payload: dict[str, Any] | None = None) -> dict:
64+
"""Make synchronous HTTP POST to controller callback endpoint.
65+
66+
Parameters
67+
----------
68+
endpoint : str
69+
The callback endpoint (e.g., "/callback/init_weights_group")
70+
payload : dict, optional
71+
JSON payload to send
72+
73+
Returns
74+
-------
75+
dict
76+
Response JSON from controller
77+
"""
78+
url = f"http://{self.controller_addr}{endpoint}"
79+
try:
80+
resp = requests.post(
81+
url,
82+
json=payload or {},
83+
timeout=self.request_timeout,
84+
)
85+
resp.raise_for_status()
86+
return resp.json()
87+
except requests.RequestException as e:
88+
logger.error(f"Callback to {url} failed: {e}")
89+
raise
90+
91+
def _post_nowait(
92+
self, endpoint: str, payload: dict[str, Any] | None = None
93+
) -> Future[dict]:
94+
"""Make asynchronous HTTP POST to controller callback endpoint.
95+
96+
This method submits the HTTP request to a background thread and returns
97+
immediately with a Future. This is critical for NCCL coordination where
98+
both train and inference sides must participate in collective operations
99+
concurrently.
100+
101+
Parameters
102+
----------
103+
endpoint : str
104+
The callback endpoint
105+
payload : dict, optional
106+
JSON payload to send
107+
108+
Returns
109+
-------
110+
Future[dict]
111+
Future that completes when the HTTP response is received
112+
"""
113+
return _get_executor().submit(self._post, endpoint, payload)
114+
115+
def _post_nowait_void(
116+
self, endpoint: str, payload: dict[str, Any] | None = None
117+
) -> Future[None]:
118+
"""Make an async POST request and return a Future that resolves to None."""
119+
http_future = self._post_nowait(endpoint, payload)
120+
result_future: Future[None] = Future()
121+
122+
def on_done(f: Future[dict]):
123+
try:
124+
f.result() # Raise any exception from the HTTP request
125+
result_future.set_result(None)
126+
except Exception as e:
127+
result_future.set_exception(e)
128+
129+
http_future.add_done_callback(on_done)
130+
return result_future
131+
132+
def init_weights_update_group(self, meta: WeightUpdateMeta) -> Future[None]:
133+
"""Callback to controller to initialize weight update group on inference side.
134+
135+
This method is NON-BLOCKING. It starts the HTTP request in a background
136+
thread and returns immediately. This allows the train engine to proceed
137+
with creating its side of the NCCL group concurrently.
138+
139+
Parameters
140+
----------
141+
meta : WeightUpdateMeta
142+
Weight update metadata
143+
144+
Returns
145+
-------
146+
Future[None]
147+
Future that completes when controller finishes initialization
148+
"""
149+
payload = {"meta": serialize_value(meta)}
150+
return self._post_nowait_void("/callback/init_weights_group", payload)
151+
152+
def update_weights_from_distributed(
153+
self, meta: WeightUpdateMeta, param_specs: list[ParamSpec]
154+
) -> Future[None]:
155+
"""Callback to controller to receive weights on inference side.
156+
157+
This method is NON-BLOCKING. The train engine calls this to notify the
158+
inference side to start receiving NCCL broadcasts, then immediately
159+
starts broadcasting. Both sides participate in the NCCL collective
160+
concurrently.
161+
162+
Parameters
163+
----------
164+
meta : WeightUpdateMeta
165+
Weight update metadata
166+
param_specs : list[ParamSpec]
167+
List of parameter specifications for this update batch
168+
169+
Returns
170+
-------
171+
Future[None]
172+
Future that completes when controller finishes receiving weights
173+
"""
174+
payload = {
175+
"meta": serialize_value(meta),
176+
"param_specs": serialize_value(param_specs),
177+
}
178+
return self._post_nowait_void("/callback/update_weights_xccl", payload)
179+
180+
def update_weights_from_disk(self, meta: WeightUpdateMeta) -> Future[None]:
181+
"""Callback to controller to load weights from disk on inference side.
182+
183+
This method is NON-BLOCKING for consistency, though disk-based updates
184+
don't have the same NCCL coordination requirements.
185+
186+
Parameters
187+
----------
188+
meta : WeightUpdateMeta
189+
Weight update metadata with path information
190+
191+
Returns
192+
-------
193+
Future[None]
194+
Future that completes when controller finishes loading weights
195+
"""
196+
payload = {"meta": serialize_value(meta)}
197+
return self._post_nowait_void("/callback/update_weights_disk", payload)
198+
199+
def pause_generation(self) -> None:
200+
"""Callback to controller to pause inference generation.
201+
202+
This is synchronous as it must complete before weight updates begin.
203+
"""
204+
self._post("/callback/pause_generation")
205+
206+
def continue_generation(self) -> None:
207+
"""Callback to controller to resume inference generation.
208+
209+
This is synchronous as it should complete before returning control.
210+
"""
211+
self._post("/callback/continue_generation")

0 commit comments

Comments
 (0)