File tree Expand file tree Collapse file tree
applications/DeepSpeed-Chat/training
step1_supervised_finetuning
step2_reward_model_finetuning Expand file tree Collapse file tree Original file line number Diff line number Diff line change @@ -161,6 +161,13 @@ def parse_args():
161161 parser .add_argument ('--only_optimize_lora' ,
162162 action = 'store_true' ,
163163 help = 'Only optimize the LoRA parameters.' )
164+ parser .add_argument (
165+ "--lora_learning_rate" ,
166+ type = float ,
167+ default = 5e-4 ,
168+ help =
169+ "Initial LoRA learning rate (after the potential warmup period) to use."
170+ )
164171 ## Tensorboard logging
165172 parser .add_argument ('--enable_tensorboard' ,
166173 action = 'store_true' ,
@@ -274,7 +281,7 @@ def evaluation(model, eval_dataloader):
274281
275282 # Split weights in two groups, one with weight decay and the other not.
276283 optimizer_grouped_parameters = get_optimizer_grouped_parameters (
277- model , args .weight_decay )
284+ model , args .weight_decay , args . lora_learning_rate )
278285
279286 AdamOptimizer = DeepSpeedCPUAdam if args .offload else FusedAdam
280287 optimizer = AdamOptimizer (optimizer_grouped_parameters ,
Original file line number Diff line number Diff line change @@ -161,6 +161,13 @@ def parse_args():
161161 parser .add_argument ('--only_optimize_lora' ,
162162 action = 'store_true' ,
163163 help = 'Only optimize the LoRA parameters.' )
164+ parser .add_argument (
165+ "--lora_learning_rate" ,
166+ type = float ,
167+ default = 5e-4 ,
168+ help =
169+ "Initial LoRA learning rate (after the potential warmup period) to use."
170+ )
164171 ## Tensorboard logging
165172 parser .add_argument ('--enable_tensorboard' ,
166173 action = 'store_true' ,
@@ -271,7 +278,7 @@ def evaluation_reward(model, eval_dataloader):
271278
272279 # Split weights in two groups, one with weight decay and the other not.
273280 optimizer_grouped_parameters = get_optimizer_grouped_parameters (
274- rm_model , args .weight_decay )
281+ rm_model , args .weight_decay , args . lora_learning_rate )
275282
276283 AdamOptimizer = DeepSpeedCPUAdam if args .offload else FusedAdam
277284 optimizer = AdamOptimizer (optimizer_grouped_parameters ,
Original file line number Diff line number Diff line change @@ -286,6 +286,20 @@ def parse_args():
286286 parser .add_argument ('--only_optimize_lora' ,
287287 action = 'store_true' ,
288288 help = 'Only optimize the LoRA parameters.' )
289+ parser .add_argument (
290+ "--actor_lora_learning_rate" ,
291+ type = float ,
292+ default = 5e-4 ,
293+ help =
294+ "Initial actor LoRA learning rate (after the potential warmup period) to use."
295+ )
296+ parser .add_argument (
297+ "--critic_lora_learning_rate" ,
298+ type = float ,
299+ default = 5e-4 ,
300+ help =
301+ "Initial critic LoRA learning rate (after the potential warmup period) to use."
302+ )
289303 ## Make EMA as an optional feature
290304 parser .add_argument ('--enable_ema' ,
291305 action = 'store_true' ,
Original file line number Diff line number Diff line change @@ -105,7 +105,8 @@ def _init_actor(self, actor_model_name_or_path):
105105 # Optimizer
106106 AdamOptimizer = DeepSpeedCPUAdam if self .args .offload else FusedAdam
107107 optim_params = get_optimizer_grouped_parameters (
108- actor_model , self .args .actor_weight_decay )
108+ actor_model , self .args .actor_weight_decay ,
109+ self .args .actor_lora_learning_rate )
109110 optim = AdamOptimizer (optim_params ,
110111 lr = self .args .actor_learning_rate ,
111112 betas = (0.9 , 0.95 ))
@@ -231,9 +232,10 @@ def _init_critic(self, critic_model_name_or_path):
231232
232233 # Optimizer
233234 AdamOptimizer = DeepSpeedCPUAdam if self .args .offload else FusedAdam
234- optim_pararms = get_optimizer_grouped_parameters (
235- critic_model , self .args .critic_weight_decay )
236- optim = AdamOptimizer (optim_pararms ,
235+ optim_params = get_optimizer_grouped_parameters (
236+ critic_model , self .args .critic_weight_decay ,
237+ self .args .critic_lora_learning_rate )
238+ optim = AdamOptimizer (optim_params ,
237239 lr = self .args .critic_learning_rate ,
238240 betas = (0.9 , 0.95 ))
239241
You can’t perform that action at this time.
0 commit comments