Skip to content

Commit dfbba6e

Browse files
krfrickeSvetlana Karslioglu
and
Svetlana Karslioglu
authored
Update hyperparameter optimization with Ray Tune tutorial (#2318)
* Update code * Reformat * Output * remove total time s --------- Co-authored-by: Svetlana Karslioglu <svekars@fb.com>
1 parent c877c59 commit dfbba6e

File tree

2 files changed

+117
-93
lines changed

2 files changed

+117
-93
lines changed

beginner_source/hyperparameter_tuning_tutorial.py

Lines changed: 116 additions & 92 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,6 @@
4040
Let's start with the imports:
4141
"""
4242
from functools import partial
43-
import numpy as np
4443
import os
4544
import torch
4645
import torch.nn as nn
@@ -50,7 +49,7 @@
5049
import torchvision
5150
import torchvision.transforms as transforms
5251
from ray import tune
53-
from ray.tune import CLIReporter
52+
from ray.air import Checkpoint, session
5453
from ray.tune.schedulers import ASHAScheduler
5554

5655
######################################################################
@@ -64,23 +63,26 @@
6463

6564

6665
def load_data(data_dir="./data"):
67-
transform = transforms.Compose([
68-
transforms.ToTensor(),
69-
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
70-
])
66+
transform = transforms.Compose(
67+
[transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
68+
)
7169

7270
trainset = torchvision.datasets.CIFAR10(
73-
root=data_dir, train=True, download=True, transform=transform)
71+
root=data_dir, train=True, download=True, transform=transform
72+
)
7473

7574
testset = torchvision.datasets.CIFAR10(
76-
root=data_dir, train=False, download=True, transform=transform)
75+
root=data_dir, train=False, download=True, transform=transform
76+
)
7777

7878
return trainset, testset
7979

80+
8081
######################################################################
8182
# Configurable neural network
8283
# ---------------------------
83-
# We can only tune those parameters that are configurable. In this example, we can specify
84+
# We can only tune those parameters that are configurable.
85+
# In this example, we can specify
8486
# the layer sizes of the fully connected layers:
8587

8688

@@ -97,32 +99,40 @@ def __init__(self, l1=120, l2=84):
9799
def forward(self, x):
98100
x = self.pool(F.relu(self.conv1(x)))
99101
x = self.pool(F.relu(self.conv2(x)))
100-
x = x.view(-1, 16 * 5 * 5)
102+
x = torch.flatten(x, 1) # flatten all dimensions except batch
101103
x = F.relu(self.fc1(x))
102104
x = F.relu(self.fc2(x))
103105
x = self.fc3(x)
104106
return x
105107

108+
106109
######################################################################
107110
# The train function
108111
# ------------------
109112
# Now it gets interesting, because we introduce some changes to the example `from the PyTorch
110113
# documentation <https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html>`_.
111114
#
112-
# We wrap the training script in a function ``train_cifar(config, checkpoint_dir=None, data_dir=None)``.
113-
# As you can guess, the ``config`` parameter will receive the hyperparameters we would like to
114-
# train with. The ``checkpoint_dir`` parameter is used to restore checkpoints. The ``data_dir`` specifies
115-
# the directory where we load and store the data, so multiple runs can share the same data source.
115+
# We wrap the training script in a function ``train_cifar(config, data_dir=None)``.
116+
# The ``config`` parameter will receive the hyperparameters we would like to
117+
# train with. The ``data_dir`` specifies the directory where we load and store the data,
118+
# so that multiple runs can share the same data source.
119+
# We also load the model and optimizer state at the start of the run, if a checkpoint
120+
# is provided. Further down in this tutorial you will find information on how
121+
# to save the checkpoint and what it is used for.
116122
#
117123
# .. code-block:: python
118124
#
119125
# net = Net(config["l1"], config["l2"])
120126
#
121-
# if checkpoint_dir:
122-
# model_state, optimizer_state = torch.load(
123-
# os.path.join(checkpoint_dir, "checkpoint"))
124-
# net.load_state_dict(model_state)
125-
# optimizer.load_state_dict(optimizer_state)
127+
# checkpoint = session.get_checkpoint()
128+
#
129+
# if checkpoint:
130+
# checkpoint_state = checkpoint.to_dict()
131+
# start_epoch = checkpoint_state["epoch"]
132+
# net.load_state_dict(checkpoint_state["net_state_dict"])
133+
# optimizer.load_state_dict(checkpoint_state["optimizer_state_dict"])
134+
# else:
135+
# start_epoch = 0
126136
#
127137
# The learning rate of the optimizer is made configurable, too:
128138
#
@@ -171,11 +181,17 @@ def forward(self, x):
171181
#
172182
# .. code-block:: python
173183
#
174-
# with tune.checkpoint_dir(epoch) as checkpoint_dir:
175-
# path = os.path.join(checkpoint_dir, "checkpoint")
176-
# torch.save((net.state_dict(), optimizer.state_dict()), path)
184+
# checkpoint_data = {
185+
# "epoch": epoch,
186+
# "net_state_dict": net.state_dict(),
187+
# "optimizer_state_dict": optimizer.state_dict(),
188+
# }
189+
# checkpoint = Checkpoint.from_dict(checkpoint_data)
177190
#
178-
# tune.report(loss=(val_loss / val_steps), accuracy=correct / total)
191+
# session.report(
192+
# {"loss": val_loss / val_steps, "accuracy": correct / total},
193+
# checkpoint=checkpoint,
194+
# )
179195
#
180196
# Here we first save a checkpoint and then report some metrics back to Ray Tune. Specifically,
181197
# we send the validation loss and accuracy back to Ray Tune. Ray Tune can then use these metrics
@@ -187,15 +203,16 @@ def forward(self, x):
187203
# schedulers like
188204
# `Population Based Training <https://docs.ray.io/en/master/tune/tutorials/tune-advanced-tutorial.html>`_.
189205
# Also, by saving the checkpoint we can later load the trained models and validate them
190-
# on a test set.
206+
# on a test set. Lastly, saving checkpoints is useful for fault tolerance, and it allows
207+
# us to interrupt training and continue training later.
191208
#
192209
# Full training function
193210
# ~~~~~~~~~~~~~~~~~~~~~~
194211
#
195212
# The full code example looks like this:
196213

197214

198-
def train_cifar(config, checkpoint_dir=None, data_dir=None):
215+
def train_cifar(config, data_dir=None):
199216
net = Net(config["l1"], config["l2"])
200217

201218
device = "cpu"
@@ -208,30 +225,31 @@ def train_cifar(config, checkpoint_dir=None, data_dir=None):
208225
criterion = nn.CrossEntropyLoss()
209226
optimizer = optim.SGD(net.parameters(), lr=config["lr"], momentum=0.9)
210227

211-
if checkpoint_dir:
212-
model_state, optimizer_state = torch.load(
213-
os.path.join(checkpoint_dir, "checkpoint"))
214-
net.load_state_dict(model_state)
215-
optimizer.load_state_dict(optimizer_state)
228+
checkpoint = session.get_checkpoint()
229+
230+
if checkpoint:
231+
checkpoint_state = checkpoint.to_dict()
232+
start_epoch = checkpoint_state["epoch"]
233+
net.load_state_dict(checkpoint_state["net_state_dict"])
234+
optimizer.load_state_dict(checkpoint_state["optimizer_state_dict"])
235+
else:
236+
start_epoch = 0
216237

217238
trainset, testset = load_data(data_dir)
218239

219240
test_abs = int(len(trainset) * 0.8)
220241
train_subset, val_subset = random_split(
221-
trainset, [test_abs, len(trainset) - test_abs])
242+
trainset, [test_abs, len(trainset) - test_abs]
243+
)
222244

223245
trainloader = torch.utils.data.DataLoader(
224-
train_subset,
225-
batch_size=int(config["batch_size"]),
226-
shuffle=True,
227-
num_workers=8)
246+
train_subset, batch_size=int(config["batch_size"]), shuffle=True, num_workers=8
247+
)
228248
valloader = torch.utils.data.DataLoader(
229-
val_subset,
230-
batch_size=int(config["batch_size"]),
231-
shuffle=True,
232-
num_workers=8)
249+
val_subset, batch_size=int(config["batch_size"]), shuffle=True, num_workers=8
250+
)
233251

234-
for epoch in range(10): # loop over the dataset multiple times
252+
for epoch in range(start_epoch, 10): # loop over the dataset multiple times
235253
running_loss = 0.0
236254
epoch_steps = 0
237255
for i, data in enumerate(trainloader, 0):
@@ -252,8 +270,10 @@ def train_cifar(config, checkpoint_dir=None, data_dir=None):
252270
running_loss += loss.item()
253271
epoch_steps += 1
254272
if i % 2000 == 1999: # print every 2000 mini-batches
255-
print("[%d, %5d] loss: %.3f" % (epoch + 1, i + 1,
256-
running_loss / epoch_steps))
273+
print(
274+
"[%d, %5d] loss: %.3f"
275+
% (epoch + 1, i + 1, running_loss / epoch_steps)
276+
)
257277
running_loss = 0.0
258278

259279
# Validation loss
@@ -275,13 +295,20 @@ def train_cifar(config, checkpoint_dir=None, data_dir=None):
275295
val_loss += loss.cpu().numpy()
276296
val_steps += 1
277297

278-
with tune.checkpoint_dir(epoch) as checkpoint_dir:
279-
path = os.path.join(checkpoint_dir, "checkpoint")
280-
torch.save((net.state_dict(), optimizer.state_dict()), path)
281-
282-
tune.report(loss=(val_loss / val_steps), accuracy=correct / total)
298+
checkpoint_data = {
299+
"epoch": epoch,
300+
"net_state_dict": net.state_dict(),
301+
"optimizer_state_dict": optimizer.state_dict(),
302+
}
303+
checkpoint = Checkpoint.from_dict(checkpoint_data)
304+
305+
session.report(
306+
{"loss": val_loss / val_steps, "accuracy": correct / total},
307+
checkpoint=checkpoint,
308+
)
283309
print("Finished Training")
284310

311+
285312
######################################################################
286313
# As you can see, most of the code is adapted directly from the original example.
287314
#
@@ -296,7 +323,8 @@ def test_accuracy(net, device="cpu"):
296323
trainset, testset = load_data()
297324

298325
testloader = torch.utils.data.DataLoader(
299-
testset, batch_size=4, shuffle=False, num_workers=2)
326+
testset, batch_size=4, shuffle=False, num_workers=2
327+
)
300328

301329
correct = 0
302330
total = 0
@@ -311,6 +339,7 @@ def test_accuracy(net, device="cpu"):
311339

312340
return correct / total
313341

342+
314343
######################################################################
315344
# The function also expects a ``device`` parameter, so we can do the
316345
# test set validation on a GPU.
@@ -322,14 +351,14 @@ def test_accuracy(net, device="cpu"):
322351
# .. code-block:: python
323352
#
324353
# config = {
325-
# "l1": tune.sample_from(lambda _: 2**np.random.randint(2, 9)),
326-
# "l2": tune.sample_from(lambda _: 2**np.random.randint(2, 9)),
354+
# "l1": tune.choice([2 ** i for i in range(9)]),
355+
# "l2": tune.choice([2 ** i for i in range(9)]),
327356
# "lr": tune.loguniform(1e-4, 1e-1),
328357
# "batch_size": tune.choice([2, 4, 8, 16])
329358
# }
330359
#
331-
# The ``tune.sample_from()`` function makes it possible to define your own sample
332-
# methods to obtain hyperparameters. In this example, the ``l1`` and ``l2`` parameters
360+
# The ``tune.choice()`` accepts a list of values that are uniformly sampled from.
361+
# In this example, the ``l1`` and ``l2`` parameters
333362
# should be powers of 2 between 4 and 256, so either 4, 8, 16, 32, 64, 128, or 256.
334363
# The ``lr`` (learning rate) should be uniformly sampled between 0.0001 and 0.1. Lastly,
335364
# the batch size is a choice between 2, 4, 8, and 16.
@@ -353,7 +382,6 @@ def test_accuracy(net, device="cpu"):
353382
# config=config,
354383
# num_samples=num_samples,
355384
# scheduler=scheduler,
356-
# progress_reporter=reporter,
357385
# checkpoint_at_end=True)
358386
#
359387
# You can specify the number of CPUs, which are then available e.g.
@@ -377,34 +405,30 @@ def main(num_samples=10, max_num_epochs=10, gpus_per_trial=2):
377405
data_dir = os.path.abspath("./data")
378406
load_data(data_dir)
379407
config = {
380-
"l1": tune.sample_from(lambda _: 2 ** np.random.randint(2, 9)),
381-
"l2": tune.sample_from(lambda _: 2 ** np.random.randint(2, 9)),
408+
"l1": tune.choice([2**i for i in range(9)]),
409+
"l2": tune.choice([2**i for i in range(9)]),
382410
"lr": tune.loguniform(1e-4, 1e-1),
383-
"batch_size": tune.choice([2, 4, 8, 16])
411+
"batch_size": tune.choice([2, 4, 8, 16]),
384412
}
385413
scheduler = ASHAScheduler(
386414
metric="loss",
387415
mode="min",
388416
max_t=max_num_epochs,
389417
grace_period=1,
390-
reduction_factor=2)
391-
reporter = CLIReporter(
392-
# ``parameter_columns=["l1", "l2", "lr", "batch_size"]``,
393-
metric_columns=["loss", "accuracy", "training_iteration"])
418+
reduction_factor=2,
419+
)
394420
result = tune.run(
395421
partial(train_cifar, data_dir=data_dir),
396422
resources_per_trial={"cpu": 2, "gpu": gpus_per_trial},
397423
config=config,
398424
num_samples=num_samples,
399425
scheduler=scheduler,
400-
progress_reporter=reporter)
426+
)
401427

402428
best_trial = result.get_best_trial("loss", "min", "last")
403-
print("Best trial config: {}".format(best_trial.config))
404-
print("Best trial final validation loss: {}".format(
405-
best_trial.last_result["loss"]))
406-
print("Best trial final validation accuracy: {}".format(
407-
best_trial.last_result["accuracy"]))
429+
print(f"Best trial config: {best_trial.config}")
430+
print(f"Best trial final validation loss: {best_trial.last_result['loss']}")
431+
print(f"Best trial final validation accuracy: {best_trial.last_result['accuracy']}")
408432

409433
best_trained_model = Net(best_trial.config["l1"], best_trial.config["l2"])
410434
device = "cpu"
@@ -414,10 +438,10 @@ def main(num_samples=10, max_num_epochs=10, gpus_per_trial=2):
414438
best_trained_model = nn.DataParallel(best_trained_model)
415439
best_trained_model.to(device)
416440

417-
best_checkpoint_dir = best_trial.checkpoint.value
418-
model_state, optimizer_state = torch.load(os.path.join(
419-
best_checkpoint_dir, "checkpoint"))
420-
best_trained_model.load_state_dict(model_state)
441+
best_checkpoint = best_trial.checkpoint.to_air_checkpoint()
442+
best_checkpoint_data = best_checkpoint.to_dict()
443+
444+
best_trained_model.load_state_dict(best_checkpoint_data["net_state_dict"])
421445

422446
test_acc = test_accuracy(best_trained_model, device)
423447
print("Best trial test set accuracy: {}".format(test_acc))
@@ -428,6 +452,7 @@ def main(num_samples=10, max_num_epochs=10, gpus_per_trial=2):
428452
# Fixes ``AttributeError: '_LoggingTee' object has no attribute 'fileno'``.
429453
# This is only needed to run with sphinx-build.
430454
import sys
455+
431456
sys.stdout.fileno = lambda: False
432457
# sphinx_gallery_end_ignore
433458
# You can change the number of GPUs per trial here:
@@ -439,30 +464,29 @@ def main(num_samples=10, max_num_epochs=10, gpus_per_trial=2):
439464
#
440465
# ::
441466
#
442-
# Number of trials: 10 (10 TERMINATED)
443-
# +-----+------+------+-------------+--------------+---------+------------+--------------------+
444-
# | ... | l1 | l2 | lr | batch_size | loss | accuracy | training_iteration |
445-
# |-----+------+------+-------------+--------------+---------+------------+--------------------|
446-
# | ... | 64 | 4 | 0.00011629 | 2 | 1.87273 | 0.244 | 2 |
447-
# | ... | 32 | 64 | 0.000339763 | 8 | 1.23603 | 0.567 | 8 |
448-
# | ... | 8 | 16 | 0.00276249 | 16 | 1.1815 | 0.5836 | 10 |
449-
# | ... | 4 | 64 | 0.000648721 | 4 | 1.31131 | 0.5224 | 8 |
450-
# | ... | 32 | 16 | 0.000340753 | 8 | 1.26454 | 0.5444 | 8 |
451-
# | ... | 8 | 4 | 0.000699775 | 8 | 1.99594 | 0.1983 | 2 |
452-
# | ... | 256 | 8 | 0.0839654 | 16 | 2.3119 | 0.0993 | 1 |
453-
# | ... | 16 | 128 | 0.0758154 | 16 | 2.33575 | 0.1327 | 1 |
454-
# | ... | 16 | 8 | 0.0763312 | 16 | 2.31129 | 0.1042 | 4 |
455-
# | ... | 128 | 16 | 0.000124903 | 4 | 2.26917 | 0.1945 | 1 |
456-
# +-----+------+------+-------------+--------------+---------+------------+--------------------+
457-
#
458-
#
459-
# Best trial config: {'l1': 8, 'l2': 16, 'lr': 0.00276249, 'batch_size': 16, 'data_dir': '...'}
460-
# Best trial final validation loss: 1.181501
461-
# Best trial final validation accuracy: 0.5836
462-
# Best trial test set accuracy: 0.5806
467+
# Number of trials: 10/10 (10 TERMINATED)
468+
# +-----+--------------+------+------+-------------+--------+---------+------------+
469+
# | ... | batch_size | l1 | l2 | lr | iter | loss | accuracy |
470+
# |-----+--------------+------+------+-------------+--------+---------+------------|
471+
# | ... | 2 | 1 | 256 | 0.000668163 | 1 | 2.31479 | 0.0977 |
472+
# | ... | 4 | 64 | 8 | 0.0331514 | 1 | 2.31605 | 0.0983 |
473+
# | ... | 4 | 2 | 1 | 0.000150295 | 1 | 2.30755 | 0.1023 |
474+
# | ... | 16 | 32 | 32 | 0.0128248 | 10 | 1.66912 | 0.4391 |
475+
# | ... | 4 | 8 | 128 | 0.00464561 | 2 | 1.7316 | 0.3463 |
476+
# | ... | 8 | 256 | 8 | 0.00031556 | 1 | 2.19409 | 0.1736 |
477+
# | ... | 4 | 16 | 256 | 0.00574329 | 2 | 1.85679 | 0.3368 |
478+
# | ... | 8 | 2 | 2 | 0.00325652 | 1 | 2.30272 | 0.0984 |
479+
# | ... | 2 | 2 | 2 | 0.000342987 | 2 | 1.76044 | 0.292 |
480+
# | ... | 4 | 64 | 32 | 0.003734 | 8 | 1.53101 | 0.4761 |
481+
# +-----+--------------+------+------+-------------+--------+---------+------------+
482+
#
483+
# Best trial config: {'l1': 64, 'l2': 32, 'lr': 0.0037339984519545164, 'batch_size': 4}
484+
# Best trial final validation loss: 1.5310075663924216
485+
# Best trial final validation accuracy: 0.4761
486+
# Best trial test set accuracy: 0.4737
463487
#
464488
# Most trials have been stopped early in order to avoid wasting resources.
465-
# The best performing trial achieved a validation accuracy of about 58%, which could
489+
# The best performing trial achieved a validation accuracy of about 47%, which could
466490
# be confirmed on the test set.
467491
#
468492
# So that's it! You can now tune the parameters of your PyTorch models.

0 commit comments

Comments
 (0)