diff --git a/recipes_source/recipes/amp_recipe.py b/recipes_source/recipes/amp_recipe.py index 86b278d351e..eb92b90d0cc 100644 --- a/recipes_source/recipes/amp_recipe.py +++ b/recipes_source/recipes/amp_recipe.py @@ -11,7 +11,7 @@ range of ``float32``. Mixed precision tries to match each op to its appropriate datatype, which can reduce your network's runtime and memory footprint. -Ordinarily, "automatic mixed precision training" uses `torch.autocast `_ and +Ordinarily, "automatic mixed precision training" uses `torch.autocast `_ and `torch.cuda.amp.GradScaler `_ together. This recipe measures the performance of a simple network in default precision, @@ -19,7 +19,7 @@ mixed precision with improved performance. You may download and run this recipe as a standalone Python script. -The only requirements are Pytorch 1.6+ and a CUDA-capable GPU. +The only requirements are PyTorch 1.6 or later and a CUDA-capable GPU. Mixed precision primarily benefits Tensor Core-enabled architectures (Volta, Turing, Ampere). This recipe should show significant (2-3X) speedup on those architectures. @@ -105,7 +105,7 @@ def make_model(in_size, out_size, num_layers): ########################################################## # Adding autocast # --------------- -# Instances of `torch.cuda.amp.autocast `_ +# Instances of `torch.autocast `_ # serve as context managers that allow regions of your script to run in mixed precision. # # In these regions, CUDA ops run in a dtype chosen by autocast @@ -310,7 +310,7 @@ def make_model(in_size, out_size, num_layers): # 1. Disable ``autocast`` or ``GradScaler`` individually (by passing ``enabled=False`` to their constructor) and see if infs/NaNs persist. # 2. If you suspect part of your network (e.g., a complicated loss function) overflows , run that forward region in ``float32`` # and see if infs/NaNs persist. -# `The autocast docstring `_'s last code snippet +# `The autocast docstring `_'s last code snippet # shows forcing a subregion to run in ``float32`` (by locally disabling autocast and casting the subregion's inputs). # # Type mismatch error (may manifest as CUDNN_STATUS_BAD_PARAM)