1919from areal .controller .rollout_controller import RolloutController
2020from areal .scheduler .rpc .rtensor import RTensor
2121from areal .utils import logging , stats_tracker
22+ from areal .utils .concurrent import run_async_task
2223from areal .utils .network import find_free_ports
2324
2425logger = logging .getLogger ("TrainController" )
@@ -132,7 +133,6 @@ def initialize(
132133 tasks = list (self .config .scheduling_spec ),
133134 scheduling_strategy = self .config .scheduling_strategy ,
134135 role = self ._worker_role ,
135- shared_placement_group = True ,
136136 )
137137
138138 # Create workers via scheduler
@@ -164,21 +164,16 @@ def initialize(
164164 engine_class = self .train_engine
165165
166166 # Create and initialize engines on workers
167- self ._run_async_task (
168- self ._async_create_engines (
169- f"{ engine_class .__module__ } .{ engine_class .__name__ } "
170- )
167+ run_async_task (
168+ self ._async_create_engines ,
169+ f"{ engine_class .__module__ } .{ engine_class .__name__ } " ,
171170 )
172- self . _run_async_task (self ._async_initialize_engines ( ft_spec , ** kwargs ) )
171+ run_async_task (self ._async_initialize_engines , ft_spec , ** kwargs )
173172
174173 # Identify DP head workers
175174 self ._identify_dp_heads ()
176175 logger .info ("TrainController initialization complete" )
177176
178- def _run_async_task (self , task ):
179- """Run an async task synchronously."""
180- return asyncio .run (task )
181-
182177 def _engine_name (self , rank : int ) -> str :
183178 """Generate engine name for a worker rank.
184179
@@ -255,7 +250,7 @@ async def _get_dp_head():
255250 ]
256251 return await asyncio .gather (* tasks )
257252
258- self .workers_is_dp_head = self . _run_async_task (_get_dp_head () )
253+ self .workers_is_dp_head = run_async_task (_get_dp_head )
259254
260255 def destroy (self ):
261256 """Destroy the controller and release GPU memory of models.
@@ -280,7 +275,7 @@ async def _destroy_all_engines():
280275 ]
281276 await asyncio .gather (* tasks , return_exceptions = True )
282277
283- self . _run_async_task (_destroy_all_engines () )
278+ run_async_task (_destroy_all_engines )
284279 logger .info ("Engines destroyed" )
285280 except Exception as e :
286281 logger .error (f"Error destroying engines: { e } " )
@@ -306,8 +301,8 @@ def _custom_function_call(self, method: str, *args, **kwargs):
306301 dp_split_args , dp_split_kwargs , group_indices = self ._dispatch_inputs (
307302 * args , ** kwargs
308303 )
309- results = self . _run_async_task (
310- self ._call_with_dispatched_inputs ( method , dp_split_args , dp_split_kwargs )
304+ results = run_async_task (
305+ self ._call_with_dispatched_inputs , method , dp_split_args , dp_split_kwargs
311306 )
312307 # Filter to only keep results from DP head workers
313308 results = [r for idx , r in enumerate (results ) if self .workers_is_dp_head [idx ]]
@@ -527,7 +522,7 @@ async def _call():
527522 ]
528523 return await asyncio .gather (* tasks )
529524
530- self . _run_async_task (_call () )
525+ run_async_task (_call )
531526
532527 def save_perf_tracer (self , step : int | None = None , force : bool = False ) -> None :
533528 self ._custom_function_call ("save_perf_tracer" , step = step , force = force )
@@ -588,4 +583,4 @@ async def _async_clear_batches(self, *targets: dict[str, RTensor]):
588583
589584 def clear_batches (self , * targets : dict [str , RTensor ]):
590585 """Clear distributed batch shards from workers to free memory."""
591- self . _run_async_task (self ._async_clear_batches ( * targets ) )
586+ run_async_task (self ._async_clear_batches , * targets )
0 commit comments