|
| 1 | +# -*- coding: utf-8 -*- |
| 2 | +""" |
| 3 | +Automatic Mixed Precision |
| 4 | +************************* |
| 5 | +**Author**: `Michael Carilli <https://github.com/mcarilli>`_ |
| 6 | +
|
| 7 | +`torch.cuda.amp <https://pytorch.org/docs/stable/amp.html>`_ provides convenience methods for mixed precision, |
| 8 | +where some operations use the ``torch.float32`` (``float``) datatype and other operations |
| 9 | +use ``torch.float16`` (``half``). Some ops, like linear layers and convolutions, |
| 10 | +are much faster in ``float16``. Other ops, like reductions, often require the dynamic |
| 11 | +range of ``float32``. Mixed precision tries to match each op to its appropriate datatype, |
| 12 | +which can reduce your network's runtime and memory footprint. |
| 13 | +
|
| 14 | +Ordinarily, "automatic mixed precision training" uses `torch.cuda.amp.autocast <https://pytorch.org/docs/stable/amp.html#torch.cuda.amp.autocast>`_ and |
| 15 | +`torch.cuda.amp.GradScaler <https://pytorch.org/docs/stable/amp.html#torch.cuda.amp.GradScaler>`_ together. |
| 16 | +
|
| 17 | +This recipe measures the performance of a simple network in default precision, |
| 18 | +then walks through adding ``autocast`` and ``GradScaler`` to run the same network in |
| 19 | +mixed precision with improved performance. |
| 20 | +
|
| 21 | +You may download and run this recipe as a standalone Python script. |
| 22 | +The only requirements are Pytorch 1.6+ and a CUDA-capable GPU. |
| 23 | +
|
| 24 | +Mixed precision primarily benefits Tensor Core-enabled architectures (Volta, Turing, Ampere). |
| 25 | +This recipe should show significant (2-3X) speedup on those architectures. |
| 26 | +On earlier architectures (Kepler, Maxwell, Pascal), you may observe a modest speedup. |
| 27 | +Run ``nvidia-smi`` to display your GPU's architecture. |
| 28 | +""" |
| 29 | + |
| 30 | +import torch, time, gc |
| 31 | + |
| 32 | +# Timing utilities |
| 33 | +start_time = None |
| 34 | + |
| 35 | +def start_timer(): |
| 36 | + global start_time |
| 37 | + gc.collect() |
| 38 | + torch.cuda.empty_cache() |
| 39 | + torch.cuda.reset_max_memory_allocated() |
| 40 | + torch.cuda.synchronize() |
| 41 | + start_time = time.time() |
| 42 | + |
| 43 | +def end_timer_and_print(local_msg): |
| 44 | + torch.cuda.synchronize() |
| 45 | + end_time = time.time() |
| 46 | + print("\n" + local_msg) |
| 47 | + print("Total execution time = {:.3f} sec".format(end_time - start_time)) |
| 48 | + print("Max memory used by tensors = {} bytes".format(torch.cuda.max_memory_allocated())) |
| 49 | + |
| 50 | +########################################################## |
| 51 | +# A simple network |
| 52 | +# ---------------- |
| 53 | +# The following sequence of linear layers and ReLUs should show a speedup with mixed precision. |
| 54 | + |
| 55 | +def make_model(in_size, out_size, num_layers): |
| 56 | + layers = [] |
| 57 | + for _ in range(num_layers - 1): |
| 58 | + layers.append(torch.nn.Linear(in_size, in_size)) |
| 59 | + layers.append(torch.nn.ReLU()) |
| 60 | + layers.append(torch.nn.Linear(in_size, out_size)) |
| 61 | + return torch.nn.Sequential(*tuple(layers)).cuda() |
| 62 | + |
| 63 | +########################################################## |
| 64 | +# ``batch_size``, ``in_size``, ``out_size``, and ``num_layers`` are chosen to be large enough to saturate the GPU with work. |
| 65 | +# Typically, mixed precision provides the greatest speedup when the GPU is saturated. |
| 66 | +# Small networks may be CPU bound, in which case mixed precision won't improve performance. |
| 67 | +# Sizes are also chosen such that linear layers' participating dimensions are multiples of 8, |
| 68 | +# to permit Tensor Core usage on Tensor Core-capable GPUs (see :ref:`Troubleshooting<troubleshooting>` below). |
| 69 | +# |
| 70 | +# Exercise: Vary participating sizes and see how the mixed precision speedup changes. |
| 71 | + |
| 72 | +batch_size = 512 # Try, for example, 128, 256, 513. |
| 73 | +in_size = 4096 |
| 74 | +out_size = 4096 |
| 75 | +num_layers = 3 |
| 76 | +num_batches = 50 |
| 77 | +epochs = 3 |
| 78 | + |
| 79 | +# Creates data in default precision. |
| 80 | +# The same data is used for both default and mixed precision trials below. |
| 81 | +# You don't need to manually change inputs' dtype when enabling mixed precision. |
| 82 | +data = [torch.randn(batch_size, in_size, device="cuda") for _ in range(num_batches)] |
| 83 | +targets = [torch.randn(batch_size, out_size, device="cuda") for _ in range(num_batches)] |
| 84 | + |
| 85 | +loss_fn = torch.nn.MSELoss().cuda() |
| 86 | + |
| 87 | +########################################################## |
| 88 | +# Default Precision |
| 89 | +# ----------------- |
| 90 | +# Without ``torch.cuda.amp``, the following simple network executes all ops in default precision (``torch.float32``): |
| 91 | + |
| 92 | +net = make_model(in_size, out_size, num_layers) |
| 93 | +opt = torch.optim.SGD(net.parameters(), lr=0.001) |
| 94 | + |
| 95 | +start_timer() |
| 96 | +for epoch in range(epochs): |
| 97 | + for input, target in zip(data, targets): |
| 98 | + output = net(input) |
| 99 | + loss = loss_fn(output, target) |
| 100 | + loss.backward() |
| 101 | + opt.step() |
| 102 | + opt.zero_grad() # set_to_none=True here can modestly improve performance |
| 103 | +end_timer_and_print("Default precision:") |
| 104 | + |
| 105 | +########################################################## |
| 106 | +# Adding autocast |
| 107 | +# --------------- |
| 108 | +# Instances of `torch.cuda.amp.autocast <https://pytorch.org/docs/stable/amp.html#autocasting>`_ |
| 109 | +# serve as context managers that allow regions of your script to run in mixed precision. |
| 110 | +# |
| 111 | +# In these regions, CUDA ops run in a dtype chosen by autocast |
| 112 | +# to improve performance while maintaining accuracy. |
| 113 | +# See the `Autocast Op Reference <https://pytorch.org/docs/stable/amp.html#autocast-op-reference>`_ |
| 114 | +# for details on what precision autocast chooses for each op, and under what circumstances. |
| 115 | + |
| 116 | +for epoch in range(0): # 0 epochs, this section is for illustration only |
| 117 | + for input, target in zip(data, targets): |
| 118 | + # Runs the forward pass under autocast. |
| 119 | + with torch.cuda.amp.autocast(): |
| 120 | + output = net(input) |
| 121 | + # output is float16 because linear layers autocast to float16. |
| 122 | + assert output.dtype is torch.float16 |
| 123 | + |
| 124 | + loss = loss_fn(output, target) |
| 125 | + # loss is float32 because mse_loss layers autocast to float32. |
| 126 | + assert loss.dtype is torch.float32 |
| 127 | + |
| 128 | + # Exits autocast before backward(). |
| 129 | + # Backward passes under autocast are not recommended. |
| 130 | + # Backward ops run in the same dtype autocast chose for corresponding forward ops. |
| 131 | + loss.backward() |
| 132 | + opt.step() |
| 133 | + opt.zero_grad() # set_to_none=True here can modestly improve performance |
| 134 | + |
| 135 | +########################################################## |
| 136 | +# Adding GradScaler |
| 137 | +# ----------------- |
| 138 | +# `Gradient scaling <https://pytorch.org/docs/stable/amp.html#gradient-scaling>`_ |
| 139 | +# helps prevent gradients with small magnitudes from flushing to zero |
| 140 | +# ("underflowing") when training with mixed precision. |
| 141 | +# |
| 142 | +# `torch.cuda.amp.GradScaler <https://pytorch.org/docs/stable/amp.html#torch.cuda.amp.GradScaler>`_ |
| 143 | +# performs the steps of gradient scaling conveniently. |
| 144 | + |
| 145 | +# Constructs scaler once, at the beginning of the convergence run, using default args. |
| 146 | +# If your network fails to converge with default GradScaler args, please file an issue. |
| 147 | +# The same GradScaler instance should be used for the entire convergence run. |
| 148 | +# If you perform multiple convergence runs in the same script, each run should use |
| 149 | +# a dedicated fresh GradScaler instance. GradScaler instances are lightweight. |
| 150 | +scaler = torch.cuda.amp.GradScaler() |
| 151 | + |
| 152 | +for epoch in range(0): # 0 epochs, this section is for illustration only |
| 153 | + for input, target in zip(data, targets): |
| 154 | + with torch.cuda.amp.autocast(): |
| 155 | + output = net(input) |
| 156 | + loss = loss_fn(output, target) |
| 157 | + |
| 158 | + # Scales loss. Calls backward() on scaled loss to create scaled gradients. |
| 159 | + scaler.scale(loss).backward() |
| 160 | + |
| 161 | + # scaler.step() first unscales the gradients of the optimizer's assigned params. |
| 162 | + # If these gradients do not contain infs or NaNs, optimizer.step() is then called, |
| 163 | + # otherwise, optimizer.step() is skipped. |
| 164 | + scaler.step(opt) |
| 165 | + |
| 166 | + # Updates the scale for next iteration. |
| 167 | + scaler.update() |
| 168 | + |
| 169 | + opt.zero_grad() # set_to_none=True here can modestly improve performance |
| 170 | + |
| 171 | +########################################################## |
| 172 | +# All together: "Automatic Mixed Precision" |
| 173 | +# ------------------------------------------ |
| 174 | +# (The following also demonstrates ``enabled``, an optional convenience argument to ``autocast`` and ``GradScaler``. |
| 175 | +# If False, ``autocast`` and ``GradScaler``\ 's calls become no-ops. |
| 176 | +# This allows switching between default precision and mixed precision without if/else statements.) |
| 177 | + |
| 178 | +use_amp = True |
| 179 | + |
| 180 | +net = make_model(in_size, out_size, num_layers) |
| 181 | +opt = torch.optim.SGD(net.parameters(), lr=0.001) |
| 182 | +scaler = torch.cuda.amp.GradScaler(enabled=use_amp) |
| 183 | + |
| 184 | +start_timer() |
| 185 | +for epoch in range(epochs): |
| 186 | + for input, target in zip(data, targets): |
| 187 | + with torch.cuda.amp.autocast(enabled=use_amp): |
| 188 | + output = net(input) |
| 189 | + loss = loss_fn(output, target) |
| 190 | + scaler.scale(loss).backward() |
| 191 | + scaler.step(opt) |
| 192 | + scaler.update() |
| 193 | + opt.zero_grad() # set_to_none=True here can modestly improve performance |
| 194 | +end_timer_and_print("Mixed precision:") |
| 195 | + |
| 196 | +########################################################## |
| 197 | +# Inspecting/modifying gradients (e.g., clipping) |
| 198 | +# -------------------------------------------------------- |
| 199 | +# All gradients produced by ``scaler.scale(loss).backward()`` are scaled. If you wish to modify or inspect |
| 200 | +# the parameters' ``.grad`` attributes between ``backward()`` and ``scaler.step(optimizer)``, you should |
| 201 | +# unscale them first using `scaler.unscale_(optimizer) <https://pytorch.org/docs/stable/amp.html#torch.cuda.amp.GradScaler.unscale_>`_. |
| 202 | + |
| 203 | +for epoch in range(0): # 0 epochs, this section is for illustration only |
| 204 | + for input, target in zip(data, targets): |
| 205 | + with torch.cuda.amp.autocast(): |
| 206 | + output = net(input) |
| 207 | + loss = loss_fn(output, target) |
| 208 | + scaler.scale(loss).backward() |
| 209 | + |
| 210 | + # Unscales the gradients of optimizer's assigned params in-place |
| 211 | + scaler.unscale_(opt) |
| 212 | + |
| 213 | + # Since the gradients of optimizer's assigned params are now unscaled, clips as usual. |
| 214 | + # You may use the same value for max_norm here as you would without gradient scaling. |
| 215 | + torch.nn.utils.clip_grad_norm_(net.parameters(), max_norm=0.1) |
| 216 | + |
| 217 | + scaler.step(opt) |
| 218 | + scaler.update() |
| 219 | + opt.zero_grad() # set_to_none=True here can modestly improve performance |
| 220 | + |
| 221 | +########################################################## |
| 222 | +# Saving/Resuming |
| 223 | +# ---------------- |
| 224 | +# To save/resume Amp-enabled runs with bitwise accuracy, use |
| 225 | +# `scaler.state_dict <https://pytorch.org/docs/stable/amp.html#torch.cuda.amp.GradScaler.state_dict>`_ and |
| 226 | +# `scaler.load_state_dict <https://pytorch.org/docs/stable/amp.html#torch.cuda.amp.GradScaler.load_state_dict>`_. |
| 227 | +# |
| 228 | +# When saving, save the scaler state dict alongside the usual model and optimizer state dicts. |
| 229 | +# Do this either at the beginning of an iteration before any forward passes, or at the end of |
| 230 | +# an iteration after ``scaler.update()``. |
| 231 | + |
| 232 | +checkpoint = {"model": net.state_dict(), |
| 233 | + "optimizer": opt.state_dict(), |
| 234 | + "scaler": scaler.state_dict()} |
| 235 | +# Write checkpoint as desired, e.g., |
| 236 | +# torch.save(checkpoint, "filename") |
| 237 | + |
| 238 | +########################################################## |
| 239 | +# When resuming, load the scaler state dict alongside the model and optimizer state dicts. |
| 240 | + |
| 241 | +# Read checkpoint as desired, e.g., |
| 242 | +# dev = torch.cuda.current_device() |
| 243 | +# checkpoint = torch.load("filename", |
| 244 | +# map_location = lambda storage, loc: storage.cuda(dev)) |
| 245 | +net.load_state_dict(checkpoint["model"]) |
| 246 | +opt.load_state_dict(checkpoint["optimizer"]) |
| 247 | +scaler.load_state_dict(checkpoint["scaler"]) |
| 248 | + |
| 249 | +########################################################## |
| 250 | +# If a checkpoint was created from a run *without* Amp, and you want to resume training *with* Amp, |
| 251 | +# load model and optimizer states from the checkpoint as usual. The checkpoint won't contain a saved scaler state, so |
| 252 | +# use a fresh instance of ``GradScaler``. |
| 253 | +# |
| 254 | +# If a checkpoint was created from a run *with* Amp and you want to resume training *without* Amp, |
| 255 | +# load model and optimizer states from the checkpoint as usual, and ignore the saved scaler state. |
| 256 | + |
| 257 | +########################################################## |
| 258 | +# Inference/Evaluation |
| 259 | +# -------------------- |
| 260 | +# ``autocast`` may be used by itself to wrap inference or evaluation forward passes. ``GradScaler`` is not necessary. |
| 261 | + |
| 262 | +########################################################## |
| 263 | +# .. _advanced-topics: |
| 264 | +# |
| 265 | +# Advanced topics |
| 266 | +# --------------- |
| 267 | +# See the `Automatic Mixed Precision Examples <https://pytorch.org/docs/stable/notes/amp_examples.html>`_ for advanced use cases including: |
| 268 | +# |
| 269 | +# * Gradient accumulation |
| 270 | +# * Gradient penalty/double backward |
| 271 | +# * Networks with multiple models, optimizers, or losses |
| 272 | +# * Multiple GPUs (``torch.nn.DataParallel`` or ``torch.nn.parallel.DistributedDataParallel``) |
| 273 | +# * Custom autograd functions (subclasses of ``torch.autograd.Function``) |
| 274 | +# |
| 275 | +# If you perform multiple convergence runs in the same script, each run should use |
| 276 | +# a dedicated fresh GradScaler instance. GradScaler instances are lightweight. |
| 277 | +# |
| 278 | +# If you're registering a custom C++ op with the dispatcher, see the |
| 279 | +# `autocast section <https://pytorch.org/tutorials/advanced/dispatcher.html#autocast>`_ |
| 280 | +# of the dispatcher tutorial. |
| 281 | + |
| 282 | +########################################################## |
| 283 | +# .. _troubleshooting: |
| 284 | +# |
| 285 | +# Troubleshooting |
| 286 | +# --------------- |
| 287 | +# Speedup with Amp is minor |
| 288 | +# ~~~~~~~~~~~~~~~~~~~~~~~~~ |
| 289 | +# 1. Your network may fail to saturate the GPU(s) with work, and is therefore CPU bound. Amp's effect on GPU performance |
| 290 | +# won't matter. |
| 291 | +# |
| 292 | +# * A rough rule of thumb to saturate the GPU is to increase batch and/or network size(s) |
| 293 | +# as much as you can without running OOM. |
| 294 | +# * Try to avoid excessive CPU-GPU synchronization (``.item()`` calls, or printing values from CUDA tensors). |
| 295 | +# * Try to avoid sequences of many small CUDA ops (coalesce these into a few large CUDA ops if you can). |
| 296 | +# 2. Your network may be GPU compute bound (lots of matmuls/convolutions) but your GPU does not have Tensor Cores. |
| 297 | +# In this case a reduced speedup is expected. |
| 298 | +# 3. Matmul dimensions are not Tensor Core-friendly. Make sure matmuls' participating sizes are multiples of 8. |
| 299 | +# (For NLP models with encoders/decoders, this can be subtle. Also, convolutions used to have similar size constraints |
| 300 | +# for Tensor Core use, but for CuDNN versions 7.3 and later, no such constraints exist. See |
| 301 | +# `here <https://github.com/NVIDIA/apex/issues/221#issuecomment-478084841>`_ for guidance.) |
| 302 | +# |
| 303 | +# Loss is inf/NaN |
| 304 | +# ~~~~~~~~~~~~~~~ |
| 305 | +# First, check if your network fits an :ref:`advanced use case<advanced-topics>`. |
| 306 | +# See also `Prefer binary_cross_entropy_with_logits over binary_cross_entropy <https://pytorch.org/docs/stable/amp.html#prefer-binary-cross-entropy-with-logits-over-binary-cross-entropy>`_. |
| 307 | +# |
| 308 | +# If you're confident your Amp usage is correct, you may need to file an issue, but before doing so, it's helpful to gather the following information: |
| 309 | +# |
| 310 | +# 1. Disable ``autocast`` or ``GradScaler`` individually (by passing ``enabled=False`` to their constructor) and see if infs/NaNs persist. |
| 311 | +# 2. If you suspect part of your network (e.g., a complicated loss function) overflows , run that forward region in ``float32`` |
| 312 | +# and see if infs/NaNs persist. |
| 313 | +# `The autocast docstring <https://pytorch.org/docs/stable/amp.html#torch.cuda.amp.autocast>`_'s last code snippet |
| 314 | +# shows forcing a subregion to run in ``float32`` (by locally disabling autocast and casting the subregion's inputs). |
| 315 | +# |
| 316 | +# Type mismatch error (may manifest as CUDNN_STATUS_BAD_PARAM) |
| 317 | +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ |
| 318 | +# Autocast tries to cover all ops that benefit from or require casting. |
| 319 | +# `Ops that receive explicit coverage <https://pytorch.org/docs/stable/amp.html#autocast-op-reference>`_ |
| 320 | +# are chosen based on numerical properties, but also on experience. |
| 321 | +# If you see a type mismatch error in an autocast-enabled forward region or a backward pass following that region, |
| 322 | +# it's possible autocast missed an op. |
| 323 | +# |
| 324 | +# Please file an issue with the error backtrace. ``export TORCH_SHOW_CPP_STACKTRACES=1`` before running your script to provide |
| 325 | +# fine-grained information on which backend op is failing. |
0 commit comments