From 801adf74219e5a86ca61514c2bf9bb17f69e4e60 Mon Sep 17 00:00:00 2001 From: Bin Bao Date: Wed, 29 May 2024 14:46:17 -0400 Subject: [PATCH] Update the fusion section of https://pytorch.org/tutorials/recipes/recipes/tuning_guide.html with torch.compile --- recipes_source/recipes/tuning_guide.py | 41 +++++++++++++------------- 1 file changed, 21 insertions(+), 20 deletions(-) diff --git a/recipes_source/recipes/tuning_guide.py b/recipes_source/recipes/tuning_guide.py index fcc07595506..9f2c70a8921 100644 --- a/recipes_source/recipes/tuning_guide.py +++ b/recipes_source/recipes/tuning_guide.py @@ -94,35 +94,36 @@ # ``optimizer.zero_grad(set_to_none=True)``. ############################################################################### -# Fuse pointwise operations +# Fuse operations # ~~~~~~~~~~~~~~~~~~~~~~~~~ -# Pointwise operations (elementwise addition, multiplication, math functions - -# ``sin()``, ``cos()``, ``sigmoid()`` etc.) can be fused into a single kernel -# to amortize memory access time and kernel launch time. -# -# `PyTorch JIT `_ can fuse kernels -# automatically, although there could be additional fusion opportunities not yet -# implemented in the compiler, and not all device types are supported equally. -# -# Pointwise operations are memory-bound, for each operation PyTorch launches a -# separate kernel. Each kernel loads data from the memory, performs computation -# (this step is usually inexpensive) and stores results back into the memory. -# -# Fused operator launches only one kernel for multiple fused pointwise ops and -# loads/stores data only once to the memory. This makes JIT very useful for -# activation functions, optimizers, custom RNN cells etc. +# Pointwise operations such as elementwise addition, multiplication, and math +# functions like `sin()`, `cos()`, `sigmoid()`, etc., can be combined into a +# single kernel. This fusion helps reduce memory access and kernel launch times. +# Typically, pointwise operations are memory-bound; PyTorch eager-mode initiates +# a separate kernel for each operation, which involves loading data from memory, +# executing the operation (often not the most time-consuming step), and writing +# the results back to memory. +# +# By using a fused operator, only one kernel is launched for multiple pointwise +# operations, and data is loaded and stored just once. This efficiency is +# particularly beneficial for activation functions, optimizers, and custom RNN cells etc. +# +# PyTorch 2 introduces a compile-mode facilitated by TorchInductor, an underlying compiler +# that automatically fuses kernels. TorchInductor extends its capabilities beyond simple +# element-wise operations, enabling advanced fusion of eligible pointwise and reduction +# operations for optimized performance. # # In the simplest case fusion can be enabled by applying -# `torch.jit.script `_ +# `torch.compile `_ # decorator to the function definition, for example: -@torch.jit.script -def fused_gelu(x): +@torch.compile +def gelu(x): return x * 0.5 * (1.0 + torch.erf(x / 1.41421)) ############################################################################### # Refer to -# `TorchScript documentation `_ +# `Introduction to torch.compile `_ # for more advanced use cases. ###############################################################################