Skip to content

Torch Function modes x torch.compile tutorial #3320

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 18 commits into from
Apr 21, 2025
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
18 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions recipes_source/recipes_index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -317,6 +317,15 @@ Recipes are bite-sized, actionable examples of how to use specific PyTorch featu
:link: ../recipes/amx.html
:tags: Model-Optimization

.. (beta) Utilizing Torch Function modes with torch.compile

.. customcarditem::
:header: (beta) Utilizing Torch Function modes with torch.compile
:card_description: Override torch operators with Torch Function modes and torch.compile
:image: ../_static/img/thumbnails/cropped/generic-pytorch-logo.png
:link: ../recipes/torch_compile_torch_function_modes.html
:tags: Model-Optimization

.. (beta) Compiling the Optimizer with torch.compile

.. customcarditem::
Expand Down
77 changes: 77 additions & 0 deletions recipes_source/torch_compile_torch_function_modes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
"""
(beta) Utilizing Torch Function modes with torch.compile
============================================================

**Author:** `Michael Lazos <https://github.com/mlazos>`_
"""

#########################################################
# This tutorial covers how to use a key torch extensibility point,
# torch function modes, in tandem with torch.compile to override
# the behavior of torch ops at trace time, with no runtime overhead.
#
# .. note::
#
# This tutorial requires PyTorch 2.7.0 or later.


#####################################################################
# Rewriting a torch op (torch.add -> torch.mul)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# Rewriting a torch op (torch.add -> torch.mul)
# Rewriting a torch op (``torch.add`` -> ``torch.mul``)

# ~~~~~~~~~~~~~~~~~~~~~
# For this example, we'll use torch function modes to rewrite occurences
# of addition with multiply instead. This type of override can be common
# if a certain backend has a custom implementation that should be dispatched
# for a given op.
import torch

# exit cleanly if we are on a device that doesn't support ``torch.compile``
if torch.cuda.get_device_capability() < (7, 0):
print("Exiting because torch.compile is not supported on this device.")
import sys
sys.exit(0)

from torch.overrides import BaseTorchFunctionMode

# Define our mode, Note: BaseTorchFunctionMode
# implements the actual invocation of func(..)
class AddToMultiplyMode(BaseTorchFunctionMode):
def __torch_function__(self, func, types, args=(), kwargs=None):
if func == torch.Tensor.add:
func = torch.mul

return super().__torch_function__(func, types, args, kwargs)

@torch.compile()
def test_fn(x, y):
return x + y * x # Note: infix operators map to torch.Tensor.* methods

x = torch.rand(2, 2)
y = torch.rand_like(x)

with AddToMultiplyMode():
z = test_fn(x, y)

assert torch.allclose(z, x * y * x)

# The mode can also be used within the compiled region as well like so

@torch.compile()
def test_fn(x, y):
with AddToMultiplyMode():
return x + y * x # Note: infix operators map to torch.Tensor.* methods

x = torch.rand(2, 2)
y = torch.rand_like(x)
z = test_fn(x, y)

assert torch.allclose(z, x * y * x)

######################################################################
# Conclusion
# ~~~~~~~~~~
# In this tutorial we demonstrated how to override the behavior of torch.* operators
# using torch function modes from within torch.compile. This enables users to utilize
# the extensibility benefits of torch function modes without the runtime overhead
# of calling torch function on every op invocation.
#
# * `Extending Torch API with Modes <https://pytorch.org/docs/stable/notes/extending.html#extending-all-torch-api-with-modes>`__ - Other examples and backgroun on Torch Function modes.
Loading