Skip to content

Commit 5e1f1be

Browse files
committed
use subprocess to fork colocated workers and centralize the usage of threadpoolexecutor
1 parent 9daf105 commit 5e1f1be

23 files changed

Lines changed: 1867 additions & 273 deletions

areal/api/cli_args.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -630,6 +630,14 @@ class SchedulingStrategy:
630630
target: str | None = field(
631631
default=None, metadata={"help": "The target role to be colocated with"}
632632
)
633+
fork: bool = field(
634+
default=True,
635+
metadata={
636+
"help": "When True with colocation, the target worker spawns a new "
637+
"process on the same node/GPUs instead of sharing its process. "
638+
"Provides process isolation while sharing GPU resources."
639+
},
640+
)
633641

634642

635643
@dataclass

areal/api/scheduler_api.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,6 @@ class Job:
3636
replicas: int = 0
3737
tasks: list[SchedulingSpec] = field(default_factory=list)
3838
scheduling_strategy: SchedulingStrategy = field(default_factory=SchedulingStrategy)
39-
shared_placement_group: bool = False
4039

4140

4241
class Scheduler(abc.ABC):

areal/controller/rollout_callback.py

Lines changed: 3 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,4 @@
1-
import atexit
2-
import threading
3-
from concurrent.futures import Future, ThreadPoolExecutor
1+
from concurrent.futures import Future
42
from dataclasses import dataclass
53
from typing import Any
64

@@ -9,39 +7,10 @@
97
from areal.api.io_struct import ParamSpec, WeightUpdateMeta
108
from areal.scheduler.rpc.serialization import serialize_value
119
from areal.utils import logging
10+
from areal.utils.concurrent import get_executor
1211

1312
logger = logging.getLogger(__name__)
1413

15-
# Lazy-initialized thread pool for async HTTP requests
16-
_executor: ThreadPoolExecutor | None = None
17-
_executor_lock = threading.Lock()
18-
19-
20-
def _get_executor() -> ThreadPoolExecutor:
21-
"""Get or create the shared thread pool executor."""
22-
global _executor
23-
if _executor is None:
24-
with _executor_lock:
25-
if _executor is None:
26-
_executor = ThreadPoolExecutor(
27-
max_workers=4, thread_name_prefix="rollout_callback"
28-
)
29-
# Register cleanup on process exit
30-
atexit.register(_shutdown_executor)
31-
return _executor
32-
33-
34-
def _shutdown_executor() -> None:
35-
"""Shutdown the shared thread pool executor if it exists.
36-
37-
Called via atexit at process exit, when no other threads should be
38-
accessing the executor.
39-
"""
40-
global _executor
41-
if _executor is not None:
42-
_executor.shutdown(wait=False)
43-
_executor = None
44-
4514

4615
@dataclass
4716
class RolloutCallback:
@@ -110,7 +79,7 @@ def _post_nowait(
11079
Future[dict]
11180
Future that completes when the HTTP response is received
11281
"""
113-
return _get_executor().submit(self._post, endpoint, payload)
82+
return get_executor().submit(self._post, endpoint, payload)
11483

11584
def _post_nowait_void(
11685
self, endpoint: str, payload: dict[str, Any] | None = None

areal/controller/rollout_controller.py

Lines changed: 6 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
from areal.core.workflow_executor import BatchTaskDispatcher, TaskIdGenerator
3030
from areal.scheduler.rpc.serialization import deserialize_value
3131
from areal.utils import logging, perf_tracer
32+
from areal.utils.concurrent import run_async_task
3233
from areal.utils.data import concat_padded_tensors, cycle_dataloader
3334
from areal.utils.dynamic_import import import_from_string
3435
from areal.utils.network import find_free_ports, gethostip
@@ -139,18 +140,11 @@ def initialize(
139140
tasks=[sch_spec for _ in range(alloc_mode.gen.dp_size)],
140141
scheduling_strategy=self.config.scheduling_strategy,
141142
role=self._worker_role,
142-
shared_placement_group=False,
143143
)
144144

145-
# Use asyncio.run to call async scheduler methods synchronously
146-
asyncio.run(
147-
self._async_initialize(
148-
job,
149-
server_args,
150-
server_infos,
151-
*args,
152-
**kwargs,
153-
)
145+
# Call async scheduler methods synchronously
146+
run_async_task(
147+
self._async_initialize, job, server_args, server_infos, *args, **kwargs
154148
)
155149

156150
# Initialize staleness manager for global capacity control
@@ -385,7 +379,7 @@ def _resolve_task_future(self, task_id: int):
385379
future.get_loop().call_soon_threadsafe(future.set_result, None)
386380

387381
def _collective_rpc(self, method: str, *args, **kwargs) -> list[Any]:
388-
return asyncio.run(self._collective_rpc_async(method, *args, **kwargs))
382+
return run_async_task(self._collective_rpc_async, method, *args, **kwargs)
389383

390384
async def _collective_rpc_async(self, method: str, *args, **kwargs) -> list[Any]:
391385
tasks = [
@@ -764,7 +758,7 @@ async def _call():
764758
]
765759
return await asyncio.gather(*tasks)
766760

767-
asyncio.run(_call())
761+
run_async_task(_call)
768762

769763
def save_perf_tracer(self, step: int | None = None, force: bool = False) -> None:
770764
self._collective_rpc("save_perf_tracer", step=step, force=force)

areal/controller/train_controller.py

Lines changed: 11 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from areal.controller.rollout_controller import RolloutController
2020
from areal.scheduler.rpc.rtensor import RTensor
2121
from areal.utils import logging, stats_tracker
22+
from areal.utils.concurrent import run_async_task
2223
from areal.utils.network import find_free_ports
2324

2425
logger = logging.getLogger("TrainController")
@@ -132,7 +133,6 @@ def initialize(
132133
tasks=list(self.config.scheduling_spec),
133134
scheduling_strategy=self.config.scheduling_strategy,
134135
role=self._worker_role,
135-
shared_placement_group=True,
136136
)
137137

138138
# Create workers via scheduler
@@ -164,21 +164,16 @@ def initialize(
164164
engine_class = self.train_engine
165165

166166
# Create and initialize engines on workers
167-
self._run_async_task(
168-
self._async_create_engines(
169-
f"{engine_class.__module__}.{engine_class.__name__}"
170-
)
167+
run_async_task(
168+
self._async_create_engines,
169+
f"{engine_class.__module__}.{engine_class.__name__}",
171170
)
172-
self._run_async_task(self._async_initialize_engines(ft_spec, **kwargs))
171+
run_async_task(self._async_initialize_engines, ft_spec, **kwargs)
173172

174173
# Identify DP head workers
175174
self._identify_dp_heads()
176175
logger.info("TrainController initialization complete")
177176

178-
def _run_async_task(self, task):
179-
"""Run an async task synchronously."""
180-
return asyncio.run(task)
181-
182177
def _engine_name(self, rank: int) -> str:
183178
"""Generate engine name for a worker rank.
184179
@@ -255,7 +250,7 @@ async def _get_dp_head():
255250
]
256251
return await asyncio.gather(*tasks)
257252

258-
self.workers_is_dp_head = self._run_async_task(_get_dp_head())
253+
self.workers_is_dp_head = run_async_task(_get_dp_head)
259254

260255
def destroy(self):
261256
"""Destroy the controller and release GPU memory of models.
@@ -280,7 +275,7 @@ async def _destroy_all_engines():
280275
]
281276
await asyncio.gather(*tasks, return_exceptions=True)
282277

283-
self._run_async_task(_destroy_all_engines())
278+
run_async_task(_destroy_all_engines)
284279
logger.info("Engines destroyed")
285280
except Exception as e:
286281
logger.error(f"Error destroying engines: {e}")
@@ -306,8 +301,8 @@ def _custom_function_call(self, method: str, *args, **kwargs):
306301
dp_split_args, dp_split_kwargs, group_indices = self._dispatch_inputs(
307302
*args, **kwargs
308303
)
309-
results = self._run_async_task(
310-
self._call_with_dispatched_inputs(method, dp_split_args, dp_split_kwargs)
304+
results = run_async_task(
305+
self._call_with_dispatched_inputs, method, dp_split_args, dp_split_kwargs
311306
)
312307
# Filter to only keep results from DP head workers
313308
results = [r for idx, r in enumerate(results) if self.workers_is_dp_head[idx]]
@@ -527,7 +522,7 @@ async def _call():
527522
]
528523
return await asyncio.gather(*tasks)
529524

530-
self._run_async_task(_call())
525+
run_async_task(_call)
531526

532527
def save_perf_tracer(self, step: int | None = None, force: bool = False) -> None:
533528
self._custom_function_call("save_perf_tracer", step=step, force=force)
@@ -588,4 +583,4 @@ async def _async_clear_batches(self, *targets: dict[str, RTensor]):
588583

589584
def clear_batches(self, *targets: dict[str, RTensor]):
590585
"""Clear distributed batch shards from workers to free memory."""
591-
self._run_async_task(self._async_clear_batches(*targets))
586+
run_async_task(self._async_clear_batches, *targets)

areal/core/workflow_executor.py

Lines changed: 2 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
import threading
88
import time
99
from collections.abc import Awaitable, Callable
10-
from concurrent.futures import ThreadPoolExecutor
1110
from dataclasses import dataclass
1211
from typing import TYPE_CHECKING, Any, TypeVar, Generic, Protocol
1312
from collections.abc import Generator
@@ -33,6 +32,7 @@
3332
from areal.core.workflow_context import WorkflowContext
3433
from areal.experimental.openai.types import InteractionWithTokenLogpReward
3534
from areal.utils import logging, perf_tracer, stats_tracker
35+
from areal.utils.concurrent import get_executor
3636
from areal.utils.data import concat_padded_tensors, cycle_dataloader
3737
from areal.utils.dynamic_import import import_from_string
3838
from areal.utils.perf_tracer import trace_perf, trace_session_event
@@ -361,8 +361,6 @@ def __init__(
361361

362362
# Callback support: task_id -> callback_addr
363363
self._task_callbacks: dict[int, str] = {}
364-
# Thread pool for sending callbacks (avoids creating threads per callback)
365-
self._callback_executor: ThreadPoolExecutor | None = None
366364

367365
def _set_thread_exception(self, exc: Exception):
368366
"""Store exception from background thread for fail-fast behavior."""
@@ -407,7 +405,7 @@ def post():
407405
except requests.RequestException as e:
408406
self.logger.error(f"Callback to {addr} failed: {e}")
409407

410-
self._callback_executor.submit(post)
408+
get_executor().submit(post)
411409

412410
def _commit_loop(self) -> None:
413411
"""Producer thread - continuously submits tasks based on capacity."""
@@ -504,9 +502,6 @@ def initialize(self, logger: Logger):
504502
self.runner.initialize(logger=logger)
505503

506504
self._shutdown_event.clear()
507-
self._callback_executor = ThreadPoolExecutor(
508-
max_workers=4, thread_name_prefix="callback"
509-
)
510505

511506
self._commit_thread = threading.Thread(target=self._commit_loop, daemon=True)
512507
self._commit_thread.start()
@@ -538,11 +533,6 @@ def destroy(self):
538533
# Clear pending callbacks to prevent memory leak
539534
self._task_callbacks.clear()
540535

541-
# Shutdown callback thread pool
542-
if self._callback_executor is not None:
543-
self._callback_executor.shutdown(wait=False)
544-
self._callback_executor = None
545-
546536
# Shutdown the async task runner
547537
self.runner.destroy()
548538

0 commit comments

Comments
 (0)