Skip to content

Commit fe33b54

Browse files
mcarillibrianjo
andauthored
Python recipe for automatic mixed precision (#1137)
* fdsa * Tutorial runs * clarify one scaler per convergence run * adjust sizes, dont run illustrative sections * satisfying ocd * MORE * fdsa * details * rephrase * fix formatting * move script to recipes * hopefully moved to recipes * fdsa * add amp_tutorial to toctree * amp_tutorial -> amp_recipe * looks like backtick highlights dont render in card_description * correct path for amp_recipe.html * arch notes and saving/restoring * formatting * fdsa * Clarify autograd-autocast interaction for custom ops * touchups Co-authored-by: Brian Johnson <brianjo@fb.com>
1 parent ee5e448 commit fe33b54

File tree

5 files changed

+363
-0
lines changed

5 files changed

+363
-0
lines changed
14.5 KB
Loading

advanced_source/dispatcher.rst

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,8 @@ speaking, the structure of your registrations will look like this:
105105
that provides implementations for all basic operators on the XLA dispatch
106106
key.
107107

108+
.. _autograd-support:
109+
108110
Adding autograd support
109111
-----------------------
110112

@@ -299,6 +301,28 @@ the safest choice for the execution type:
299301
at::autocast::cached_cast(exec_type, t1));
300302
}
301303
304+
If your custom op is :ref:`autograd-enabled<autograd-support>`, you only need to write and register
305+
an autocast wrapper for the same name onto which the autograd wrapper is registered.
306+
For example, if you wanted an autocast wrapper for the ``myadd`` function shown
307+
in the autograd section, all you'd need is
308+
309+
.. code-block:: cpp
310+
311+
Tensor myadd_autocast(const Tensor& self, const Tensor& other) {
312+
c10::impl::ExcludeDispatchKeyGuard no_autocast(c10::DispatchKey::Autocast);
313+
return myadd(at::autocast::cached_cast(<desired dtype>, self),
314+
at::autocast::cached_cast(<desired dtype>, other));
315+
}
316+
317+
TORCH_LIBRARY_IMPL(myops, Autocast, m) {
318+
m.impl("myadd", myadd_autocast);
319+
}
320+
321+
There are no separate gymnastics to make the backward method autocast compatible.
322+
However, the backward method defined in your custom autograd function will run in the same
323+
dtype as autocast sets for the forward method, so you should choose a ``<desired dtype>``
324+
suitable for both your forward and backward methods.
325+
302326
Batched
303327
^^^^^^^
304328

recipes_source/recipes/README.txt

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,3 +56,7 @@ PyTorch Recipes
5656
14. mobile_perf.py
5757
PyTorch Mobile Performance Recipes
5858
https://pytorch.org/tutorials/recipes/mobile_perf.html
59+
60+
15. amp_recipe.py
61+
Automatic Mixed Precision
62+
https://pytorch.org/tutorials/recipes/amp_recipe.html

recipes_source/recipes/amp_recipe.py

Lines changed: 325 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,325 @@
1+
# -*- coding: utf-8 -*-
2+
"""
3+
Automatic Mixed Precision
4+
*************************
5+
**Author**: `Michael Carilli <https://github.com/mcarilli>`_
6+
7+
`torch.cuda.amp <https://pytorch.org/docs/stable/amp.html>`_ provides convenience methods for mixed precision,
8+
where some operations use the ``torch.float32`` (``float``) datatype and other operations
9+
use ``torch.float16`` (``half``). Some ops, like linear layers and convolutions,
10+
are much faster in ``float16``. Other ops, like reductions, often require the dynamic
11+
range of ``float32``. Mixed precision tries to match each op to its appropriate datatype,
12+
which can reduce your network's runtime and memory footprint.
13+
14+
Ordinarily, "automatic mixed precision training" uses `torch.cuda.amp.autocast <https://pytorch.org/docs/stable/amp.html#torch.cuda.amp.autocast>`_ and
15+
`torch.cuda.amp.GradScaler <https://pytorch.org/docs/stable/amp.html#torch.cuda.amp.GradScaler>`_ together.
16+
17+
This recipe measures the performance of a simple network in default precision,
18+
then walks through adding ``autocast`` and ``GradScaler`` to run the same network in
19+
mixed precision with improved performance.
20+
21+
You may download and run this recipe as a standalone Python script.
22+
The only requirements are Pytorch 1.6+ and a CUDA-capable GPU.
23+
24+
Mixed precision primarily benefits Tensor Core-enabled architectures (Volta, Turing, Ampere).
25+
This recipe should show significant (2-3X) speedup on those architectures.
26+
On earlier architectures (Kepler, Maxwell, Pascal), you may observe a modest speedup.
27+
Run ``nvidia-smi`` to display your GPU's architecture.
28+
"""
29+
30+
import torch, time, gc
31+
32+
# Timing utilities
33+
start_time = None
34+
35+
def start_timer():
36+
global start_time
37+
gc.collect()
38+
torch.cuda.empty_cache()
39+
torch.cuda.reset_max_memory_allocated()
40+
torch.cuda.synchronize()
41+
start_time = time.time()
42+
43+
def end_timer_and_print(local_msg):
44+
torch.cuda.synchronize()
45+
end_time = time.time()
46+
print("\n" + local_msg)
47+
print("Total execution time = {:.3f} sec".format(end_time - start_time))
48+
print("Max memory used by tensors = {} bytes".format(torch.cuda.max_memory_allocated()))
49+
50+
##########################################################
51+
# A simple network
52+
# ----------------
53+
# The following sequence of linear layers and ReLUs should show a speedup with mixed precision.
54+
55+
def make_model(in_size, out_size, num_layers):
56+
layers = []
57+
for _ in range(num_layers - 1):
58+
layers.append(torch.nn.Linear(in_size, in_size))
59+
layers.append(torch.nn.ReLU())
60+
layers.append(torch.nn.Linear(in_size, out_size))
61+
return torch.nn.Sequential(*tuple(layers)).cuda()
62+
63+
##########################################################
64+
# ``batch_size``, ``in_size``, ``out_size``, and ``num_layers`` are chosen to be large enough to saturate the GPU with work.
65+
# Typically, mixed precision provides the greatest speedup when the GPU is saturated.
66+
# Small networks may be CPU bound, in which case mixed precision won't improve performance.
67+
# Sizes are also chosen such that linear layers' participating dimensions are multiples of 8,
68+
# to permit Tensor Core usage on Tensor Core-capable GPUs (see :ref:`Troubleshooting<troubleshooting>` below).
69+
#
70+
# Exercise: Vary participating sizes and see how the mixed precision speedup changes.
71+
72+
batch_size = 512 # Try, for example, 128, 256, 513.
73+
in_size = 4096
74+
out_size = 4096
75+
num_layers = 3
76+
num_batches = 50
77+
epochs = 3
78+
79+
# Creates data in default precision.
80+
# The same data is used for both default and mixed precision trials below.
81+
# You don't need to manually change inputs' dtype when enabling mixed precision.
82+
data = [torch.randn(batch_size, in_size, device="cuda") for _ in range(num_batches)]
83+
targets = [torch.randn(batch_size, out_size, device="cuda") for _ in range(num_batches)]
84+
85+
loss_fn = torch.nn.MSELoss().cuda()
86+
87+
##########################################################
88+
# Default Precision
89+
# -----------------
90+
# Without ``torch.cuda.amp``, the following simple network executes all ops in default precision (``torch.float32``):
91+
92+
net = make_model(in_size, out_size, num_layers)
93+
opt = torch.optim.SGD(net.parameters(), lr=0.001)
94+
95+
start_timer()
96+
for epoch in range(epochs):
97+
for input, target in zip(data, targets):
98+
output = net(input)
99+
loss = loss_fn(output, target)
100+
loss.backward()
101+
opt.step()
102+
opt.zero_grad() # set_to_none=True here can modestly improve performance
103+
end_timer_and_print("Default precision:")
104+
105+
##########################################################
106+
# Adding autocast
107+
# ---------------
108+
# Instances of `torch.cuda.amp.autocast <https://pytorch.org/docs/stable/amp.html#autocasting>`_
109+
# serve as context managers that allow regions of your script to run in mixed precision.
110+
#
111+
# In these regions, CUDA ops run in a dtype chosen by autocast
112+
# to improve performance while maintaining accuracy.
113+
# See the `Autocast Op Reference <https://pytorch.org/docs/stable/amp.html#autocast-op-reference>`_
114+
# for details on what precision autocast chooses for each op, and under what circumstances.
115+
116+
for epoch in range(0): # 0 epochs, this section is for illustration only
117+
for input, target in zip(data, targets):
118+
# Runs the forward pass under autocast.
119+
with torch.cuda.amp.autocast():
120+
output = net(input)
121+
# output is float16 because linear layers autocast to float16.
122+
assert output.dtype is torch.float16
123+
124+
loss = loss_fn(output, target)
125+
# loss is float32 because mse_loss layers autocast to float32.
126+
assert loss.dtype is torch.float32
127+
128+
# Exits autocast before backward().
129+
# Backward passes under autocast are not recommended.
130+
# Backward ops run in the same dtype autocast chose for corresponding forward ops.
131+
loss.backward()
132+
opt.step()
133+
opt.zero_grad() # set_to_none=True here can modestly improve performance
134+
135+
##########################################################
136+
# Adding GradScaler
137+
# -----------------
138+
# `Gradient scaling <https://pytorch.org/docs/stable/amp.html#gradient-scaling>`_
139+
# helps prevent gradients with small magnitudes from flushing to zero
140+
# ("underflowing") when training with mixed precision.
141+
#
142+
# `torch.cuda.amp.GradScaler <https://pytorch.org/docs/stable/amp.html#torch.cuda.amp.GradScaler>`_
143+
# performs the steps of gradient scaling conveniently.
144+
145+
# Constructs scaler once, at the beginning of the convergence run, using default args.
146+
# If your network fails to converge with default GradScaler args, please file an issue.
147+
# The same GradScaler instance should be used for the entire convergence run.
148+
# If you perform multiple convergence runs in the same script, each run should use
149+
# a dedicated fresh GradScaler instance. GradScaler instances are lightweight.
150+
scaler = torch.cuda.amp.GradScaler()
151+
152+
for epoch in range(0): # 0 epochs, this section is for illustration only
153+
for input, target in zip(data, targets):
154+
with torch.cuda.amp.autocast():
155+
output = net(input)
156+
loss = loss_fn(output, target)
157+
158+
# Scales loss. Calls backward() on scaled loss to create scaled gradients.
159+
scaler.scale(loss).backward()
160+
161+
# scaler.step() first unscales the gradients of the optimizer's assigned params.
162+
# If these gradients do not contain infs or NaNs, optimizer.step() is then called,
163+
# otherwise, optimizer.step() is skipped.
164+
scaler.step(opt)
165+
166+
# Updates the scale for next iteration.
167+
scaler.update()
168+
169+
opt.zero_grad() # set_to_none=True here can modestly improve performance
170+
171+
##########################################################
172+
# All together: "Automatic Mixed Precision"
173+
# ------------------------------------------
174+
# (The following also demonstrates ``enabled``, an optional convenience argument to ``autocast`` and ``GradScaler``.
175+
# If False, ``autocast`` and ``GradScaler``\ 's calls become no-ops.
176+
# This allows switching between default precision and mixed precision without if/else statements.)
177+
178+
use_amp = True
179+
180+
net = make_model(in_size, out_size, num_layers)
181+
opt = torch.optim.SGD(net.parameters(), lr=0.001)
182+
scaler = torch.cuda.amp.GradScaler(enabled=use_amp)
183+
184+
start_timer()
185+
for epoch in range(epochs):
186+
for input, target in zip(data, targets):
187+
with torch.cuda.amp.autocast(enabled=use_amp):
188+
output = net(input)
189+
loss = loss_fn(output, target)
190+
scaler.scale(loss).backward()
191+
scaler.step(opt)
192+
scaler.update()
193+
opt.zero_grad() # set_to_none=True here can modestly improve performance
194+
end_timer_and_print("Mixed precision:")
195+
196+
##########################################################
197+
# Inspecting/modifying gradients (e.g., clipping)
198+
# --------------------------------------------------------
199+
# All gradients produced by ``scaler.scale(loss).backward()`` are scaled. If you wish to modify or inspect
200+
# the parameters' ``.grad`` attributes between ``backward()`` and ``scaler.step(optimizer)``, you should
201+
# unscale them first using `scaler.unscale_(optimizer) <https://pytorch.org/docs/stable/amp.html#torch.cuda.amp.GradScaler.unscale_>`_.
202+
203+
for epoch in range(0): # 0 epochs, this section is for illustration only
204+
for input, target in zip(data, targets):
205+
with torch.cuda.amp.autocast():
206+
output = net(input)
207+
loss = loss_fn(output, target)
208+
scaler.scale(loss).backward()
209+
210+
# Unscales the gradients of optimizer's assigned params in-place
211+
scaler.unscale_(opt)
212+
213+
# Since the gradients of optimizer's assigned params are now unscaled, clips as usual.
214+
# You may use the same value for max_norm here as you would without gradient scaling.
215+
torch.nn.utils.clip_grad_norm_(net.parameters(), max_norm=0.1)
216+
217+
scaler.step(opt)
218+
scaler.update()
219+
opt.zero_grad() # set_to_none=True here can modestly improve performance
220+
221+
##########################################################
222+
# Saving/Resuming
223+
# ----------------
224+
# To save/resume Amp-enabled runs with bitwise accuracy, use
225+
# `scaler.state_dict <https://pytorch.org/docs/stable/amp.html#torch.cuda.amp.GradScaler.state_dict>`_ and
226+
# `scaler.load_state_dict <https://pytorch.org/docs/stable/amp.html#torch.cuda.amp.GradScaler.load_state_dict>`_.
227+
#
228+
# When saving, save the scaler state dict alongside the usual model and optimizer state dicts.
229+
# Do this either at the beginning of an iteration before any forward passes, or at the end of
230+
# an iteration after ``scaler.update()``.
231+
232+
checkpoint = {"model": net.state_dict(),
233+
"optimizer": opt.state_dict(),
234+
"scaler": scaler.state_dict()}
235+
# Write checkpoint as desired, e.g.,
236+
# torch.save(checkpoint, "filename")
237+
238+
##########################################################
239+
# When resuming, load the scaler state dict alongside the model and optimizer state dicts.
240+
241+
# Read checkpoint as desired, e.g.,
242+
# dev = torch.cuda.current_device()
243+
# checkpoint = torch.load("filename",
244+
# map_location = lambda storage, loc: storage.cuda(dev))
245+
net.load_state_dict(checkpoint["model"])
246+
opt.load_state_dict(checkpoint["optimizer"])
247+
scaler.load_state_dict(checkpoint["scaler"])
248+
249+
##########################################################
250+
# If a checkpoint was created from a run *without* Amp, and you want to resume training *with* Amp,
251+
# load model and optimizer states from the checkpoint as usual. The checkpoint won't contain a saved scaler state, so
252+
# use a fresh instance of ``GradScaler``.
253+
#
254+
# If a checkpoint was created from a run *with* Amp and you want to resume training *without* Amp,
255+
# load model and optimizer states from the checkpoint as usual, and ignore the saved scaler state.
256+
257+
##########################################################
258+
# Inference/Evaluation
259+
# --------------------
260+
# ``autocast`` may be used by itself to wrap inference or evaluation forward passes. ``GradScaler`` is not necessary.
261+
262+
##########################################################
263+
# .. _advanced-topics:
264+
#
265+
# Advanced topics
266+
# ---------------
267+
# See the `Automatic Mixed Precision Examples <https://pytorch.org/docs/stable/notes/amp_examples.html>`_ for advanced use cases including:
268+
#
269+
# * Gradient accumulation
270+
# * Gradient penalty/double backward
271+
# * Networks with multiple models, optimizers, or losses
272+
# * Multiple GPUs (``torch.nn.DataParallel`` or ``torch.nn.parallel.DistributedDataParallel``)
273+
# * Custom autograd functions (subclasses of ``torch.autograd.Function``)
274+
#
275+
# If you perform multiple convergence runs in the same script, each run should use
276+
# a dedicated fresh GradScaler instance. GradScaler instances are lightweight.
277+
#
278+
# If you're registering a custom C++ op with the dispatcher, see the
279+
# `autocast section <https://pytorch.org/tutorials/advanced/dispatcher.html#autocast>`_
280+
# of the dispatcher tutorial.
281+
282+
##########################################################
283+
# .. _troubleshooting:
284+
#
285+
# Troubleshooting
286+
# ---------------
287+
# Speedup with Amp is minor
288+
# ~~~~~~~~~~~~~~~~~~~~~~~~~
289+
# 1. Your network may fail to saturate the GPU(s) with work, and is therefore CPU bound. Amp's effect on GPU performance
290+
# won't matter.
291+
#
292+
# * A rough rule of thumb to saturate the GPU is to increase batch and/or network size(s)
293+
# as much as you can without running OOM.
294+
# * Try to avoid excessive CPU-GPU synchronization (``.item()`` calls, or printing values from CUDA tensors).
295+
# * Try to avoid sequences of many small CUDA ops (coalesce these into a few large CUDA ops if you can).
296+
# 2. Your network may be GPU compute bound (lots of matmuls/convolutions) but your GPU does not have Tensor Cores.
297+
# In this case a reduced speedup is expected.
298+
# 3. Matmul dimensions are not Tensor Core-friendly. Make sure matmuls' participating sizes are multiples of 8.
299+
# (For NLP models with encoders/decoders, this can be subtle. Also, convolutions used to have similar size constraints
300+
# for Tensor Core use, but for CuDNN versions 7.3 and later, no such constraints exist. See
301+
# `here <https://github.com/NVIDIA/apex/issues/221#issuecomment-478084841>`_ for guidance.)
302+
#
303+
# Loss is inf/NaN
304+
# ~~~~~~~~~~~~~~~
305+
# First, check if your network fits an :ref:`advanced use case<advanced-topics>`.
306+
# See also `Prefer binary_cross_entropy_with_logits over binary_cross_entropy <https://pytorch.org/docs/stable/amp.html#prefer-binary-cross-entropy-with-logits-over-binary-cross-entropy>`_.
307+
#
308+
# If you're confident your Amp usage is correct, you may need to file an issue, but before doing so, it's helpful to gather the following information:
309+
#
310+
# 1. Disable ``autocast`` or ``GradScaler`` individually (by passing ``enabled=False`` to their constructor) and see if infs/NaNs persist.
311+
# 2. If you suspect part of your network (e.g., a complicated loss function) overflows , run that forward region in ``float32``
312+
# and see if infs/NaNs persist.
313+
# `The autocast docstring <https://pytorch.org/docs/stable/amp.html#torch.cuda.amp.autocast>`_'s last code snippet
314+
# shows forcing a subregion to run in ``float32`` (by locally disabling autocast and casting the subregion's inputs).
315+
#
316+
# Type mismatch error (may manifest as CUDNN_STATUS_BAD_PARAM)
317+
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
318+
# Autocast tries to cover all ops that benefit from or require casting.
319+
# `Ops that receive explicit coverage <https://pytorch.org/docs/stable/amp.html#autocast-op-reference>`_
320+
# are chosen based on numerical properties, but also on experience.
321+
# If you see a type mismatch error in an autocast-enabled forward region or a backward pass following that region,
322+
# it's possible autocast missed an op.
323+
#
324+
# Please file an issue with the error backtrace. ``export TORCH_SHOW_CPP_STACKTRACES=1`` before running your script to provide
325+
# fine-grained information on which backend op is failing.

recipes_source/recipes_index.rst

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -167,6 +167,15 @@ Recipes are bite-sized, actionable examples of how to use specific PyTorch featu
167167
:link: ../recipes/android_native_app_with_custom_op.html
168168
:tags: Mobile
169169

170+
.. Automatic Mixed Precision
171+
172+
.. customcarditem::
173+
:header: Automatic Mixed Precision
174+
:card_description: Use torch.cuda.amp to reduce runtime and save memory on NVIDIA GPUs.
175+
:image: ../_static/img/thumbnails/cropped/amp.png
176+
:link: ../recipes/recipes/amp_recipe.html
177+
:tags: Model-Optimization
178+
170179
.. End of tutorial card section
171180
172181
.. raw:: html
@@ -199,6 +208,7 @@ Recipes are bite-sized, actionable examples of how to use specific PyTorch featu
199208
/recipes/recipes/Captum_Recipe
200209
/recipes/recipes/tensorboard_with_pytorch
201210
/recipes/recipes/dynamic_quantization
211+
/recipes/recipes/amp_recipe
202212
/recipes/torchscript_inference
203213
/recipes/deployment_with_flask
204214
/recipes/distributed_rpc_profiling

0 commit comments

Comments
 (0)