14
14
supports distributed training through `Ray's distributed machine learning engine
15
15
<https://ray.io/>`_.
16
16
17
- In this tutorial, we will show you how to integrate Tune into your PyTorch
17
+ In this tutorial, we will show you how to integrate Ray Tune into your PyTorch
18
18
training workflow. We will extend `this tutorial from the PyTorch documentation
19
19
<https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html>`_ for training
20
20
a CIFAR10 image classifier.
62
62
# We wrap the data loaders in their own function and pass a global data directory.
63
63
# This way we can share a data directory between different trials.
64
64
65
+
65
66
def load_data (data_dir = "./data" ):
66
67
transform = transforms .Compose ([
67
68
transforms .ToTensor (),
@@ -82,6 +83,7 @@ def load_data(data_dir="./data"):
82
83
# We can only tune those parameters that are configurable. In this example, we can specify
83
84
# the layer sizes of the fully connected layers:
84
85
86
+
85
87
class Net (nn .Module ):
86
88
def __init__ (self , l1 = 120 , l2 = 84 ):
87
89
super (Net , self ).__init__ ()
@@ -107,17 +109,20 @@ def forward(self, x):
107
109
# Now it gets interesting, because we introduce some changes to the example `from the PyTorch
108
110
# documentation <https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html>`_.
109
111
#
110
- # We wrap the training script in a function ``train_cifar(config, checkpoint =None, data_dir=None)``.
112
+ # We wrap the training script in a function ``train_cifar(config, checkpoint_dir =None, data_dir=None)``.
111
113
# As you can guess, the ``config`` parameter will receive the hyperparameters we would like to
112
- # train with. The ``checkpoint `` parameter is used to restore checkpoints. The ``data_dir`` specifies
114
+ # train with. The ``checkpoint_dir `` parameter is used to restore checkpoints. The ``data_dir`` specifies
113
115
# the directory where we load and store the data, so multiple runs can share the same data source.
114
116
#
115
117
# .. code-block:: python
116
118
#
117
119
# net = Net(config["l1"], config["l2"])
118
120
#
119
- # if checkpoint:
120
- # net.load_state_dict(torch.load(checkpoint))
121
+ # if checkpoint_dir:
122
+ # model_state, optimizer_state = torch.load(
123
+ # os.path.join(checkpoint_dir, "checkpoint"))
124
+ # net.load_state_dict(model_state)
125
+ # optimizer.load_state_dict(optimizer_state)
121
126
#
122
127
# The learning rate of the optimizer is made configurable, too:
123
128
#
@@ -162,25 +167,25 @@ def forward(self, x):
162
167
# Communicating with Ray Tune
163
168
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~
164
169
#
165
- # The most interesting part is the communication with Tune:
170
+ # The most interesting part is the communication with Ray Tune:
166
171
#
167
172
# .. code-block:: python
168
173
#
169
- # checkpoint_dir = tune.make_checkpoint_dir(epoch)
170
- # path = os.path.join(checkpoint_dir, "checkpoint")
171
- # torch.save((net.state_dict(), optimizer.state_dict()), path)
172
- # tune.save_checkpoint(path)
174
+ # with tune.checkpoint_dir(epoch) as checkpoint_dir:
175
+ # path = os.path.join(checkpoint_dir, "checkpoint")
176
+ # torch.save((net.state_dict(), optimizer.state_dict()), path)
173
177
#
174
178
# tune.report(loss=(val_loss / val_steps), accuracy=correct / total)
175
179
#
176
- # Here we first save a checkpoint and then report some metrics back to Tune. Specifically,
177
- # we send the validation loss and accuracy back to Tune. Tune can then use these metrics
180
+ # Here we first save a checkpoint and then report some metrics back to Ray Tune. Specifically,
181
+ # we send the validation loss and accuracy back to Ray Tune. Ray Tune can then use these metrics
178
182
# to decide which hyperparameter configuration lead to the best results. These metrics
179
183
# can also be used to stop bad performing trials early in order to avoid wasting
180
184
# resources on those trials.
181
185
#
182
186
# The checkpoint saving is optional, however, it is necessary if we wanted to use advanced
183
- # schedulers like `Population Based Training <https://docs.ray.io/en/master/tune/tutorials/tune-advanced-tutorial.html>`_.
187
+ # schedulers like
188
+ # `Population Based Training <https://docs.ray.io/en/master/tune/tutorials/tune-advanced-tutorial.html>`_.
184
189
# Also, by saving the checkpoint we can later load the trained models and validate them
185
190
# on a test set.
186
191
#
@@ -189,7 +194,8 @@ def forward(self, x):
189
194
#
190
195
# The full code example looks like this:
191
196
192
- def train_cifar (config , checkpoint = None , data_dir = None ):
197
+
198
+ def train_cifar (config , checkpoint_dir = None , data_dir = None ):
193
199
net = Net (config ["l1" ], config ["l2" ])
194
200
195
201
device = "cpu"
@@ -202,9 +208,9 @@ def train_cifar(config, checkpoint=None, data_dir=None):
202
208
criterion = nn .CrossEntropyLoss ()
203
209
optimizer = optim .SGD (net .parameters (), lr = config ["lr" ], momentum = 0.9 )
204
210
205
- if checkpoint :
206
- print ( "loading checkpoint {}" . format ( checkpoint ))
207
- model_state , optimizer_state = torch . load ( checkpoint )
211
+ if checkpoint_dir :
212
+ model_state , optimizer_state = torch . load (
213
+ os . path . join ( checkpoint_dir , " checkpoint" ) )
208
214
net .load_state_dict (model_state )
209
215
optimizer .load_state_dict (optimizer_state )
210
216
@@ -269,10 +275,9 @@ def train_cifar(config, checkpoint=None, data_dir=None):
269
275
val_loss += loss .cpu ().numpy ()
270
276
val_steps += 1
271
277
272
- checkpoint_dir = tune .make_checkpoint_dir (epoch )
273
- path = os .path .join (checkpoint_dir , "checkpoint" )
274
- torch .save ((net .state_dict (), optimizer .state_dict ()), path )
275
- tune .save_checkpoint (path )
278
+ with tune .checkpoint_dir (epoch ) as checkpoint_dir :
279
+ path = os .path .join (checkpoint_dir , "checkpoint" )
280
+ torch .save ((net .state_dict (), optimizer .state_dict ()), path )
276
281
277
282
tune .report (loss = (val_loss / val_steps ), accuracy = correct / total )
278
283
print ("Finished Training" )
@@ -286,6 +291,7 @@ def train_cifar(config, checkpoint=None, data_dir=None):
286
291
# set with data that has not been used for training the model. We also wrap this in a
287
292
# function:
288
293
294
+
289
295
def test_accuracy (net , device = "cpu" ):
290
296
trainset , testset = load_data ()
291
297
@@ -311,7 +317,7 @@ def test_accuracy(net, device="cpu"):
311
317
#
312
318
# Configuring the search space
313
319
# ----------------------------
314
- # Lastly, we need to define Tune's search space. Here is an example:
320
+ # Lastly, we need to define Ray Tune's search space. Here is an example:
315
321
#
316
322
# .. code-block:: python
317
323
#
@@ -328,7 +334,7 @@ def test_accuracy(net, device="cpu"):
328
334
# The ``lr`` (learning rate) should be uniformly sampled between 0.0001 and 0.1. Lastly,
329
335
# the batch size is a choice between 2, 4, 8, and 16.
330
336
#
331
- # At each trial, Tune will now randomly sample a combination of parameters from these
337
+ # At each trial, Ray Tune will now randomly sample a combination of parameters from these
332
338
# search spaces. It will then train a number of models in parallel and find the best
333
339
# performing one among these. We also use the ``ASHAScheduler`` which will terminate bad
334
340
# performing trials early.
@@ -366,6 +372,7 @@ def test_accuracy(net, device="cpu"):
366
372
#
367
373
# The full main function looks like this:
368
374
375
+
369
376
def main (num_samples = 10 , max_num_epochs = 10 , gpus_per_trial = 2 ):
370
377
data_dir = os .path .abspath ("./data" )
371
378
load_data (data_dir )
@@ -390,8 +397,7 @@ def main(num_samples=10, max_num_epochs=10, gpus_per_trial=2):
390
397
config = config ,
391
398
num_samples = num_samples ,
392
399
scheduler = scheduler ,
393
- progress_reporter = reporter ,
394
- checkpoint_at_end = True )
400
+ progress_reporter = reporter )
395
401
396
402
best_trial = result .get_best_trial ("loss" , "min" , "last" )
397
403
print ("Best trial config: {}" .format (best_trial .config ))
@@ -408,7 +414,9 @@ def main(num_samples=10, max_num_epochs=10, gpus_per_trial=2):
408
414
best_trained_model = nn .DataParallel (best_trained_model )
409
415
best_trained_model .to (device )
410
416
411
- model_state , optimizer_state = torch .load (best_trial .checkpoint .value )
417
+ best_checkpoint_dir = best_trial .checkpoint .value
418
+ model_state , optimizer_state = torch .load (os .path .join (
419
+ best_checkpoint_dir , "checkpoint" ))
412
420
best_trained_model .load_state_dict (model_state )
413
421
414
422
test_acc = test_accuracy (best_trained_model , device )
0 commit comments