Skip to content

Commit 7574df0

Browse files
author
stceum
committed
Follow upstream fixes.
1 parent 1f424f1 commit 7574df0

1 file changed

Lines changed: 24 additions & 4 deletions

File tree

  • applications/DeepSpeed-Chat/training/step2_dpo_finetuning

applications/DeepSpeed-Chat/training/step2_dpo_finetuning/main.py

Lines changed: 24 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -306,6 +306,26 @@ def main():
306306
args.global_rank)
307307
causal_lm_model_to_fp32_loss(model)
308308

309+
# Copied from ../step2_reward_model_finetuning/main.py.
310+
# Model bigscience/bloom-560m has large variance at ln_f.weight parameter
311+
# This makes bf16 finetuning hard.
312+
# In general, since we are replacing the model head, it makes sense to reset
313+
# the LN that precedes it.
314+
force_optimize_params = []
315+
if "bigscience/bloom-" in args.model_name_or_path:
316+
zero_init_enabled = (args.zero_stage == 3)
317+
params = [
318+
model.rwtranrsformer.ln_f.weight, model.rwtranrsformer.ln_f.bias
319+
]
320+
with deepspeed.zero.GatheredParameters(params,
321+
modifier_rank=0,
322+
enabled=zero_init_enabled):
323+
if deepspeed.comm.get_rank() == 0 or not zero_init_enabled:
324+
torch.nn.init.ones_(model.rwtransformer.ln_f.weight)
325+
torch.nn.init.zeros_(model.rwtransformer.ln_f.bias)
326+
force_optimize_params.extend(
327+
['rwtransformer.ln_f.weight', 'rwtransformer.ln_f.bias'])
328+
309329
if args.lora_dim > 0:
310330
model = convert_linear_layer_to_lora(model, args.lora_module_name,
311331
args.lora_dim)
@@ -372,7 +392,7 @@ def evaluation(model, ref_model, tokenizer, eval_dataloader):
372392
logits = args.beta * ((chosen_logps - ref_chosen_logps) -
373393
(rejected_logps - ref_rejected_logps))
374394
loss = (- torch.nn.functional.logsigmoid(logits) * (1 - args.label_smoothing) - \
375-
torch.nn.functional.logsigmoid(-logits) * args.label_smoothing).mean(0)
395+
torch.nn.functional.logsigmoid(-logits) * args.label_smoothing).mean(0)
376396
losses += loss.float()
377397
losses = losses / (step + 1)
378398
try:
@@ -419,7 +439,7 @@ def evaluation(model, ref_model, tokenizer, eval_dataloader):
419439
# Train!
420440
print_rank_0("***** Running training *****", args.global_rank)
421441
print_rank_0(
422-
f"***** Evaluating rewards, Epoch {0}/{args.num_train_epochs} *****",
442+
f"***** Evaluating rewards, Epoch {1}/{args.num_train_epochs} *****",
423443
args.global_rank)
424444
chosen_rewards, rejected_rewards, eval_loss = evaluation(
425445
model, ref_model, tokenizer, eval_dataloader)
@@ -466,7 +486,7 @@ def evaluation(model, ref_model, tokenizer, eval_dataloader):
466486
logits = args.beta * ((chosen_logps - ref_chosen_logps) -
467487
(rejected_logps - ref_rejected_logps))
468488
loss = (- torch.nn.functional.logsigmoid(logits) * (1 - args.label_smoothing) - \
469-
torch.nn.functional.logsigmoid(-logits) * args.label_smoothing).mean(0)
489+
torch.nn.functional.logsigmoid(-logits) * args.label_smoothing).mean(0)
470490
if args.print_loss:
471491
print(
472492
f"Epoch: {epoch}, Step: {step}, Rank: {torch.distributed.get_rank()}, loss = {loss}"
@@ -478,7 +498,7 @@ def evaluation(model, ref_model, tokenizer, eval_dataloader):
478498
print_throughput(model.model, args, end - start,
479499
args.global_rank)
480500

481-
# Evaluate perplexity on the validation set.
501+
# Evaluate rewards on the validation set.
482502
print_rank_0(
483503
f"***** Evaluating rewards, Epoch {epoch+1}/{args.num_train_epochs} *****",
484504
args.global_rank)

0 commit comments

Comments
 (0)