Skip to content

Commit 7c9e470

Browse files
authored
Single-controller LoRA RL fine-tuning with vLLM (#735)
* Working and tested examples for grpo single controller lora using the vllm backend * Cleaned up some debug statements * Updated and tested (performance matched with full RL) as per new design * removed old single controller examples in lora folder as they are not required anymore
1 parent 9497437 commit 7c9e470

3 files changed

Lines changed: 212 additions & 9 deletions

File tree

areal/api/io_struct.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -148,7 +148,7 @@ def from_disk(
148148
use_lora: bool = False,
149149
clear_checkpoint_after_load: bool = True,
150150
lora_name: str = "",
151-
lora_int_id: int = 0,
151+
lora_int_id: int = 1,
152152
base_model_name: str = "",
153153
) -> "WeightUpdateMeta":
154154
from areal.utils.saver import Saver

areal/experimental/trainer/rl.py

Lines changed: 19 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -132,14 +132,25 @@ def __init__(
132132

133133
# Prepare weight update meta and connect to inference engine
134134
if self.config.actor.weight_update_mode == "disk":
135-
self.weight_update_meta = WeightUpdateMeta.from_disk(
136-
experiment_name=config.experiment_name,
137-
trial_name=config.trial_name,
138-
file_root=config.cluster.fileroot,
139-
name="default",
140-
use_lora=config.actor.use_lora,
141-
clear_checkpoint_after_load=True,
142-
)
135+
if config.actor.use_lora:
136+
self.weight_update_meta = WeightUpdateMeta.from_disk(
137+
experiment_name=config.experiment_name,
138+
trial_name=config.trial_name,
139+
file_root=config.cluster.fileroot,
140+
name="default",
141+
clear_checkpoint_after_load=True,
142+
use_lora=config.actor.use_lora,
143+
lora_name=config.gconfig.lora_name,
144+
base_model_name=config.actor.path,
145+
)
146+
else:
147+
self.weight_update_meta = WeightUpdateMeta.from_disk(
148+
experiment_name=config.experiment_name,
149+
trial_name=config.trial_name,
150+
file_root=config.cluster.fileroot,
151+
name="default",
152+
clear_checkpoint_after_load=True,
153+
)
143154
elif self.config.actor.weight_update_mode == "xccl":
144155
# NCCL/XCCL weight update
145156
if self.allocation_mode.train_backend == "megatron":

examples/math/gsm8k_grpo_lora.yaml

Lines changed: 192 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,192 @@
1+
experiment_name: gsm8k-grpo
2+
trial_name: trial0
3+
4+
seed: 1
5+
enable_offload: false
6+
total_train_epochs: 3
7+
tokenizer_path: ${actor.path}
8+
9+
cluster:
10+
n_nodes: 1
11+
n_gpus_per_node: 16
12+
fileroot: /tmp/areal/experiments
13+
name_resolve:
14+
type: nfs
15+
nfs_record_root: /tmp/areal/name_resolve
16+
17+
allocation_mode: vllm:d8p1t1+d8p1t1
18+
19+
20+
scheduler:
21+
type: local
22+
23+
24+
rollout:
25+
experiment_name: ${experiment_name}
26+
trial_name: ${trial_name}
27+
max_concurrent_rollouts: 256
28+
queue_size: null
29+
consumer_batch_size: ${train_dataset.batch_size}
30+
max_head_offpolicyness: 2
31+
enable_rollout_tracing: false
32+
use_lora: true
33+
scheduling_spec: ${actor.scheduling_spec}
34+
35+
gconfig:
36+
n_samples: 4
37+
min_new_tokens: 0
38+
max_new_tokens: 1024
39+
greedy: false
40+
temperature: 1.0
41+
lora_name: "lora-gsm8k"
42+
43+
actor:
44+
experiment_name: ${experiment_name}
45+
trial_name: ${trial_name}
46+
path: Qwen/Qwen3-0.6B
47+
init_from_scratch: false
48+
disable_dropout: true
49+
gradient_checkpointing: true
50+
dtype: bfloat16
51+
mb_spec:
52+
max_tokens_per_mb: 10240
53+
optimizer:
54+
type: adam
55+
lr: 1.70e-4
56+
weight_decay: 0.017
57+
beta1: 0.9
58+
beta2: 0.999
59+
eps: 1e-8
60+
lr_scheduler_type: constant
61+
gradient_clipping: 1.0
62+
warmup_steps_proportion: 0.001
63+
group_size: ${gconfig.n_samples}
64+
eps_clip: 0.4
65+
temperature: ${gconfig.temperature}
66+
reward_scaling: 10.0
67+
reward_bias: -0.5
68+
kl_ctl: 0.0
69+
ppo_n_minibatches: 1
70+
recompute_logprob: true
71+
use_decoupled_loss: true
72+
behav_imp_weight_cap: 5.0
73+
dynamic_sampling: false
74+
reward_norm:
75+
mean_level: group
76+
std_level: group
77+
group_size: ${gconfig.n_samples}
78+
adv_norm:
79+
mean_level: batch
80+
std_level: batch
81+
max_new_tokens: ${gconfig.max_new_tokens}
82+
weight_update_mode: disk
83+
use_lora: ${rollout.use_lora}
84+
peft_type: lora
85+
lora_rank: 16
86+
lora_alpha: 16
87+
target_modules: [all-linear]
88+
scheduling_spec:
89+
- task_type: worker
90+
port_count: 2
91+
gpu: 1
92+
cpu: 4
93+
mem: 32
94+
cmd: python3 -m areal.scheduler.rpc.rpc_server
95+
env_vars: {}
96+
97+
ref:
98+
experiment_name: ${experiment_name}
99+
trial_name: ${trial_name}
100+
path: ${actor.path}
101+
init_from_scratch: false
102+
disable_dropout: true
103+
dtype: ${actor.dtype}
104+
mb_spec:
105+
max_tokens_per_mb: 10240
106+
optimizer: null
107+
scheduling_strategy:
108+
type: colocation
109+
target: actor
110+
scheduling_spec: ${actor.scheduling_spec}
111+
112+
113+
# SGLang
114+
sglang:
115+
model_path: ${actor.path}
116+
random_seed: ${seed}
117+
skip_tokenizer_init: true
118+
dtype: ${actor.dtype}
119+
max_running_requests: null
120+
context_length: 32768
121+
mem_fraction_static: 0.8
122+
123+
# vLLM
124+
vllm:
125+
model: ${actor.path}
126+
seed: ${seed}
127+
skip_tokenizer_init: false
128+
dtype: ${actor.dtype}
129+
max_model_len: 32768
130+
gpu_memory_utilization: 0.8
131+
enable_lora: ${rollout.use_lora}
132+
lora_modules: '{"name": "${gconfig.lora_name}", "path": ./model/Qwen3.0.6B-16rank", "base_model_name": "${actor.path}"}'
133+
enforce_eager: true
134+
135+
# datasets
136+
train_dataset:
137+
batch_size: 256
138+
shuffle: true
139+
pin_memory: true
140+
num_workers: 4
141+
path: openai/gsm8k
142+
type: rl
143+
max_length: 1024
144+
145+
valid_dataset:
146+
batch_size: 256
147+
pin_memory: true
148+
num_workers: 4
149+
path: openai/gsm8k
150+
type: rl
151+
152+
# Utilities
153+
saver:
154+
experiment_name: ${experiment_name}
155+
trial_name: ${trial_name}
156+
fileroot: ${cluster.fileroot}
157+
freq_epochs: 1
158+
freq_steps: null
159+
freq_secs: null
160+
161+
recover:
162+
mode: disabled
163+
experiment_name: ${experiment_name}
164+
trial_name: ${trial_name}
165+
fileroot: ${cluster.fileroot}
166+
freq_epochs: 1
167+
freq_steps: null
168+
freq_secs: 3600
169+
170+
evaluator:
171+
experiment_name: ${experiment_name}
172+
trial_name: ${trial_name}
173+
fileroot: ${cluster.fileroot}
174+
freq_epochs: null
175+
freq_steps: null
176+
freq_secs: null
177+
178+
stats_logger:
179+
experiment_name: ${experiment_name}
180+
trial_name: ${trial_name}
181+
fileroot: ${cluster.fileroot}
182+
wandb:
183+
mode: disabled
184+
185+
186+
perf_tracer:
187+
experiment_name: ${experiment_name}
188+
trial_name: ${trial_name}
189+
fileroot: ${cluster.fileroot}
190+
enabled: false
191+
session_tracer:
192+
enabled: false

0 commit comments

Comments
 (0)