|
20 | 20 |
|
21 | 21 | You may download and run this tutorial as a standalone Python script.
|
22 | 22 | The only requirements are Pytorch 1.6+ and a CUDA-capable GPU.
|
| 23 | +
|
| 24 | +Mixed precision primarily benefits Tensor Core-enabled architectures (Volta, Turing, Ampere). |
| 25 | +This recipe should show significant (2-3X) speedup on those architectures. |
| 26 | +On earlier architectures (Kepler, Maxwell, Pascal), you may observe a modest speedup. |
| 27 | +Run ``nvidia-smi`` to display your GPU's architecture. |
23 | 28 | """
|
24 | 29 |
|
25 | 30 | import torch, time, gc
|
@@ -212,6 +217,38 @@ def make_model(in_size, out_size, num_layers):
|
212 | 217 | scaler.update()
|
213 | 218 | opt.zero_grad()
|
214 | 219 |
|
| 220 | +########################################################## |
| 221 | +# Saving/Resuming |
| 222 | +# ---------------- |
| 223 | +# To save/resume Amp-enabled runs with bitwise accuracy, use |
| 224 | +# `scaler.state_dict <https://pytorch.org/docs/stable/amp.html#torch.cuda.amp.GradScaler.state_dict>`_ and |
| 225 | +# `scaler.load_state_dict <https://pytorch.org/docs/stable/amp.html#torch.cuda.amp.GradScaler.load_state_dict>`_. |
| 226 | +# |
| 227 | +# When saving, save the scaler state dict alongside the usual model and optimizer state dicts. |
| 228 | +# Do this either at the beginning of an iteration before any forward passes, or at the end of |
| 229 | +# an iteration after ``scaler.update()``. |
| 230 | + |
| 231 | +checkpoint = {"model": net.state_dict(), |
| 232 | + "optimizer": opt.state_dict(), |
| 233 | + "scaler": scaler.state_dict()} |
| 234 | + |
| 235 | +# (write checkpoint as desired, e.g., ``torch.save(checkpoint, "filename")``.) |
| 236 | +# |
| 237 | +# When resuming, load the scaler state dict alongside the model and optimizer state dicts. |
| 238 | +# (read checkpoint as desired, e.g., |
| 239 | +# ``checkpoint = torch.load(args.resume, map_location = lambda storage, loc: storage.cuda(torch.cuda.current_device()))``) |
| 240 | + |
| 241 | +net.load_state_dict(checkpoint["model"]) |
| 242 | +opt.load_state_dict(checkpoint["optimizer"]) |
| 243 | +scaler.load_state_dict(checkpoint["scaler"]) |
| 244 | + |
| 245 | +# If a checkpoint was created from a run _without_ mixed precision, and you want to resume training _with_ mixed precision, |
| 246 | +# load model and optimizer states from the checkpoint as usual. The checkpoint won't contain a saved scaler state, so |
| 247 | +# use a fresh instance of ``GradScaler``. |
| 248 | +# |
| 249 | +# If a checkpoint was created from a run _with_ mixed precision and you want to resume training _without_ mixed precision, |
| 250 | +# load model and optimizer states from the checkpoint as usual, and ignore the saved scaler state. |
| 251 | + |
215 | 252 | ##########################################################
|
216 | 253 | # Inference/Evaluation
|
217 | 254 | # --------------------
|
|
0 commit comments