71
71
help = 'seed for initializing training. ' )
72
72
parser .add_argument ('--gpu' , default = None , type = int ,
73
73
help = 'GPU id to use.' )
74
+ parser .add_argument ('--no-accel' , action = 'store_true' ,
75
+ help = 'disables accelerator' )
74
76
parser .add_argument ('--multiprocessing-distributed' , action = 'store_true' ,
75
77
help = 'Use multi-processing distributed training to launch '
76
78
'N processes per node, which has N GPUs. This is the '
@@ -104,8 +106,17 @@ def main():
104
106
105
107
args .distributed = args .world_size > 1 or args .multiprocessing_distributed
106
108
107
- if torch .cuda .is_available ():
108
- ngpus_per_node = torch .cuda .device_count ()
109
+ use_accel = not args .no_accel and torch .accelerator .is_available ()
110
+
111
+ if use_accel :
112
+ device = torch .accelerator .current_accelerator ()
113
+ else :
114
+ device = torch .device ("cpu" )
115
+
116
+ print (f"Using device: { device } " )
117
+
118
+ if device .type == 'cuda' :
119
+ ngpus_per_node = torch .accelerator .device_count ()
109
120
if ngpus_per_node == 1 and args .dist_backend == "nccl" :
110
121
warnings .warn ("nccl backend >=2.5 requires GPU count>1, see https://github.com/NVIDIA/nccl/issues/103 perhaps use 'gloo'" )
111
122
else :
@@ -127,8 +138,15 @@ def main_worker(gpu, ngpus_per_node, args):
127
138
global best_acc1
128
139
args .gpu = gpu
129
140
130
- if args .gpu is not None :
131
- print ("Use GPU: {} for training" .format (args .gpu ))
141
+ use_accel = not args .no_accel and torch .accelerator .is_available ()
142
+
143
+ if use_accel :
144
+ if args .gpu is not None :
145
+ torch .accelerator .set_device_index (args .gpu )
146
+ print ("Use GPU: {} for training" .format (args .gpu ))
147
+ device = torch .accelerator .current_accelerator ()
148
+ else :
149
+ device = torch .device ("cpu" )
132
150
133
151
if args .distributed :
134
152
if args .dist_url == "env://" and args .rank == - 1 :
@@ -147,16 +165,16 @@ def main_worker(gpu, ngpus_per_node, args):
147
165
print ("=> creating model '{}'" .format (args .arch ))
148
166
model = models .__dict__ [args .arch ]()
149
167
150
- if not torch . cuda . is_available () and not torch . backends . mps . is_available () :
168
+ if not use_accel :
151
169
print ('using CPU, this will be slow' )
152
170
elif args .distributed :
153
171
# For multiprocessing distributed, DistributedDataParallel constructor
154
172
# should always set the single device scope, otherwise,
155
173
# DistributedDataParallel will use all available devices.
156
- if torch . cuda . is_available () :
174
+ if device . type == ' cuda' :
157
175
if args .gpu is not None :
158
176
torch .cuda .set_device (args .gpu )
159
- model .cuda (args . gpu )
177
+ model .cuda (device )
160
178
# When using a single GPU per process and per
161
179
# DistributedDataParallel, we need to divide the batch size
162
180
# ourselves based on the total number of GPUs of the current node.
@@ -168,29 +186,17 @@ def main_worker(gpu, ngpus_per_node, args):
168
186
# DistributedDataParallel will divide and allocate batch_size to all
169
187
# available GPUs if device_ids are not set
170
188
model = torch .nn .parallel .DistributedDataParallel (model )
171
- elif args .gpu is not None and torch .cuda .is_available ():
172
- torch .cuda .set_device (args .gpu )
173
- model = model .cuda (args .gpu )
174
- elif torch .backends .mps .is_available ():
175
- device = torch .device ("mps" )
176
- model = model .to (device )
177
- else :
189
+ elif device .type == 'cuda' :
178
190
# DataParallel will divide and allocate batch_size to all available GPUs
179
191
if args .arch .startswith ('alexnet' ) or args .arch .startswith ('vgg' ):
180
192
model .features = torch .nn .DataParallel (model .features )
181
193
model .cuda ()
182
194
else :
183
195
model = torch .nn .DataParallel (model ).cuda ()
184
-
185
- if torch .cuda .is_available ():
186
- if args .gpu :
187
- device = torch .device ('cuda:{}' .format (args .gpu ))
188
- else :
189
- device = torch .device ("cuda" )
190
- elif torch .backends .mps .is_available ():
191
- device = torch .device ("mps" )
192
196
else :
193
- device = torch .device ("cpu" )
197
+ model .to (device )
198
+
199
+
194
200
# define loss function (criterion), optimizer, and learning rate scheduler
195
201
criterion = nn .CrossEntropyLoss ().to (device )
196
202
@@ -207,9 +213,9 @@ def main_worker(gpu, ngpus_per_node, args):
207
213
print ("=> loading checkpoint '{}'" .format (args .resume ))
208
214
if args .gpu is None :
209
215
checkpoint = torch .load (args .resume )
210
- elif torch . cuda . is_available () :
216
+ else :
211
217
# Map model to be loaded to specified single gpu.
212
- loc = 'cuda:{}' . format ( args .gpu )
218
+ loc = f' { device . type } : { args .gpu } '
213
219
checkpoint = torch .load (args .resume , map_location = loc )
214
220
args .start_epoch = checkpoint ['epoch' ]
215
221
best_acc1 = checkpoint ['best_acc1' ]
@@ -302,11 +308,14 @@ def main_worker(gpu, ngpus_per_node, args):
302
308
303
309
304
310
def train (train_loader , model , criterion , optimizer , epoch , device , args ):
305
- batch_time = AverageMeter ('Time' , ':6.3f' )
306
- data_time = AverageMeter ('Data' , ':6.3f' )
307
- losses = AverageMeter ('Loss' , ':.4e' )
308
- top1 = AverageMeter ('Acc@1' , ':6.2f' )
309
- top5 = AverageMeter ('Acc@5' , ':6.2f' )
311
+
312
+ use_accel = not args .no_accel and torch .accelerator .is_available ()
313
+
314
+ batch_time = AverageMeter ('Time' , use_accel , ':6.3f' , Summary .NONE )
315
+ data_time = AverageMeter ('Data' , use_accel , ':6.3f' , Summary .NONE )
316
+ losses = AverageMeter ('Loss' , use_accel , ':.4e' , Summary .NONE )
317
+ top1 = AverageMeter ('Acc@1' , use_accel , ':6.2f' , Summary .NONE )
318
+ top5 = AverageMeter ('Acc@5' , use_accel , ':6.2f' , Summary .NONE )
310
319
progress = ProgressMeter (
311
320
len (train_loader ),
312
321
[batch_time , data_time , losses , top1 , top5 ],
@@ -349,18 +358,27 @@ def train(train_loader, model, criterion, optimizer, epoch, device, args):
349
358
350
359
def validate (val_loader , model , criterion , args ):
351
360
361
+ use_accel = not args .no_accel and torch .accelerator .is_available ()
362
+
352
363
def run_validate (loader , base_progress = 0 ):
364
+
365
+ if use_accel :
366
+ device = torch .accelerator .current_accelerator ()
367
+ else :
368
+ device = torch .device ("cpu" )
369
+
353
370
with torch .no_grad ():
354
371
end = time .time ()
355
372
for i , (images , target ) in enumerate (loader ):
356
373
i = base_progress + i
357
- if args .gpu is not None and torch .cuda .is_available ():
358
- images = images .cuda (args .gpu , non_blocking = True )
359
- if torch .backends .mps .is_available ():
360
- images = images .to ('mps' )
361
- target = target .to ('mps' )
362
- if torch .cuda .is_available ():
363
- target = target .cuda (args .gpu , non_blocking = True )
374
+ if use_accel :
375
+ if args .gpu is not None and device .type == 'cuda' :
376
+ torch .accelerator .set_device_index (argps .gpu )
377
+ images = images .cuda (args .gpu , non_blocking = True )
378
+ target = target .cuda (args .gpu , non_blocking = True )
379
+ else :
380
+ images = images .to (device )
381
+ target = target .to (device )
364
382
365
383
# compute output
366
384
output = model (images )
@@ -379,10 +397,10 @@ def run_validate(loader, base_progress=0):
379
397
if i % args .print_freq == 0 :
380
398
progress .display (i + 1 )
381
399
382
- batch_time = AverageMeter ('Time' , ':6.3f' , Summary .NONE )
383
- losses = AverageMeter ('Loss' , ':.4e' , Summary .NONE )
384
- top1 = AverageMeter ('Acc@1' , ':6.2f' , Summary .AVERAGE )
385
- top5 = AverageMeter ('Acc@5' , ':6.2f' , Summary .AVERAGE )
400
+ batch_time = AverageMeter ('Time' , use_accel , ':6.3f' , Summary .NONE )
401
+ losses = AverageMeter ('Loss' , use_accel , ':.4e' , Summary .NONE )
402
+ top1 = AverageMeter ('Acc@1' , use_accel , ':6.2f' , Summary .AVERAGE )
403
+ top5 = AverageMeter ('Acc@5' , use_accel , ':6.2f' , Summary .AVERAGE )
386
404
progress = ProgressMeter (
387
405
len (val_loader ) + (args .distributed and (len (val_loader .sampler ) * args .world_size < len (val_loader .dataset ))),
388
406
[batch_time , losses , top1 , top5 ],
@@ -422,8 +440,9 @@ class Summary(Enum):
422
440
423
441
class AverageMeter (object ):
424
442
"""Computes and stores the average and current value"""
425
- def __init__ (self , name , fmt = ':f' , summary_type = Summary .AVERAGE ):
443
+ def __init__ (self , name , use_accel , fmt = ':f' , summary_type = Summary .AVERAGE ):
426
444
self .name = name
445
+ self .use_accel = use_accel
427
446
self .fmt = fmt
428
447
self .summary_type = summary_type
429
448
self .reset ()
@@ -440,11 +459,9 @@ def update(self, val, n=1):
440
459
self .count += n
441
460
self .avg = self .sum / self .count
442
461
443
- def all_reduce (self ):
444
- if torch .cuda .is_available ():
445
- device = torch .device ("cuda" )
446
- elif torch .backends .mps .is_available ():
447
- device = torch .device ("mps" )
462
+ def all_reduce (self ):
463
+ if use_accel :
464
+ device = torch .accelerator .current_accelerator ()
448
465
else :
449
466
device = torch .device ("cpu" )
450
467
total = torch .tensor ([self .sum , self .count ], dtype = torch .float32 , device = device )
0 commit comments