@@ -105,7 +105,7 @@ def group_examples(self):
105
105
"""
106
106
107
107
# get the targets from MNIST dataset
108
- np_arr = np .array (self .dataset .targets .clone ())
108
+ np_arr = np .array (self .dataset .targets .clone (), dtype = None , copy = None )
109
109
110
110
# group examples based on class
111
111
self .grouped_examples = {}
@@ -247,10 +247,8 @@ def main():
247
247
help = 'learning rate (default: 1.0)' )
248
248
parser .add_argument ('--gamma' , type = float , default = 0.7 , metavar = 'M' ,
249
249
help = 'Learning rate step gamma (default: 0.7)' )
250
- parser .add_argument ('--no-cuda' , action = 'store_true' , default = False ,
251
- help = 'disables CUDA training' )
252
- parser .add_argument ('--no-mps' , action = 'store_true' , default = False ,
253
- help = 'disables macOS GPU training' )
250
+ parser .add_argument ('--accel' , action = 'store_true' ,
251
+ help = 'use accelerator' )
254
252
parser .add_argument ('--dry-run' , action = 'store_true' , default = False ,
255
253
help = 'quickly check a single pass' )
256
254
parser .add_argument ('--seed' , type = int , default = 1 , metavar = 'S' ,
@@ -260,22 +258,25 @@ def main():
260
258
parser .add_argument ('--save-model' , action = 'store_true' , default = False ,
261
259
help = 'For Saving the current Model' )
262
260
args = parser .parse_args ()
263
-
264
- use_cuda = not args .no_cuda and torch .cuda .is_available ()
265
- use_mps = not args .no_mps and torch .backends .mps .is_available ()
266
261
267
262
torch .manual_seed (args .seed )
268
263
269
- if use_cuda :
270
- device = torch .device ("cuda" )
271
- elif use_mps :
272
- device = torch .device ("mps" )
264
+ if args .accel and not torch .accelerator .is_available ():
265
+ print ("ERROR: accelerator is not available, try running on CPU" )
266
+ sys .exit (1 )
267
+ if not args .accel and torch .accelerator .is_available ():
268
+ print ("WARNING: accelerator is available, run with --accel to enable it" )
269
+
270
+ if args .accel :
271
+ device = torch .accelerator .current_accelerator ()
273
272
else :
274
273
device = torch .device ("cpu" )
274
+
275
+ print (f"Using device: { device } " )
275
276
276
277
train_kwargs = {'batch_size' : args .batch_size }
277
278
test_kwargs = {'batch_size' : args .test_batch_size }
278
- if use_cuda :
279
+ if device == "cuda" :
279
280
cuda_kwargs = {'num_workers' : 1 ,
280
281
'pin_memory' : True ,
281
282
'shuffle' : True }
0 commit comments