Skip to content

Update the fusion section of tuning_guide.py #2889

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 30, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 21 additions & 20 deletions recipes_source/recipes/tuning_guide.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,35 +94,36 @@
# ``optimizer.zero_grad(set_to_none=True)``.

###############################################################################
# Fuse pointwise operations
# Fuse operations
# ~~~~~~~~~~~~~~~~~~~~~~~~~
# Pointwise operations (elementwise addition, multiplication, math functions -
# ``sin()``, ``cos()``, ``sigmoid()`` etc.) can be fused into a single kernel
# to amortize memory access time and kernel launch time.
#
# `PyTorch JIT <https://pytorch.org/docs/stable/jit.html>`_ can fuse kernels
# automatically, although there could be additional fusion opportunities not yet
# implemented in the compiler, and not all device types are supported equally.
#
# Pointwise operations are memory-bound, for each operation PyTorch launches a
# separate kernel. Each kernel loads data from the memory, performs computation
# (this step is usually inexpensive) and stores results back into the memory.
#
# Fused operator launches only one kernel for multiple fused pointwise ops and
# loads/stores data only once to the memory. This makes JIT very useful for
# activation functions, optimizers, custom RNN cells etc.
# Pointwise operations such as elementwise addition, multiplication, and math
# functions like `sin()`, `cos()`, `sigmoid()`, etc., can be combined into a
# single kernel. This fusion helps reduce memory access and kernel launch times.
# Typically, pointwise operations are memory-bound; PyTorch eager-mode initiates
# a separate kernel for each operation, which involves loading data from memory,
# executing the operation (often not the most time-consuming step), and writing
# the results back to memory.
#
# By using a fused operator, only one kernel is launched for multiple pointwise
# operations, and data is loaded and stored just once. This efficiency is
# particularly beneficial for activation functions, optimizers, and custom RNN cells etc.
#
# PyTorch 2 introduces a compile-mode facilitated by TorchInductor, an underlying compiler
# that automatically fuses kernels. TorchInductor extends its capabilities beyond simple
# element-wise operations, enabling advanced fusion of eligible pointwise and reduction
# operations for optimized performance.
#
# In the simplest case fusion can be enabled by applying
# `torch.jit.script <https://pytorch.org/docs/stable/generated/torch.jit.script.html#torch.jit.script>`_
# `torch.compile <https://pytorch.org/docs/stable/generated/torch.compile.html>`_
# decorator to the function definition, for example:

@torch.jit.script
def fused_gelu(x):
@torch.compile
def gelu(x):
return x * 0.5 * (1.0 + torch.erf(x / 1.41421))

###############################################################################
# Refer to
# `TorchScript documentation <https://pytorch.org/docs/stable/jit.html>`_
# `Introduction to torch.compile <https://pytorch.org/tutorials/intermediate/torch_compile_tutorial.html>`_
# for more advanced use cases.

###############################################################################
Expand Down
Loading