Skip to content

Commit ac602d6

Browse files
committed
arch notes and saving/restoring
1 parent c503f29 commit ac602d6

File tree

1 file changed

+37
-0
lines changed

1 file changed

+37
-0
lines changed

recipes_source/recipes/amp_recipe.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,11 @@
2020
2121
You may download and run this tutorial as a standalone Python script.
2222
The only requirements are Pytorch 1.6+ and a CUDA-capable GPU.
23+
24+
Mixed precision primarily benefits Tensor Core-enabled architectures (Volta, Turing, Ampere).
25+
This recipe should show significant (2-3X) speedup on those architectures.
26+
On earlier architectures (Kepler, Maxwell, Pascal), you may observe a modest speedup.
27+
Run ``nvidia-smi`` to display your GPU's architecture.
2328
"""
2429

2530
import torch, time, gc
@@ -212,6 +217,38 @@ def make_model(in_size, out_size, num_layers):
212217
scaler.update()
213218
opt.zero_grad()
214219

220+
##########################################################
221+
# Saving/Resuming
222+
# ----------------
223+
# To save/resume Amp-enabled runs with bitwise accuracy, use
224+
# `scaler.state_dict <https://pytorch.org/docs/stable/amp.html#torch.cuda.amp.GradScaler.state_dict>`_ and
225+
# `scaler.load_state_dict <https://pytorch.org/docs/stable/amp.html#torch.cuda.amp.GradScaler.load_state_dict>`_.
226+
#
227+
# When saving, save the scaler state dict alongside the usual model and optimizer state dicts.
228+
# Do this either at the beginning of an iteration before any forward passes, or at the end of
229+
# an iteration after ``scaler.update()``.
230+
231+
checkpoint = {"model": net.state_dict(),
232+
"optimizer": opt.state_dict(),
233+
"scaler": scaler.state_dict()}
234+
235+
# (write checkpoint as desired, e.g., ``torch.save(checkpoint, "filename")``.)
236+
#
237+
# When resuming, load the scaler state dict alongside the model and optimizer state dicts.
238+
# (read checkpoint as desired, e.g.,
239+
# ``checkpoint = torch.load(args.resume, map_location = lambda storage, loc: storage.cuda(torch.cuda.current_device()))``)
240+
241+
net.load_state_dict(checkpoint["model"])
242+
opt.load_state_dict(checkpoint["optimizer"])
243+
scaler.load_state_dict(checkpoint["scaler"])
244+
245+
# If a checkpoint was created from a run _without_ mixed precision, and you want to resume training _with_ mixed precision,
246+
# load model and optimizer states from the checkpoint as usual. The checkpoint won't contain a saved scaler state, so
247+
# use a fresh instance of ``GradScaler``.
248+
#
249+
# If a checkpoint was created from a run _with_ mixed precision and you want to resume training _without_ mixed precision,
250+
# load model and optimizer states from the checkpoint as usual, and ignore the saved scaler state.
251+
215252
##########################################################
216253
# Inference/Evaluation
217254
# --------------------

0 commit comments

Comments
 (0)