diff --git a/applications/DeepSpeed-Chat/training/step1_supervised_finetuning/main.py b/applications/DeepSpeed-Chat/training/step1_supervised_finetuning/main.py index 9df8b5a0f..9db666913 100755 --- a/applications/DeepSpeed-Chat/training/step1_supervised_finetuning/main.py +++ b/applications/DeepSpeed-Chat/training/step1_supervised_finetuning/main.py @@ -29,6 +29,7 @@ from utils.ds_utils import get_train_ds_config from utils.module.lora import convert_linear_layer_to_lora, convert_lora_to_linear_layer, only_optimize_lora_parameters, make_model_gradient_checkpointing_compatible from utils.model.model_utils import create_hf_model +from utils.perf import print_throughput def parse_args(): @@ -321,7 +322,9 @@ def evaluation(model, eval_dataloader): f"Beginning of Epoch {epoch+1}/{args.num_train_epochs}, Total Micro Batches {len(train_dataloader)}", args.global_rank) model.train() + import time for step, batch in enumerate(train_dataloader): + start = time.time() batch = to_device(batch, device) outputs = model(**batch, use_cache=False) loss = outputs.loss @@ -331,6 +334,10 @@ def evaluation(model, eval_dataloader): ) model.backward(loss) model.step() + end = time.time() + if torch.distributed.get_rank() == 0: + print_throughput(model.model, args, end - start, + args.global_rank) # Evaluate perplexity on the validation set. print_rank_0( diff --git a/applications/DeepSpeed-Chat/training/step3_rlhf_finetuning/main.py b/applications/DeepSpeed-Chat/training/step3_rlhf_finetuning/main.py index 1743fba02..011b0e830 100644 --- a/applications/DeepSpeed-Chat/training/step3_rlhf_finetuning/main.py +++ b/applications/DeepSpeed-Chat/training/step3_rlhf_finetuning/main.py @@ -43,7 +43,7 @@ from utils.data.data_utils import create_prompt_dataset, MiniDataset, DataCollatorRLHF, get_unsupervised_data from utils.utils import print_rank_0, to_device, save_hf_format, set_random_seed, get_all_reduce_mean, moving_average, save_zero_three_model, load_hf_tokenizer from utils.module.lora import convert_lora_to_linear_layer -from utils.perf import print_throughput +from utils.perf import print_throughput_step3 writer = None @@ -481,7 +481,7 @@ def main(): for step, (batch_prompt, batch_unsupervised) in enumerate( zip(prompt_train_dataloader, unsupervised_train_dataloader)): - start = time.time() + # start = time.time() batch_prompt = to_device(batch_prompt, device) # prompts = batch_prompt['prompt'] @@ -535,15 +535,15 @@ def main(): random.shuffle(unsup_dataset) end = time.time() - e2e_time = end - start training_time = end - training_start + e2e_time = training_time + trainer.generate_time * args.generation_batches # it is an approximation, we did not include, e.g., rw forward time etc print_rank_0( f'Epoch: {epoch} | Step: {step} | PPO Epoch: {ppo_ep+1} | Actor Loss: {actor_loss_sum/inner_iter} | Critic Loss: {critic_loss_sum/inner_iter} | Unsupervised Loss: {unsup_loss_sum/inner_iter}', args.global_rank) - print_throughput(rlhf_engine.actor.model, args, e2e_time, - trainer.generate_time, training_time, - args.global_rank) + print_throughput_step3(rlhf_engine.actor.model, args, e2e_time, + trainer.generate_time, training_time, + args.global_rank) average_reward = get_all_reduce_mean(average_reward).item() print_rank_0( f"Average reward score: {average_reward/inner_iter}", diff --git a/applications/DeepSpeed-Chat/training/utils/perf.py b/applications/DeepSpeed-Chat/training/utils/perf.py index 4d399a537..ff8ac4eb8 100644 --- a/applications/DeepSpeed-Chat/training/utils/perf.py +++ b/applications/DeepSpeed-Chat/training/utils/perf.py @@ -6,12 +6,54 @@ import torch -def print_throughput(hf_model, - args, - e2e_time, - gen_exp_time, - train_time, - rank=0): +# This function can be used to print throughput for Step 1 and 2 only +def print_throughput(hf_model, args, e2e_time, rank=0): + if rank <= 0: + hf_config = hf_model.config + num_layers = getattr(hf_config, "num_hidden_layers", + getattr(hf_config, "n_layer", None)) + hidden_size = getattr(hf_config, "hidden_size", + getattr(hf_config, "n_embd", None)) + vocab_size = getattr(hf_config, "vocab_size", None) + assert all( + (num_layers, hidden_size, vocab_size) + ), "Could not determine number of layers, hidden size, and vocab size of the model" + + gpus_per_model = torch.distributed.get_world_size() + seq_length = args.max_seq_len + batch_size = args.per_device_train_batch_size + samples_per_second = batch_size / e2e_time + checkpoint_activations_factor = 4 if args.gradient_checkpointing else 3 + hf_model._num_params = sum([ + p.ds_numel if hasattr(p, "ds_tensor") else p.numel() + for p in hf_model.parameters() + ]) + params_in_billions = hf_model._num_params / (1e9) + + # Megatron paper's formula to calculate training flops + train_flops_per_iteration = ( + 24 * checkpoint_activations_factor * batch_size * seq_length * + num_layers * + (hidden_size**2)) * (1.0 + (seq_length / (6.0 * hidden_size)) + + (vocab_size / + (16.0 * num_layers * hidden_size))) + + train_tflops = train_flops_per_iteration / (e2e_time * gpus_per_model * + (10**12)) + + param_string = f"{params_in_billions:.3f} B" if params_in_billions != 0 else "NA" + print( + f"Model Parameters: {param_string}, Latency: {e2e_time:.2f}s, TFLOPs: {train_tflops:.2f}, Samples/sec: {samples_per_second:.2f}, Time/seq {e2e_time/batch_size:.2f}s, Batch Size: {batch_size}, Sequence Length: {seq_length}" + ) + + +# Enhanced version of the function above that provides calculations and printing for Step 3 +def print_throughput_step3(hf_model, + args, + e2e_time, + gen_exp_time, + train_time, + rank=0): if rank <= 0: hf_config = hf_model.config num_layers = getattr(hf_config, "num_hidden_layers", @@ -34,8 +76,7 @@ def print_throughput(hf_model, ]) params_in_billions = hf_model._num_params / (1e9) - # megatron paper formula - + # Megatron paper's formula to calculate training flops train_flops_per_iteration = ( 24 * checkpoint_activations_factor * batch_size * seq_length * num_layers * @@ -48,9 +89,9 @@ def print_throughput(hf_model, gen_bs = args.per_device_generation_batch_size * gpus_per_model + # Modified formula for calculating flops in forward pass only gen_flops_per_iteration = ( - 24 * checkpoint_activations_factor * gen_bs * seq_length * - num_layers * + 24 * gen_bs * seq_length * num_layers * (hidden_size**2)) * (1.0 + (seq_length / (6.0 * hidden_size)) + (vocab_size / (16.0 * num_layers * hidden_size))) @@ -68,13 +109,15 @@ def print_throughput(hf_model, gen_bw = (hf_model._num_params * (num_bytes / 1e9)) / gen_exp_time * args.max_answer_seq_len - total_tflops = train_tflops + gen_tflops * args.generation_batches + total_flops_per_iteration = train_flops_per_iteration + gen_flops_per_iteration * args.generation_batches + total_tflops = total_flops_per_iteration / (e2e_time * gpus_per_model * + (10**12)) print( f"End-to-End => Latency: {e2e_time:.2f}s, TFLOPs: {total_tflops:.2f}, Samples/sec: {samples_per_second:.2f}, Time/seq {e2e_time/batch_size:.2f}s, Batch Size: {batch_size}, Sequence Length: {seq_length}" ) print( - f"Generation => Latency: {gen_exp_time:.2f}s, Approx. TFLOPs: {gen_tflops:.2f}, BW: {gen_bw:.2f} GB/sec" + f"Generation => Latency: {gen_exp_time:.2f}s, TFLOPs: {gen_tflops:.2f}, BW: {gen_bw:.2f} GB/sec" ) print( f"Training => Latency: {train_time:.2f}s, TFLOPs: {train_tflops:.2f}"