Skip to content

Commit c77e75a

Browse files
authored
Add Hybrid Engine mode + move arguments to a file (#684)
1 parent 5031a60 commit c77e75a

2 files changed

Lines changed: 31 additions & 20 deletions

File tree

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
from argparse import ArgumentParser
2+
import os
3+
4+
parser = ArgumentParser()
5+
6+
parser.add_argument("--model", required=True, type=str, help="model_name")
7+
parser.add_argument("--checkpoint_path", required=False, default=None, type=str, help="model checkpoint path")
8+
parser.add_argument("--save_mp_checkpoint_path", required=False, default=None, type=str, help="save-path to store the new model checkpoint")
9+
parser.add_argument("--batch_size", default=1, type=int, help="batch size")
10+
parser.add_argument("--dtype", default="float16", type=str, choices=["float32", "float16", "int8"], help="data-type")
11+
parser.add_argument("--hf_baseline", action='store_true', help="disable DeepSpeed inference")
12+
parser.add_argument("--use_kernel", action='store_true', help="enable kernel-injection")
13+
parser.add_argument("--max_tokens", default=1024, type=int, help="maximum tokens used for the text-generation KV-cache")
14+
parser.add_argument("--max_new_tokens", default=50, type=int, help="maximum new tokens to generate")
15+
parser.add_argument("--greedy", action='store_true', help="greedy generation mode")
16+
parser.add_argument("--use_meta_tensor", action='store_true', help="use the meta tensors to initialize model")
17+
parser.add_argument("--test_performance", action='store_true', help="enable latency, bandwidth, and throughout testing")
18+
parser.add_argument("--local_rank", type=int, default=int(os.getenv("LOCAL_RANK", "0")), help="local rank")
19+
parser.add_argument("--world_size", type=int, default=int(os.getenv("WORLD_SIZE", "1")), help="world_size")
20+
parser.add_argument("--test_hybrid_engine", action='store_true', help="enable hybrid engine testing")

inference/huggingface/text-generation/inference-test.py

Lines changed: 11 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
from argparse import ArgumentParser
21
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig
32
import deepspeed
43
import math
@@ -7,23 +6,8 @@
76
import time
87
from utils import DSPipeline, Performance
98
from deepspeed.runtime.utils import see_memory_usage
9+
from arguments import parser
1010

11-
parser = ArgumentParser()
12-
13-
parser.add_argument("--model", required=True, type=str, help="model_name")
14-
parser.add_argument("--checkpoint_path", required=False, default=None, type=str, help="model checkpoint path")
15-
parser.add_argument("--save_mp_checkpoint_path", required=False, default=None, type=str, help="save-path to store the new model checkpoint")
16-
parser.add_argument("--batch_size", default=1, type=int, help="batch size")
17-
parser.add_argument("--dtype", default="float16", type=str, choices=["float32", "float16", "int8"], help="data-type")
18-
parser.add_argument("--hf_baseline", action='store_true', help="disable DeepSpeed inference")
19-
parser.add_argument("--use_kernel", action='store_true', help="enable kernel-injection")
20-
parser.add_argument("--max_tokens", default=1024, type=int, help="maximum tokens used for the text-generation KV-cache")
21-
parser.add_argument("--max_new_tokens", default=50, type=int, help="maximum new tokens to generate")
22-
parser.add_argument("--greedy", action='store_true', help="greedy generation mode")
23-
parser.add_argument("--use_meta_tensor", action='store_true', help="use the meta tensors to initialize model")
24-
parser.add_argument("--test_performance", action='store_true', help="enable latency, bandwidth, and throughout testing")
25-
parser.add_argument("--local_rank", type=int, default=int(os.getenv("LOCAL_RANK", "0")), help="local rank")
26-
parser.add_argument("--world_size", type=int, default=int(os.getenv("WORLD_SIZE", "1")), help="world_size")
2711
args = parser.parse_args()
2812

2913
if args.hf_baseline and args.world_size > 1:
@@ -51,16 +35,23 @@
5135
else:
5236
ds_kwargs = dict()
5337

54-
if not args.hf_baseline:
55-
pipe.model = deepspeed.init_inference(pipe.model,
38+
# Use DeepSpeed Hybrid Engine for inference
39+
if args.test_hybrid_engine:
40+
ds_config = {"train_batch_size": args.batch_size, "fp16": {"enabled": True if data_type==torch.half else False}, "hybrid_engine": {"enabled": True}}
41+
pipe.model, *_ = deepspeed.initialize(model=pipe.model, config=ds_config)
42+
pipe.model.eval()
43+
# If not trying with the HuggingFace baseline, use DeepSpeed Inference Engine
44+
else:
45+
if not args.hf_baseline:
46+
pipe.model = deepspeed.init_inference(pipe.model,
5647
dtype=data_type,
5748
mp_size=args.world_size,
5849
replace_with_kernel_inject=args.use_kernel,
5950
max_tokens=args.max_tokens,
6051
save_mp_checkpoint_path=args.save_mp_checkpoint_path,
6152
**ds_kwargs
6253
)
63-
54+
6455
if args.local_rank == 0:
6556
see_memory_usage("after init_inference", True)
6657

0 commit comments

Comments
 (0)