Skip to content

Commit eea58e9

Browse files
delocktjruwase
authored andcommitted
Enable non-CUDA device for CIFAR10 and HelloDeepSpeed training example (deepspeedai#651)
* support bf16 and CPU accelerator * support both bfloat16 and fp16 data type * change default data type to bf16 to help run this demo on both CPU and GPU * enable HelloDeepSpeed for non-CUDA device * revert changes for sh output * allow select bf16/fp16 datatype * revert unnecessary changes * seperate bf16 and fp16 config --------- Co-authored-by: Olatunji Ruwase <olruwase@microsoft.com>
1 parent 6f7f121 commit eea58e9

File tree

9 files changed

+110
-73
lines changed

9 files changed

+110
-73
lines changed

training/HelloDeepSpeed/run.sh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
python train_bert.py --checkpoint_dir ./experiment

training/HelloDeepSpeed/run_ds.sh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
deepspeed --bind_cores_to_rank train_bert_ds.py --checkpoint_dir experiment_deepspeed $@

training/HelloDeepSpeed/train_bert.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@
2424
RobertaPreTrainedModel,
2525
)
2626

27+
from deepspeed.accelerator import get_accelerator
28+
2729
logger = loguru.logger
2830

2931
######################################################################
@@ -625,8 +627,8 @@ def train(
625627
pathlib.Path: The final experiment directory
626628
627629
"""
628-
device = (torch.device("cuda", local_rank) if (local_rank > -1)
629-
and torch.cuda.is_available() else torch.device("cpu"))
630+
device = (torch.device(get_accelerator().device_name(), local_rank) if (local_rank > -1)
631+
and get_accelerator().is_available() else torch.device("cpu"))
630632
################################
631633
###### Create Exp. Dir #########
632634
################################

training/HelloDeepSpeed/train_bert_ds.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
RobertaPreTrainedModel,
3232
)
3333

34+
from deepspeed.accelerator import get_accelerator
3435

3536
def is_rank_0() -> bool:
3637
return int(os.environ.get("RANK", "0")) == 0
@@ -612,6 +613,7 @@ def train(
612613
checkpoint_every: int = 1000,
613614
log_every: int = 10,
614615
local_rank: int = -1,
616+
dtype: str = "bf16",
615617
) -> pathlib.Path:
616618
"""Trains a [Bert style](https://arxiv.org/pdf/1810.04805.pdf)
617619
(transformer encoder only) model for MLM Task
@@ -667,8 +669,8 @@ def train(
667669
pathlib.Path: The final experiment directory
668670
669671
"""
670-
device = (torch.device("cuda", local_rank) if (local_rank > -1)
671-
and torch.cuda.is_available() else torch.device("cpu"))
672+
device = (torch.device(get_accelerator().device_name(), local_rank) if (local_rank > -1)
673+
and get_accelerator().is_available() else torch.device("cpu"))
672674
################################
673675
###### Create Exp. Dir #########
674676
################################
@@ -777,6 +779,7 @@ def train(
777779
###### DeepSpeed engine ########
778780
################################
779781
log_dist("Creating DeepSpeed engine", ranks=[0], level=logging.INFO)
782+
assert (dtype == 'fp16' or dtype == 'bf16')
780783
ds_config = {
781784
"train_micro_batch_size_per_gpu": batch_size,
782785
"optimizer": {
@@ -785,7 +788,7 @@ def train(
785788
"lr": 1e-4
786789
}
787790
},
788-
"fp16": {
791+
dtype: {
789792
"enabled": True
790793
},
791794
"zero_optimization": {

training/cifar/README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,3 +18,4 @@ run_ds_moe.sh
1818
* To run baseline CIFAR-10 model - "python cifar10_tutorial.py"
1919
* To run DeepSpeed CIFAR-10 model - "bash run_ds.sh"
2020
* To run DeepSpeed CIFAR-10 model with Mixture of Experts (MoE) - "bash run_ds_moe.sh"
21+
* To run with different data type (default='fp16') and zero stages (default=0) - "bash run_ds.sh --dtype={fp16|bf16} --stage={0|1|2|3}"

training/cifar/cifar10_deepspeed.py

Lines changed: 92 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import torchvision.transforms as transforms
44
import argparse
55
import deepspeed
6+
from deepspeed.accelerator import get_accelerator
67

78

89
def add_argument():
@@ -88,6 +89,22 @@ def add_argument():
8889
help=
8990
'(moe) create separate moe param groups, required when using ZeRO w. MoE'
9091
)
92+
parser.add_argument(
93+
'--dtype',
94+
default='fp16',
95+
type=str,
96+
choices=['bf16', 'fp16', 'fp32'],
97+
help=
98+
'Datatype used for training'
99+
)
100+
parser.add_argument(
101+
'--stage',
102+
default=0,
103+
type=int,
104+
choices=[0, 1, 2, 3],
105+
help=
106+
'Datatype used for training'
107+
)
91108

92109
# Include DeepSpeed configuration arguments
93110
parser = deepspeed.add_config_arguments(parser)
@@ -243,11 +260,68 @@ def create_moe_param_groups(model):
243260
# 1) Distributed model
244261
# 2) Distributed data loader
245262
# 3) DeepSpeed optimizer
263+
ds_config = {
264+
"train_batch_size": 16,
265+
"steps_per_print": 2000,
266+
"optimizer": {
267+
"type": "Adam",
268+
"params": {
269+
"lr": 0.001,
270+
"betas": [
271+
0.8,
272+
0.999
273+
],
274+
"eps": 1e-8,
275+
"weight_decay": 3e-7
276+
}
277+
},
278+
"scheduler": {
279+
"type": "WarmupLR",
280+
"params": {
281+
"warmup_min_lr": 0,
282+
"warmup_max_lr": 0.001,
283+
"warmup_num_steps": 1000
284+
}
285+
},
286+
"gradient_clipping": 1.0,
287+
"prescale_gradients": False,
288+
"bf16": {
289+
"enabled": args.dtype == "bf16"
290+
},
291+
"fp16": {
292+
"enabled": args.dtype == "fp16",
293+
"fp16_master_weights_and_grads": False,
294+
"loss_scale": 0,
295+
"loss_scale_window": 500,
296+
"hysteresis": 2,
297+
"min_loss_scale": 1,
298+
"initial_scale_power": 15
299+
},
300+
"wall_clock_breakdown": False,
301+
"zero_optimization": {
302+
"stage": args.stage,
303+
"allgather_partitions": True,
304+
"reduce_scatter": True,
305+
"allgather_bucket_size": 50000000,
306+
"reduce_bucket_size": 50000000,
307+
"overlap_comm": True,
308+
"contiguous_gradients": True,
309+
"cpu_offload": False
310+
}
311+
}
312+
246313
model_engine, optimizer, trainloader, __ = deepspeed.initialize(
247-
args=args, model=net, model_parameters=parameters, training_data=trainset)
314+
args=args, model=net, model_parameters=parameters, training_data=trainset, config=ds_config)
315+
316+
local_device = get_accelerator().device_name(model_engine.local_rank)
317+
local_rank = model_engine.local_rank
248318

249-
fp16 = model_engine.fp16_enabled()
250-
print(f'fp16={fp16}')
319+
# For float32, target_dtype will be None so no datatype conversion needed
320+
target_dtype = None
321+
if model_engine.bfloat16_enabled():
322+
target_dtype=torch.bfloat16
323+
elif model_engine.fp16_enabled():
324+
target_dtype=torch.half
251325

252326
#device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
253327
#net.to(device)
@@ -274,10 +348,9 @@ def create_moe_param_groups(model):
274348
running_loss = 0.0
275349
for i, data in enumerate(trainloader):
276350
# get the inputs; data is a list of [inputs, labels]
277-
inputs, labels = data[0].to(model_engine.local_rank), data[1].to(
278-
model_engine.local_rank)
279-
if fp16:
280-
inputs = inputs.half()
351+
inputs, labels = data[0].to(local_device), data[1].to(local_device)
352+
if target_dtype != None:
353+
inputs = inputs.to(target_dtype)
281354
outputs = model_engine(inputs)
282355
loss = criterion(outputs, labels)
283356

@@ -286,7 +359,7 @@ def create_moe_param_groups(model):
286359

287360
# print statistics
288361
running_loss += loss.item()
289-
if i % args.log_interval == (
362+
if local_rank == 0 and i % args.log_interval == (
290363
args.log_interval -
291364
1): # print every log_interval mini-batches
292365
print('[%d, %5d] loss: %.3f' %
@@ -317,9 +390,9 @@ def create_moe_param_groups(model):
317390

318391
########################################################################
319392
# Okay, now let us see what the neural network thinks these examples above are:
320-
if fp16:
321-
images = images.half()
322-
outputs = net(images.to(model_engine.local_rank))
393+
if target_dtype != None:
394+
images = images.to(target_dtype)
395+
outputs = net(images.to(local_device))
323396

324397
########################################################################
325398
# The outputs are energies for the 10 classes.
@@ -340,13 +413,12 @@ def create_moe_param_groups(model):
340413
with torch.no_grad():
341414
for data in testloader:
342415
images, labels = data
343-
if fp16:
344-
images = images.half()
345-
outputs = net(images.to(model_engine.local_rank))
416+
if target_dtype != None:
417+
images = images.to(target_dtype)
418+
outputs = net(images.to(local_device))
346419
_, predicted = torch.max(outputs.data, 1)
347420
total += labels.size(0)
348-
correct += (predicted == labels.to(
349-
model_engine.local_rank)).sum().item()
421+
correct += (predicted == labels.to(local_device)).sum().item()
350422

351423
print('Accuracy of the network on the 10000 test images: %d %%' %
352424
(100 * correct / total))
@@ -364,11 +436,11 @@ def create_moe_param_groups(model):
364436
with torch.no_grad():
365437
for data in testloader:
366438
images, labels = data
367-
if fp16:
368-
images = images.half()
369-
outputs = net(images.to(model_engine.local_rank))
439+
if target_dtype != None:
440+
images = images.to(target_dtype)
441+
outputs = net(images.to(local_device))
370442
_, predicted = torch.max(outputs, 1)
371-
c = (predicted == labels.to(model_engine.local_rank)).squeeze()
443+
c = (predicted == labels.to(local_device)).squeeze()
372444
for i in range(4):
373445
label = labels[i]
374446
class_correct[label] += c[i].item()

training/cifar/ds_config.json

Lines changed: 0 additions & 46 deletions
This file was deleted.

training/cifar/run_ds.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
#!/bin/bash
22

3-
deepspeed cifar10_deepspeed.py --deepspeed --deepspeed_config ds_config.json $@
3+
deepspeed --bind_cores_to_rank cifar10_deepspeed.py --deepspeed $@

training/cifar/run_ds_moe.sh

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,10 @@ EP_SIZE=2
99
# Number of total experts
1010
EXPERTS=2
1111

12-
deepspeed --num_nodes=${NUM_NODES} --num_gpus=${NUM_GPUS} cifar10_deepspeed.py \
12+
deepspeed --num_nodes=${NUM_NODES}\
13+
--num_gpus=${NUM_GPUS} \
14+
--bind_cores_to_rank \
15+
cifar10_deepspeed.py \
1316
--log-interval 100 \
1417
--deepspeed \
1518
--deepspeed_config ds_config.json \

0 commit comments

Comments
 (0)