Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from utils.ds_utils import get_train_ds_config
from utils.module.lora import convert_linear_layer_to_lora, convert_lora_to_linear_layer, only_optimize_lora_parameters, make_model_gradient_checkpointing_compatible
from utils.model.model_utils import create_hf_model
from utils.perf import print_throughput


def parse_args():
Expand Down Expand Up @@ -321,7 +322,9 @@ def evaluation(model, eval_dataloader):
f"Beginning of Epoch {epoch+1}/{args.num_train_epochs}, Total Micro Batches {len(train_dataloader)}",
args.global_rank)
model.train()
import time
for step, batch in enumerate(train_dataloader):
start = time.time()
batch = to_device(batch, device)
outputs = model(**batch, use_cache=False)
loss = outputs.loss
Expand All @@ -331,6 +334,10 @@ def evaluation(model, eval_dataloader):
)
model.backward(loss)
model.step()
end = time.time()
if torch.distributed.get_rank() == 0:
print_throughput(model.model, args, end - start,
args.global_rank)

# Evaluate perplexity on the validation set.
print_rank_0(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@
from utils.data.data_utils import create_prompt_dataset, MiniDataset, DataCollatorRLHF, get_unsupervised_data
from utils.utils import print_rank_0, to_device, save_hf_format, set_random_seed, get_all_reduce_mean, moving_average, save_zero_three_model, load_hf_tokenizer
from utils.module.lora import convert_lora_to_linear_layer
from utils.perf import print_throughput
from utils.perf import print_throughput_step3

writer = None

Expand Down Expand Up @@ -481,7 +481,7 @@ def main():
for step, (batch_prompt, batch_unsupervised) in enumerate(
zip(prompt_train_dataloader, unsupervised_train_dataloader)):

start = time.time()
# start = time.time()
batch_prompt = to_device(batch_prompt, device)

# prompts = batch_prompt['prompt']
Expand Down Expand Up @@ -535,15 +535,15 @@ def main():
random.shuffle(unsup_dataset)

end = time.time()
e2e_time = end - start
training_time = end - training_start
e2e_time = training_time + trainer.generate_time * args.generation_batches # it is an approximation, we did not include, e.g., rw forward time etc

print_rank_0(
f'Epoch: {epoch} | Step: {step} | PPO Epoch: {ppo_ep+1} | Actor Loss: {actor_loss_sum/inner_iter} | Critic Loss: {critic_loss_sum/inner_iter} | Unsupervised Loss: {unsup_loss_sum/inner_iter}',
args.global_rank)
print_throughput(rlhf_engine.actor.model, args, e2e_time,
trainer.generate_time, training_time,
args.global_rank)
print_throughput_step3(rlhf_engine.actor.model, args, e2e_time,
trainer.generate_time, training_time,
args.global_rank)
average_reward = get_all_reduce_mean(average_reward).item()
print_rank_0(
f"Average reward score: {average_reward/inner_iter}",
Expand Down
67 changes: 55 additions & 12 deletions applications/DeepSpeed-Chat/training/utils/perf.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,54 @@
import torch


def print_throughput(hf_model,
args,
e2e_time,
gen_exp_time,
train_time,
rank=0):
# This function can be used to print throughput for Step 1 and 2 only
def print_throughput(hf_model, args, e2e_time, rank=0):
if rank <= 0:
hf_config = hf_model.config
num_layers = getattr(hf_config, "num_hidden_layers",
getattr(hf_config, "n_layer", None))
hidden_size = getattr(hf_config, "hidden_size",
getattr(hf_config, "n_embd", None))
vocab_size = getattr(hf_config, "vocab_size", None)
assert all(
(num_layers, hidden_size, vocab_size)
), "Could not determine number of layers, hidden size, and vocab size of the model"

gpus_per_model = torch.distributed.get_world_size()
seq_length = args.max_seq_len
batch_size = args.per_device_train_batch_size
samples_per_second = batch_size / e2e_time
checkpoint_activations_factor = 4 if args.gradient_checkpointing else 3
hf_model._num_params = sum([
p.ds_numel if hasattr(p, "ds_tensor") else p.numel()
for p in hf_model.parameters()
])
params_in_billions = hf_model._num_params / (1e9)

# Megatron paper's formula to calculate training flops
train_flops_per_iteration = (
24 * checkpoint_activations_factor * batch_size * seq_length *
num_layers *
(hidden_size**2)) * (1.0 + (seq_length / (6.0 * hidden_size)) +
(vocab_size /
(16.0 * num_layers * hidden_size)))

train_tflops = train_flops_per_iteration / (e2e_time * gpus_per_model *
(10**12))

param_string = f"{params_in_billions:.3f} B" if params_in_billions != 0 else "NA"
print(
f"Model Parameters: {param_string}, Latency: {e2e_time:.2f}s, TFLOPs: {train_tflops:.2f}, Samples/sec: {samples_per_second:.2f}, Time/seq {e2e_time/batch_size:.2f}s, Batch Size: {batch_size}, Sequence Length: {seq_length}"
)


# Enhanced version of the function above that provides calculations and printing for Step 3
def print_throughput_step3(hf_model,
args,
e2e_time,
gen_exp_time,
train_time,
rank=0):
if rank <= 0:
hf_config = hf_model.config
num_layers = getattr(hf_config, "num_hidden_layers",
Expand All @@ -34,8 +76,7 @@ def print_throughput(hf_model,
])
params_in_billions = hf_model._num_params / (1e9)

# megatron paper formula

# Megatron paper's formula to calculate training flops
train_flops_per_iteration = (
24 * checkpoint_activations_factor * batch_size * seq_length *
num_layers *
Expand All @@ -48,9 +89,9 @@ def print_throughput(hf_model,

gen_bs = args.per_device_generation_batch_size * gpus_per_model

# Modified formula for calculating flops in forward pass only
gen_flops_per_iteration = (
24 * checkpoint_activations_factor * gen_bs * seq_length *
num_layers *
24 * gen_bs * seq_length * num_layers *
(hidden_size**2)) * (1.0 + (seq_length / (6.0 * hidden_size)) +
(vocab_size /
(16.0 * num_layers * hidden_size)))
Expand All @@ -68,13 +109,15 @@ def print_throughput(hf_model,
gen_bw = (hf_model._num_params *
(num_bytes / 1e9)) / gen_exp_time * args.max_answer_seq_len

total_tflops = train_tflops + gen_tflops * args.generation_batches
total_flops_per_iteration = train_flops_per_iteration + gen_flops_per_iteration * args.generation_batches
total_tflops = total_flops_per_iteration / (e2e_time * gpus_per_model *
(10**12))

print(
f"End-to-End => Latency: {e2e_time:.2f}s, TFLOPs: {total_tflops:.2f}, Samples/sec: {samples_per_second:.2f}, Time/seq {e2e_time/batch_size:.2f}s, Batch Size: {batch_size}, Sequence Length: {seq_length}"
)
print(
f"Generation => Latency: {gen_exp_time:.2f}s, Approx. TFLOPs: {gen_tflops:.2f}, BW: {gen_bw:.2f} GB/sec"
f"Generation => Latency: {gen_exp_time:.2f}s, TFLOPs: {gen_tflops:.2f}, BW: {gen_bw:.2f} GB/sec"
)
print(
f"Training => Latency: {train_time:.2f}s, TFLOPs: {train_tflops:.2f}"
Expand Down