33import torchvision .transforms as transforms
44import argparse
55import deepspeed
6+ from deepspeed .accelerator import get_accelerator
67
78
89def add_argument ():
@@ -88,6 +89,22 @@ def add_argument():
8889 help =
8990 '(moe) create separate moe param groups, required when using ZeRO w. MoE'
9091 )
92+ parser .add_argument (
93+ '--dtype' ,
94+ default = 'fp16' ,
95+ type = str ,
96+ choices = ['bf16' , 'fp16' , 'fp32' ],
97+ help =
98+ 'Datatype used for training'
99+ )
100+ parser .add_argument (
101+ '--stage' ,
102+ default = 0 ,
103+ type = int ,
104+ choices = [0 , 1 , 2 , 3 ],
105+ help =
106+ 'Datatype used for training'
107+ )
91108
92109 # Include DeepSpeed configuration arguments
93110 parser = deepspeed .add_config_arguments (parser )
@@ -243,11 +260,68 @@ def create_moe_param_groups(model):
243260# 1) Distributed model
244261# 2) Distributed data loader
245262# 3) DeepSpeed optimizer
263+ ds_config = {
264+ "train_batch_size" : 16 ,
265+ "steps_per_print" : 2000 ,
266+ "optimizer" : {
267+ "type" : "Adam" ,
268+ "params" : {
269+ "lr" : 0.001 ,
270+ "betas" : [
271+ 0.8 ,
272+ 0.999
273+ ],
274+ "eps" : 1e-8 ,
275+ "weight_decay" : 3e-7
276+ }
277+ },
278+ "scheduler" : {
279+ "type" : "WarmupLR" ,
280+ "params" : {
281+ "warmup_min_lr" : 0 ,
282+ "warmup_max_lr" : 0.001 ,
283+ "warmup_num_steps" : 1000
284+ }
285+ },
286+ "gradient_clipping" : 1.0 ,
287+ "prescale_gradients" : False ,
288+ "bf16" : {
289+ "enabled" : args .dtype == "bf16"
290+ },
291+ "fp16" : {
292+ "enabled" : args .dtype == "fp16" ,
293+ "fp16_master_weights_and_grads" : False ,
294+ "loss_scale" : 0 ,
295+ "loss_scale_window" : 500 ,
296+ "hysteresis" : 2 ,
297+ "min_loss_scale" : 1 ,
298+ "initial_scale_power" : 15
299+ },
300+ "wall_clock_breakdown" : False ,
301+ "zero_optimization" : {
302+ "stage" : args .stage ,
303+ "allgather_partitions" : True ,
304+ "reduce_scatter" : True ,
305+ "allgather_bucket_size" : 50000000 ,
306+ "reduce_bucket_size" : 50000000 ,
307+ "overlap_comm" : True ,
308+ "contiguous_gradients" : True ,
309+ "cpu_offload" : False
310+ }
311+ }
312+
246313model_engine , optimizer , trainloader , __ = deepspeed .initialize (
247- args = args , model = net , model_parameters = parameters , training_data = trainset )
314+ args = args , model = net , model_parameters = parameters , training_data = trainset , config = ds_config )
315+
316+ local_device = get_accelerator ().device_name (model_engine .local_rank )
317+ local_rank = model_engine .local_rank
248318
249- fp16 = model_engine .fp16_enabled ()
250- print (f'fp16={ fp16 } ' )
319+ # For float32, target_dtype will be None so no datatype conversion needed
320+ target_dtype = None
321+ if model_engine .bfloat16_enabled ():
322+ target_dtype = torch .bfloat16
323+ elif model_engine .fp16_enabled ():
324+ target_dtype = torch .half
251325
252326#device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
253327#net.to(device)
@@ -274,10 +348,9 @@ def create_moe_param_groups(model):
274348 running_loss = 0.0
275349 for i , data in enumerate (trainloader ):
276350 # get the inputs; data is a list of [inputs, labels]
277- inputs , labels = data [0 ].to (model_engine .local_rank ), data [1 ].to (
278- model_engine .local_rank )
279- if fp16 :
280- inputs = inputs .half ()
351+ inputs , labels = data [0 ].to (local_device ), data [1 ].to (local_device )
352+ if target_dtype != None :
353+ inputs = inputs .to (target_dtype )
281354 outputs = model_engine (inputs )
282355 loss = criterion (outputs , labels )
283356
@@ -286,7 +359,7 @@ def create_moe_param_groups(model):
286359
287360 # print statistics
288361 running_loss += loss .item ()
289- if i % args .log_interval == (
362+ if local_rank == 0 and i % args .log_interval == (
290363 args .log_interval -
291364 1 ): # print every log_interval mini-batches
292365 print ('[%d, %5d] loss: %.3f' %
@@ -317,9 +390,9 @@ def create_moe_param_groups(model):
317390
318391########################################################################
319392# Okay, now let us see what the neural network thinks these examples above are:
320- if fp16 :
321- images = images .half ( )
322- outputs = net (images .to (model_engine . local_rank ))
393+ if target_dtype != None :
394+ images = images .to ( target_dtype )
395+ outputs = net (images .to (local_device ))
323396
324397########################################################################
325398# The outputs are energies for the 10 classes.
@@ -340,13 +413,12 @@ def create_moe_param_groups(model):
340413with torch .no_grad ():
341414 for data in testloader :
342415 images , labels = data
343- if fp16 :
344- images = images .half ( )
345- outputs = net (images .to (model_engine . local_rank ))
416+ if target_dtype != None :
417+ images = images .to ( target_dtype )
418+ outputs = net (images .to (local_device ))
346419 _ , predicted = torch .max (outputs .data , 1 )
347420 total += labels .size (0 )
348- correct += (predicted == labels .to (
349- model_engine .local_rank )).sum ().item ()
421+ correct += (predicted == labels .to (local_device )).sum ().item ()
350422
351423print ('Accuracy of the network on the 10000 test images: %d %%' %
352424 (100 * correct / total ))
@@ -364,11 +436,11 @@ def create_moe_param_groups(model):
364436with torch .no_grad ():
365437 for data in testloader :
366438 images , labels = data
367- if fp16 :
368- images = images .half ( )
369- outputs = net (images .to (model_engine . local_rank ))
439+ if target_dtype != None :
440+ images = images .to ( target_dtype )
441+ outputs = net (images .to (local_device ))
370442 _ , predicted = torch .max (outputs , 1 )
371- c = (predicted == labels .to (model_engine . local_rank )).squeeze ()
443+ c = (predicted == labels .to (local_device )).squeeze ()
372444 for i in range (4 ):
373445 label = labels [i ]
374446 class_correct [label ] += c [i ].item ()
0 commit comments