Skip to content

Commit faccec1

Browse files
author
Kai Fricke
committed
Fix checkpoint API
1 parent f075938 commit faccec1

File tree

2 files changed

+35
-32
lines changed

2 files changed

+35
-32
lines changed

beginner_source/hyperparameter_tuning_tutorial.py

Lines changed: 34 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
supports distributed training through `Ray's distributed machine learning engine
1515
<https://ray.io/>`_.
1616
17-
In this tutorial, we will show you how to integrate Tune into your PyTorch
17+
In this tutorial, we will show you how to integrate Ray Tune into your PyTorch
1818
training workflow. We will extend `this tutorial from the PyTorch documentation
1919
<https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html>`_ for training
2020
a CIFAR10 image classifier.
@@ -62,6 +62,7 @@
6262
# We wrap the data loaders in their own function and pass a global data directory.
6363
# This way we can share a data directory between different trials.
6464

65+
6566
def load_data(data_dir="./data"):
6667
transform = transforms.Compose([
6768
transforms.ToTensor(),
@@ -82,6 +83,7 @@ def load_data(data_dir="./data"):
8283
# We can only tune those parameters that are configurable. In this example, we can specify
8384
# the layer sizes of the fully connected layers:
8485

86+
8587
class Net(nn.Module):
8688
def __init__(self, l1=120, l2=84):
8789
super(Net, self).__init__()
@@ -107,17 +109,20 @@ def forward(self, x):
107109
# Now it gets interesting, because we introduce some changes to the example `from the PyTorch
108110
# documentation <https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html>`_.
109111
#
110-
# We wrap the training script in a function ``train_cifar(config, checkpoint=None, data_dir=None)``.
112+
# We wrap the training script in a function ``train_cifar(config, checkpoint_dir=None, data_dir=None)``.
111113
# As you can guess, the ``config`` parameter will receive the hyperparameters we would like to
112-
# train with. The ``checkpoint`` parameter is used to restore checkpoints. The ``data_dir`` specifies
114+
# train with. The ``checkpoint_dir`` parameter is used to restore checkpoints. The ``data_dir`` specifies
113115
# the directory where we load and store the data, so multiple runs can share the same data source.
114116
#
115117
# .. code-block:: python
116118
#
117119
# net = Net(config["l1"], config["l2"])
118120
#
119-
# if checkpoint:
120-
# net.load_state_dict(torch.load(checkpoint))
121+
# if checkpoint_dir:
122+
# model_state, optimizer_state = torch.load(
123+
# os.path.join(checkpoint_dir, "checkpoint"))
124+
# net.load_state_dict(model_state)
125+
# optimizer.load_state_dict(optimizer_state)
121126
#
122127
# The learning rate of the optimizer is made configurable, too:
123128
#
@@ -162,25 +167,25 @@ def forward(self, x):
162167
# Communicating with Ray Tune
163168
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~
164169
#
165-
# The most interesting part is the communication with Tune:
170+
# The most interesting part is the communication with Ray Tune:
166171
#
167172
# .. code-block:: python
168173
#
169-
# checkpoint_dir = tune.make_checkpoint_dir(epoch)
170-
# path = os.path.join(checkpoint_dir, "checkpoint")
171-
# torch.save((net.state_dict(), optimizer.state_dict()), path)
172-
# tune.save_checkpoint(path)
174+
# with tune.checkpoint_dir(epoch) as checkpoint_dir:
175+
# path = os.path.join(checkpoint_dir, "checkpoint")
176+
# torch.save((net.state_dict(), optimizer.state_dict()), path)
173177
#
174178
# tune.report(loss=(val_loss / val_steps), accuracy=correct / total)
175179
#
176-
# Here we first save a checkpoint and then report some metrics back to Tune. Specifically,
177-
# we send the validation loss and accuracy back to Tune. Tune can then use these metrics
180+
# Here we first save a checkpoint and then report some metrics back to Ray Tune. Specifically,
181+
# we send the validation loss and accuracy back to Ray Tune. Ray Tune can then use these metrics
178182
# to decide which hyperparameter configuration lead to the best results. These metrics
179183
# can also be used to stop bad performing trials early in order to avoid wasting
180184
# resources on those trials.
181185
#
182186
# The checkpoint saving is optional, however, it is necessary if we wanted to use advanced
183-
# schedulers like `Population Based Training <https://docs.ray.io/en/master/tune/tutorials/tune-advanced-tutorial.html>`_.
187+
# schedulers like
188+
# `Population Based Training <https://docs.ray.io/en/master/tune/tutorials/tune-advanced-tutorial.html>`_.
184189
# Also, by saving the checkpoint we can later load the trained models and validate them
185190
# on a test set.
186191
#
@@ -189,7 +194,8 @@ def forward(self, x):
189194
#
190195
# The full code example looks like this:
191196

192-
def train_cifar(config, checkpoint=None, data_dir=None):
197+
198+
def train_cifar(config, checkpoint_dir=None, data_dir=None):
193199
net = Net(config["l1"], config["l2"])
194200

195201
device = "cpu"
@@ -202,9 +208,9 @@ def train_cifar(config, checkpoint=None, data_dir=None):
202208
criterion = nn.CrossEntropyLoss()
203209
optimizer = optim.SGD(net.parameters(), lr=config["lr"], momentum=0.9)
204210

205-
if checkpoint:
206-
print("loading checkpoint {}".format(checkpoint))
207-
model_state, optimizer_state = torch.load(checkpoint)
211+
if checkpoint_dir:
212+
model_state, optimizer_state = torch.load(
213+
os.path.join(checkpoint_dir, "checkpoint"))
208214
net.load_state_dict(model_state)
209215
optimizer.load_state_dict(optimizer_state)
210216

@@ -269,10 +275,9 @@ def train_cifar(config, checkpoint=None, data_dir=None):
269275
val_loss += loss.cpu().numpy()
270276
val_steps += 1
271277

272-
checkpoint_dir = tune.make_checkpoint_dir(epoch)
273-
path = os.path.join(checkpoint_dir, "checkpoint")
274-
torch.save((net.state_dict(), optimizer.state_dict()), path)
275-
tune.save_checkpoint(path)
278+
with tune.checkpoint_dir(epoch) as checkpoint_dir:
279+
path = os.path.join(checkpoint_dir, "checkpoint")
280+
torch.save((net.state_dict(), optimizer.state_dict()), path)
276281

277282
tune.report(loss=(val_loss / val_steps), accuracy=correct / total)
278283
print("Finished Training")
@@ -286,6 +291,7 @@ def train_cifar(config, checkpoint=None, data_dir=None):
286291
# set with data that has not been used for training the model. We also wrap this in a
287292
# function:
288293

294+
289295
def test_accuracy(net, device="cpu"):
290296
trainset, testset = load_data()
291297

@@ -311,7 +317,7 @@ def test_accuracy(net, device="cpu"):
311317
#
312318
# Configuring the search space
313319
# ----------------------------
314-
# Lastly, we need to define Tune's search space. Here is an example:
320+
# Lastly, we need to define Ray Tune's search space. Here is an example:
315321
#
316322
# .. code-block:: python
317323
#
@@ -328,7 +334,7 @@ def test_accuracy(net, device="cpu"):
328334
# The ``lr`` (learning rate) should be uniformly sampled between 0.0001 and 0.1. Lastly,
329335
# the batch size is a choice between 2, 4, 8, and 16.
330336
#
331-
# At each trial, Tune will now randomly sample a combination of parameters from these
337+
# At each trial, Ray Tune will now randomly sample a combination of parameters from these
332338
# search spaces. It will then train a number of models in parallel and find the best
333339
# performing one among these. We also use the ``ASHAScheduler`` which will terminate bad
334340
# performing trials early.
@@ -366,6 +372,7 @@ def test_accuracy(net, device="cpu"):
366372
#
367373
# The full main function looks like this:
368374

375+
369376
def main(num_samples=10, max_num_epochs=10, gpus_per_trial=2):
370377
data_dir = os.path.abspath("./data")
371378
load_data(data_dir)
@@ -390,8 +397,7 @@ def main(num_samples=10, max_num_epochs=10, gpus_per_trial=2):
390397
config=config,
391398
num_samples=num_samples,
392399
scheduler=scheduler,
393-
progress_reporter=reporter,
394-
checkpoint_at_end=True)
400+
progress_reporter=reporter)
395401

396402
best_trial = result.get_best_trial("loss", "min", "last")
397403
print("Best trial config: {}".format(best_trial.config))
@@ -408,7 +414,9 @@ def main(num_samples=10, max_num_epochs=10, gpus_per_trial=2):
408414
best_trained_model = nn.DataParallel(best_trained_model)
409415
best_trained_model.to(device)
410416

411-
model_state, optimizer_state = torch.load(best_trial.checkpoint.value)
417+
best_checkpoint_dir = best_trial.checkpoint.value
418+
model_state, optimizer_state = torch.load(os.path.join(
419+
best_checkpoint_dir, "checkpoint"))
412420
best_trained_model.load_state_dict(model_state)
413421

414422
test_acc = test_accuracy(best_trained_model, device)

requirements.txt

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,7 @@ bs4
1414
awscli==1.16.35
1515
flask
1616
spacy
17-
18-
# Replace this block with ray[tune] after next PyPI release
19-
https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-0.9.0.dev0-cp36-cp36m-manylinux1_x86_64.whl
20-
tabulate
21-
tensorboardX
22-
pandas
17+
ray[tune]
2318

2419
# PyTorch Theme
2520
-e git+git://github.com/pytorch/pytorch_sphinx_theme.git#egg=pytorch_sphinx_theme

0 commit comments

Comments
 (0)