@@ -306,6 +306,26 @@ def main():
306306 args .global_rank )
307307 causal_lm_model_to_fp32_loss (model )
308308
309+ # Copied from ../step2_reward_model_finetuning/main.py.
310+ # Model bigscience/bloom-560m has large variance at ln_f.weight parameter
311+ # This makes bf16 finetuning hard.
312+ # In general, since we are replacing the model head, it makes sense to reset
313+ # the LN that precedes it.
314+ force_optimize_params = []
315+ if "bigscience/bloom-" in args .model_name_or_path :
316+ zero_init_enabled = (args .zero_stage == 3 )
317+ params = [
318+ model .rwtranrsformer .ln_f .weight , model .rwtranrsformer .ln_f .bias
319+ ]
320+ with deepspeed .zero .GatheredParameters (params ,
321+ modifier_rank = 0 ,
322+ enabled = zero_init_enabled ):
323+ if deepspeed .comm .get_rank () == 0 or not zero_init_enabled :
324+ torch .nn .init .ones_ (model .rwtransformer .ln_f .weight )
325+ torch .nn .init .zeros_ (model .rwtransformer .ln_f .bias )
326+ force_optimize_params .extend (
327+ ['rwtransformer.ln_f.weight' , 'rwtransformer.ln_f.bias' ])
328+
309329 if args .lora_dim > 0 :
310330 model = convert_linear_layer_to_lora (model , args .lora_module_name ,
311331 args .lora_dim )
@@ -372,7 +392,7 @@ def evaluation(model, ref_model, tokenizer, eval_dataloader):
372392 logits = args .beta * ((chosen_logps - ref_chosen_logps ) -
373393 (rejected_logps - ref_rejected_logps ))
374394 loss = (- torch .nn .functional .logsigmoid (logits ) * (1 - args .label_smoothing ) - \
375- torch .nn .functional .logsigmoid (- logits ) * args .label_smoothing ).mean (0 )
395+ torch .nn .functional .logsigmoid (- logits ) * args .label_smoothing ).mean (0 )
376396 losses += loss .float ()
377397 losses = losses / (step + 1 )
378398 try :
@@ -419,7 +439,7 @@ def evaluation(model, ref_model, tokenizer, eval_dataloader):
419439 # Train!
420440 print_rank_0 ("***** Running training *****" , args .global_rank )
421441 print_rank_0 (
422- f"***** Evaluating rewards, Epoch { 0 } /{ args .num_train_epochs } *****" ,
442+ f"***** Evaluating rewards, Epoch { 1 } /{ args .num_train_epochs } *****" ,
423443 args .global_rank )
424444 chosen_rewards , rejected_rewards , eval_loss = evaluation (
425445 model , ref_model , tokenizer , eval_dataloader )
@@ -466,7 +486,7 @@ def evaluation(model, ref_model, tokenizer, eval_dataloader):
466486 logits = args .beta * ((chosen_logps - ref_chosen_logps ) -
467487 (rejected_logps - ref_rejected_logps ))
468488 loss = (- torch .nn .functional .logsigmoid (logits ) * (1 - args .label_smoothing ) - \
469- torch .nn .functional .logsigmoid (- logits ) * args .label_smoothing ).mean (0 )
489+ torch .nn .functional .logsigmoid (- logits ) * args .label_smoothing ).mean (0 )
470490 if args .print_loss :
471491 print (
472492 f"Epoch: { epoch } , Step: { step } , Rank: { torch .distributed .get_rank ()} , loss = { loss } "
@@ -478,7 +498,7 @@ def evaluation(model, ref_model, tokenizer, eval_dataloader):
478498 print_throughput (model .model , args , end - start ,
479499 args .global_rank )
480500
481- # Evaluate perplexity on the validation set.
501+ # Evaluate rewards on the validation set.
482502 print_rank_0 (
483503 f"***** Evaluating rewards, Epoch { epoch + 1 } /{ args .num_train_epochs } *****" ,
484504 args .global_rank )
0 commit comments