40
40
Let's start with the imports:
41
41
"""
42
42
from functools import partial
43
- import numpy as np
44
43
import os
45
44
import torch
46
45
import torch .nn as nn
50
49
import torchvision
51
50
import torchvision .transforms as transforms
52
51
from ray import tune
53
- from ray .tune import CLIReporter
52
+ from ray .air import Checkpoint , session
54
53
from ray .tune .schedulers import ASHAScheduler
55
54
56
55
######################################################################
64
63
65
64
66
65
def load_data (data_dir = "./data" ):
67
- transform = transforms .Compose ([
68
- transforms .ToTensor (),
69
- transforms .Normalize ((0.5 , 0.5 , 0.5 ), (0.5 , 0.5 , 0.5 ))
70
- ])
66
+ transform = transforms .Compose (
67
+ [transforms .ToTensor (), transforms .Normalize ((0.5 , 0.5 , 0.5 ), (0.5 , 0.5 , 0.5 ))]
68
+ )
71
69
72
70
trainset = torchvision .datasets .CIFAR10 (
73
- root = data_dir , train = True , download = True , transform = transform )
71
+ root = data_dir , train = True , download = True , transform = transform
72
+ )
74
73
75
74
testset = torchvision .datasets .CIFAR10 (
76
- root = data_dir , train = False , download = True , transform = transform )
75
+ root = data_dir , train = False , download = True , transform = transform
76
+ )
77
77
78
78
return trainset , testset
79
79
80
+
80
81
######################################################################
81
82
# Configurable neural network
82
83
# ---------------------------
83
- # We can only tune those parameters that are configurable. In this example, we can specify
84
+ # We can only tune those parameters that are configurable.
85
+ # In this example, we can specify
84
86
# the layer sizes of the fully connected layers:
85
87
86
88
@@ -97,32 +99,40 @@ def __init__(self, l1=120, l2=84):
97
99
def forward (self , x ):
98
100
x = self .pool (F .relu (self .conv1 (x )))
99
101
x = self .pool (F .relu (self .conv2 (x )))
100
- x = x . view ( - 1 , 16 * 5 * 5 )
102
+ x = torch . flatten ( x , 1 ) # flatten all dimensions except batch
101
103
x = F .relu (self .fc1 (x ))
102
104
x = F .relu (self .fc2 (x ))
103
105
x = self .fc3 (x )
104
106
return x
105
107
108
+
106
109
######################################################################
107
110
# The train function
108
111
# ------------------
109
112
# Now it gets interesting, because we introduce some changes to the example `from the PyTorch
110
113
# documentation <https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html>`_.
111
114
#
112
- # We wrap the training script in a function ``train_cifar(config, checkpoint_dir=None, data_dir=None)``.
113
- # As you can guess, the ``config`` parameter will receive the hyperparameters we would like to
114
- # train with. The ``checkpoint_dir`` parameter is used to restore checkpoints. The ``data_dir`` specifies
115
- # the directory where we load and store the data, so multiple runs can share the same data source.
115
+ # We wrap the training script in a function ``train_cifar(config, data_dir=None)``.
116
+ # The ``config`` parameter will receive the hyperparameters we would like to
117
+ # train with. The ``data_dir`` specifies the directory where we load and store the data,
118
+ # so that multiple runs can share the same data source.
119
+ # We also load the model and optimizer state at the start of the run, if a checkpoint
120
+ # is provided. Further down in this tutorial you will find information on how
121
+ # to save the checkpoint and what it is used for.
116
122
#
117
123
# .. code-block:: python
118
124
#
119
125
# net = Net(config["l1"], config["l2"])
120
126
#
121
- # if checkpoint_dir:
122
- # model_state, optimizer_state = torch.load(
123
- # os.path.join(checkpoint_dir, "checkpoint"))
124
- # net.load_state_dict(model_state)
125
- # optimizer.load_state_dict(optimizer_state)
127
+ # checkpoint = session.get_checkpoint()
128
+ #
129
+ # if checkpoint:
130
+ # checkpoint_state = checkpoint.to_dict()
131
+ # start_epoch = checkpoint_state["epoch"]
132
+ # net.load_state_dict(checkpoint_state["net_state_dict"])
133
+ # optimizer.load_state_dict(checkpoint_state["optimizer_state_dict"])
134
+ # else:
135
+ # start_epoch = 0
126
136
#
127
137
# The learning rate of the optimizer is made configurable, too:
128
138
#
@@ -171,11 +181,17 @@ def forward(self, x):
171
181
#
172
182
# .. code-block:: python
173
183
#
174
- # with tune.checkpoint_dir(epoch) as checkpoint_dir:
175
- # path = os.path.join(checkpoint_dir, "checkpoint")
176
- # torch.save((net.state_dict(), optimizer.state_dict()), path)
184
+ # checkpoint_data = {
185
+ # "epoch": epoch,
186
+ # "net_state_dict": net.state_dict(),
187
+ # "optimizer_state_dict": optimizer.state_dict(),
188
+ # }
189
+ # checkpoint = Checkpoint.from_dict(checkpoint_data)
177
190
#
178
- # tune.report(loss=(val_loss / val_steps), accuracy=correct / total)
191
+ # session.report(
192
+ # {"loss": val_loss / val_steps, "accuracy": correct / total},
193
+ # checkpoint=checkpoint,
194
+ # )
179
195
#
180
196
# Here we first save a checkpoint and then report some metrics back to Ray Tune. Specifically,
181
197
# we send the validation loss and accuracy back to Ray Tune. Ray Tune can then use these metrics
@@ -187,15 +203,16 @@ def forward(self, x):
187
203
# schedulers like
188
204
# `Population Based Training <https://docs.ray.io/en/master/tune/tutorials/tune-advanced-tutorial.html>`_.
189
205
# Also, by saving the checkpoint we can later load the trained models and validate them
190
- # on a test set.
206
+ # on a test set. Lastly, saving checkpoints is useful for fault tolerance, and it allows
207
+ # us to interrupt training and continue training later.
191
208
#
192
209
# Full training function
193
210
# ~~~~~~~~~~~~~~~~~~~~~~
194
211
#
195
212
# The full code example looks like this:
196
213
197
214
198
- def train_cifar (config , checkpoint_dir = None , data_dir = None ):
215
+ def train_cifar (config , data_dir = None ):
199
216
net = Net (config ["l1" ], config ["l2" ])
200
217
201
218
device = "cpu"
@@ -208,30 +225,31 @@ def train_cifar(config, checkpoint_dir=None, data_dir=None):
208
225
criterion = nn .CrossEntropyLoss ()
209
226
optimizer = optim .SGD (net .parameters (), lr = config ["lr" ], momentum = 0.9 )
210
227
211
- if checkpoint_dir :
212
- model_state , optimizer_state = torch .load (
213
- os .path .join (checkpoint_dir , "checkpoint" ))
214
- net .load_state_dict (model_state )
215
- optimizer .load_state_dict (optimizer_state )
228
+ checkpoint = session .get_checkpoint ()
229
+
230
+ if checkpoint :
231
+ checkpoint_state = checkpoint .to_dict ()
232
+ start_epoch = checkpoint_state ["epoch" ]
233
+ net .load_state_dict (checkpoint_state ["net_state_dict" ])
234
+ optimizer .load_state_dict (checkpoint_state ["optimizer_state_dict" ])
235
+ else :
236
+ start_epoch = 0
216
237
217
238
trainset , testset = load_data (data_dir )
218
239
219
240
test_abs = int (len (trainset ) * 0.8 )
220
241
train_subset , val_subset = random_split (
221
- trainset , [test_abs , len (trainset ) - test_abs ])
242
+ trainset , [test_abs , len (trainset ) - test_abs ]
243
+ )
222
244
223
245
trainloader = torch .utils .data .DataLoader (
224
- train_subset ,
225
- batch_size = int (config ["batch_size" ]),
226
- shuffle = True ,
227
- num_workers = 8 )
246
+ train_subset , batch_size = int (config ["batch_size" ]), shuffle = True , num_workers = 8
247
+ )
228
248
valloader = torch .utils .data .DataLoader (
229
- val_subset ,
230
- batch_size = int (config ["batch_size" ]),
231
- shuffle = True ,
232
- num_workers = 8 )
249
+ val_subset , batch_size = int (config ["batch_size" ]), shuffle = True , num_workers = 8
250
+ )
233
251
234
- for epoch in range (10 ): # loop over the dataset multiple times
252
+ for epoch in range (start_epoch , 10 ): # loop over the dataset multiple times
235
253
running_loss = 0.0
236
254
epoch_steps = 0
237
255
for i , data in enumerate (trainloader , 0 ):
@@ -252,8 +270,10 @@ def train_cifar(config, checkpoint_dir=None, data_dir=None):
252
270
running_loss += loss .item ()
253
271
epoch_steps += 1
254
272
if i % 2000 == 1999 : # print every 2000 mini-batches
255
- print ("[%d, %5d] loss: %.3f" % (epoch + 1 , i + 1 ,
256
- running_loss / epoch_steps ))
273
+ print (
274
+ "[%d, %5d] loss: %.3f"
275
+ % (epoch + 1 , i + 1 , running_loss / epoch_steps )
276
+ )
257
277
running_loss = 0.0
258
278
259
279
# Validation loss
@@ -275,13 +295,20 @@ def train_cifar(config, checkpoint_dir=None, data_dir=None):
275
295
val_loss += loss .cpu ().numpy ()
276
296
val_steps += 1
277
297
278
- with tune .checkpoint_dir (epoch ) as checkpoint_dir :
279
- path = os .path .join (checkpoint_dir , "checkpoint" )
280
- torch .save ((net .state_dict (), optimizer .state_dict ()), path )
281
-
282
- tune .report (loss = (val_loss / val_steps ), accuracy = correct / total )
298
+ checkpoint_data = {
299
+ "epoch" : epoch ,
300
+ "net_state_dict" : net .state_dict (),
301
+ "optimizer_state_dict" : optimizer .state_dict (),
302
+ }
303
+ checkpoint = Checkpoint .from_dict (checkpoint_data )
304
+
305
+ session .report (
306
+ {"loss" : val_loss / val_steps , "accuracy" : correct / total },
307
+ checkpoint = checkpoint ,
308
+ )
283
309
print ("Finished Training" )
284
310
311
+
285
312
######################################################################
286
313
# As you can see, most of the code is adapted directly from the original example.
287
314
#
@@ -296,7 +323,8 @@ def test_accuracy(net, device="cpu"):
296
323
trainset , testset = load_data ()
297
324
298
325
testloader = torch .utils .data .DataLoader (
299
- testset , batch_size = 4 , shuffle = False , num_workers = 2 )
326
+ testset , batch_size = 4 , shuffle = False , num_workers = 2
327
+ )
300
328
301
329
correct = 0
302
330
total = 0
@@ -311,6 +339,7 @@ def test_accuracy(net, device="cpu"):
311
339
312
340
return correct / total
313
341
342
+
314
343
######################################################################
315
344
# The function also expects a ``device`` parameter, so we can do the
316
345
# test set validation on a GPU.
@@ -322,14 +351,14 @@ def test_accuracy(net, device="cpu"):
322
351
# .. code-block:: python
323
352
#
324
353
# config = {
325
- # "l1": tune.sample_from(lambda _: 2**np.random.randint(2, 9) ),
326
- # "l2": tune.sample_from(lambda _: 2**np.random.randint(2, 9) ),
354
+ # "l1": tune.choice([2 ** i for i in range(9)] ),
355
+ # "l2": tune.choice([2 ** i for i in range(9)] ),
327
356
# "lr": tune.loguniform(1e-4, 1e-1),
328
357
# "batch_size": tune.choice([2, 4, 8, 16])
329
358
# }
330
359
#
331
- # The ``tune.sample_from ()`` function makes it possible to define your own sample
332
- # methods to obtain hyperparameters. In this example, the ``l1`` and ``l2`` parameters
360
+ # The ``tune.choice ()`` accepts a list of values that are uniformly sampled from.
361
+ # In this example, the ``l1`` and ``l2`` parameters
333
362
# should be powers of 2 between 4 and 256, so either 4, 8, 16, 32, 64, 128, or 256.
334
363
# The ``lr`` (learning rate) should be uniformly sampled between 0.0001 and 0.1. Lastly,
335
364
# the batch size is a choice between 2, 4, 8, and 16.
@@ -353,7 +382,6 @@ def test_accuracy(net, device="cpu"):
353
382
# config=config,
354
383
# num_samples=num_samples,
355
384
# scheduler=scheduler,
356
- # progress_reporter=reporter,
357
385
# checkpoint_at_end=True)
358
386
#
359
387
# You can specify the number of CPUs, which are then available e.g.
@@ -377,34 +405,30 @@ def main(num_samples=10, max_num_epochs=10, gpus_per_trial=2):
377
405
data_dir = os .path .abspath ("./data" )
378
406
load_data (data_dir )
379
407
config = {
380
- "l1" : tune .sample_from ( lambda _ : 2 ** np . random . randint ( 2 , 9 ) ),
381
- "l2" : tune .sample_from ( lambda _ : 2 ** np . random . randint ( 2 , 9 ) ),
408
+ "l1" : tune .choice ([ 2 ** i for i in range ( 9 )] ),
409
+ "l2" : tune .choice ([ 2 ** i for i in range ( 9 )] ),
382
410
"lr" : tune .loguniform (1e-4 , 1e-1 ),
383
- "batch_size" : tune .choice ([2 , 4 , 8 , 16 ])
411
+ "batch_size" : tune .choice ([2 , 4 , 8 , 16 ]),
384
412
}
385
413
scheduler = ASHAScheduler (
386
414
metric = "loss" ,
387
415
mode = "min" ,
388
416
max_t = max_num_epochs ,
389
417
grace_period = 1 ,
390
- reduction_factor = 2 )
391
- reporter = CLIReporter (
392
- # ``parameter_columns=["l1", "l2", "lr", "batch_size"]``,
393
- metric_columns = ["loss" , "accuracy" , "training_iteration" ])
418
+ reduction_factor = 2 ,
419
+ )
394
420
result = tune .run (
395
421
partial (train_cifar , data_dir = data_dir ),
396
422
resources_per_trial = {"cpu" : 2 , "gpu" : gpus_per_trial },
397
423
config = config ,
398
424
num_samples = num_samples ,
399
425
scheduler = scheduler ,
400
- progress_reporter = reporter )
426
+ )
401
427
402
428
best_trial = result .get_best_trial ("loss" , "min" , "last" )
403
- print ("Best trial config: {}" .format (best_trial .config ))
404
- print ("Best trial final validation loss: {}" .format (
405
- best_trial .last_result ["loss" ]))
406
- print ("Best trial final validation accuracy: {}" .format (
407
- best_trial .last_result ["accuracy" ]))
429
+ print (f"Best trial config: { best_trial .config } " )
430
+ print (f"Best trial final validation loss: { best_trial .last_result ['loss' ]} " )
431
+ print (f"Best trial final validation accuracy: { best_trial .last_result ['accuracy' ]} " )
408
432
409
433
best_trained_model = Net (best_trial .config ["l1" ], best_trial .config ["l2" ])
410
434
device = "cpu"
@@ -414,10 +438,10 @@ def main(num_samples=10, max_num_epochs=10, gpus_per_trial=2):
414
438
best_trained_model = nn .DataParallel (best_trained_model )
415
439
best_trained_model .to (device )
416
440
417
- best_checkpoint_dir = best_trial .checkpoint .value
418
- model_state , optimizer_state = torch . load ( os . path . join (
419
- best_checkpoint_dir , "checkpoint" ))
420
- best_trained_model .load_state_dict (model_state )
441
+ best_checkpoint = best_trial .checkpoint .to_air_checkpoint ()
442
+ best_checkpoint_data = best_checkpoint . to_dict ()
443
+
444
+ best_trained_model .load_state_dict (best_checkpoint_data [ "net_state_dict" ] )
421
445
422
446
test_acc = test_accuracy (best_trained_model , device )
423
447
print ("Best trial test set accuracy: {}" .format (test_acc ))
@@ -428,6 +452,7 @@ def main(num_samples=10, max_num_epochs=10, gpus_per_trial=2):
428
452
# Fixes ``AttributeError: '_LoggingTee' object has no attribute 'fileno'``.
429
453
# This is only needed to run with sphinx-build.
430
454
import sys
455
+
431
456
sys .stdout .fileno = lambda : False
432
457
# sphinx_gallery_end_ignore
433
458
# You can change the number of GPUs per trial here:
@@ -439,30 +464,29 @@ def main(num_samples=10, max_num_epochs=10, gpus_per_trial=2):
439
464
#
440
465
# ::
441
466
#
442
- # Number of trials: 10 (10 TERMINATED)
443
- # +-----+------+------+-------------+--------------+---------+------------+--------------------+
444
- # | ... | l1 | l2 | lr | batch_size | loss | accuracy | training_iteration |
445
- # |-----+------+------+-------------+--------------+---------+------------+--------------------|
446
- # | ... | 64 | 4 | 0.00011629 | 2 | 1.87273 | 0.244 | 2 |
447
- # | ... | 32 | 64 | 0.000339763 | 8 | 1.23603 | 0.567 | 8 |
448
- # | ... | 8 | 16 | 0.00276249 | 16 | 1.1815 | 0.5836 | 10 |
449
- # | ... | 4 | 64 | 0.000648721 | 4 | 1.31131 | 0.5224 | 8 |
450
- # | ... | 32 | 16 | 0.000340753 | 8 | 1.26454 | 0.5444 | 8 |
451
- # | ... | 8 | 4 | 0.000699775 | 8 | 1.99594 | 0.1983 | 2 |
452
- # | ... | 256 | 8 | 0.0839654 | 16 | 2.3119 | 0.0993 | 1 |
453
- # | ... | 16 | 128 | 0.0758154 | 16 | 2.33575 | 0.1327 | 1 |
454
- # | ... | 16 | 8 | 0.0763312 | 16 | 2.31129 | 0.1042 | 4 |
455
- # | ... | 128 | 16 | 0.000124903 | 4 | 2.26917 | 0.1945 | 1 |
456
- # +-----+------+------+-------------+--------------+---------+------------+--------------------+
457
- #
458
- #
459
- # Best trial config: {'l1': 8, 'l2': 16, 'lr': 0.00276249, 'batch_size': 16, 'data_dir': '...'}
460
- # Best trial final validation loss: 1.181501
461
- # Best trial final validation accuracy: 0.5836
462
- # Best trial test set accuracy: 0.5806
467
+ # Number of trials: 10/10 (10 TERMINATED)
468
+ # +-----+--------------+------+------+-------------+--------+---------+------------+
469
+ # | ... | batch_size | l1 | l2 | lr | iter | loss | accuracy |
470
+ # |-----+--------------+------+------+-------------+--------+---------+------------|
471
+ # | ... | 2 | 1 | 256 | 0.000668163 | 1 | 2.31479 | 0.0977 |
472
+ # | ... | 4 | 64 | 8 | 0.0331514 | 1 | 2.31605 | 0.0983 |
473
+ # | ... | 4 | 2 | 1 | 0.000150295 | 1 | 2.30755 | 0.1023 |
474
+ # | ... | 16 | 32 | 32 | 0.0128248 | 10 | 1.66912 | 0.4391 |
475
+ # | ... | 4 | 8 | 128 | 0.00464561 | 2 | 1.7316 | 0.3463 |
476
+ # | ... | 8 | 256 | 8 | 0.00031556 | 1 | 2.19409 | 0.1736 |
477
+ # | ... | 4 | 16 | 256 | 0.00574329 | 2 | 1.85679 | 0.3368 |
478
+ # | ... | 8 | 2 | 2 | 0.00325652 | 1 | 2.30272 | 0.0984 |
479
+ # | ... | 2 | 2 | 2 | 0.000342987 | 2 | 1.76044 | 0.292 |
480
+ # | ... | 4 | 64 | 32 | 0.003734 | 8 | 1.53101 | 0.4761 |
481
+ # +-----+--------------+------+------+-------------+--------+---------+------------+
482
+ #
483
+ # Best trial config: {'l1': 64, 'l2': 32, 'lr': 0.0037339984519545164, 'batch_size': 4}
484
+ # Best trial final validation loss: 1.5310075663924216
485
+ # Best trial final validation accuracy: 0.4761
486
+ # Best trial test set accuracy: 0.4737
463
487
#
464
488
# Most trials have been stopped early in order to avoid wasting resources.
465
- # The best performing trial achieved a validation accuracy of about 58 %, which could
489
+ # The best performing trial achieved a validation accuracy of about 47 %, which could
466
490
# be confirmed on the test set.
467
491
#
468
492
# So that's it! You can now tune the parameters of your PyTorch models.
0 commit comments