Skip to content

Update user-defined triton kernels tutorial with new torch.library.triton_op #3227

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 21 commits into from
Jan 24, 2025
Merged
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
c307876
Update user-defined triton kernels tutorial with new torch.library.tr…
zou3519 Jan 10, 2025
5b03d04
Update recipes_source/torch_compile_user_defined_triton_kernel_tutori…
svekars Jan 10, 2025
dd359f7
Update recipes_source/torch_compile_user_defined_triton_kernel_tutori…
zou3519 Jan 13, 2025
20e86d0
Update recipes_source/torch_compile_user_defined_triton_kernel_tutori…
zou3519 Jan 13, 2025
72d6007
Update recipes_source/torch_compile_user_defined_triton_kernel_tutori…
zou3519 Jan 13, 2025
58bc83d
Update recipes_source/torch_compile_user_defined_triton_kernel_tutori…
zou3519 Jan 13, 2025
665cd6e
Update recipes_source/torch_compile_user_defined_triton_kernel_tutori…
zou3519 Jan 13, 2025
4f9eede
Update recipes_source/torch_compile_user_defined_triton_kernel_tutori…
zou3519 Jan 13, 2025
e5aaa9a
Update recipes_source/torch_compile_user_defined_triton_kernel_tutori…
zou3519 Jan 13, 2025
9de60b4
Update recipes_source/torch_compile_user_defined_triton_kernel_tutori…
zou3519 Jan 13, 2025
a165a7e
Update recipes_source/torch_compile_user_defined_triton_kernel_tutori…
zou3519 Jan 13, 2025
3620928
Update recipes_source/torch_compile_user_defined_triton_kernel_tutori…
zou3519 Jan 13, 2025
e8c6763
Update recipes_source/torch_compile_user_defined_triton_kernel_tutori…
zou3519 Jan 13, 2025
cbe7f04
Update recipes_source/torch_compile_user_defined_triton_kernel_tutori…
zou3519 Jan 13, 2025
844ff02
Update recipes_source/torch_compile_user_defined_triton_kernel_tutori…
zou3519 Jan 13, 2025
7b0fb12
Update recipes_source/torch_compile_user_defined_triton_kernel_tutori…
zou3519 Jan 13, 2025
db1c64e
Update recipes_source/torch_compile_user_defined_triton_kernel_tutori…
zou3519 Jan 13, 2025
a88cf38
Update recipes_source/torch_compile_user_defined_triton_kernel_tutori…
zou3519 Jan 13, 2025
bc9c98e
Update recipes_source/torch_compile_user_defined_triton_kernel_tutori…
svekars Jan 13, 2025
9291e85
Merge branch 'main' into update_triton
svekars Jan 24, 2025
2ee6553
Merge branch 'main' into update_triton
svekars Jan 24, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
213 changes: 210 additions & 3 deletions recipes_source/torch_compile_user_defined_triton_kernel_tutorial.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,17 +140,224 @@ def add_fn(x, y):
print(f"Vector addition of\nX:\t{x}\nY:\t{y}\nis equal to\n{out}")

######################################################################
# Composibility and Limitations
# Composability
# -------------------------------------------------------------------
#
# User-defined Triton kernels do not automatically support all PyTorch
# subsystems. This can be seen in the following use cases:
#
# - Adding a CPU fallback
# - Adding a ``FlopCounter`` formula
# - Composing with Tensor Subclasses
#
# To compose with additional PyTorch subsystems, use ``torch.library.triton_op``.
#
# ``triton_op is`` a structured way of defining a custom operator that is backed by one
# or more Triton kernels: like regular custom operators (``torch.library.custom_op``),
# you are able to specify the interactions with PyTorch subsystems via ``torch.library``.
# However, unlike ``torch.library.custom_op``, which creates opaque callables with respect to
# ``torch.compile``, ``torch.compile`` traces into ``triton_op`` to apply optimizations.
#
# Here’s a chart of which API to use when integrating Triton kernels with PyTorch.
#
# .. list-table::
# :header-rows: 1
#
# * -
# - Triton kernel (no explicit ``torch.library`` wrapper)
# - ``torch.library.triton_op``
# - ``torch.library.custom_op``
# * - Supports inference
# - Yes
# - Yes
# - Yes
# * - Supports training
# - In the majority of cases
# - Yes
# - Yes
# * - Supports ``torch.compile``
# - Yes
# - Yes
# - Yes
# * - Supports ``torch.compile(fullgraph=True)``
# - In the majority of cases
# - In the majority of cases
# - In all cases
# * - Does torch.compile trace into the implementation?
# - Yes
# - Yes
# - No
# * - Supports AOTInductor
# - Yes
# - Yes
# - No
# * - Supports PyTorch Subsystems like FlopCounterMode, CPU Fallback, Tensor Subclasses
# - No
# - Yes
# - Yes

######################################################################
# Wrapping Triton kernels with ``triton_op``
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
#
# Use ``torch.library.triton_op`` to wrap a function that may invoke one or more Triton kernels.
# Use ``torch.library.wrap_triton`` to wrap the calls to the Triton kernel.

from torch.library import triton_op, wrap_triton

@triton_op("mylib::mysin", mutates_args={})
def mysin(x: torch.Tensor) -> torch.Tensor:
out = torch.empty_like(x)
n_elements = x.numel()
wrap_triton(sin_kernel)[(n_elements,)](x, out, n_elements, BLOCK_SIZE=4)
return out

@triton.jit
def sin_kernel(
in_ptr0,
out_ptr,
n_elements,
BLOCK_SIZE: "tl.constexpr",
):
pid = tl.program_id(axis=0)
block_start = pid * BLOCK_SIZE
offsets = block_start + tl.arange(0, BLOCK_SIZE)
mask = offsets < n_elements
x = tl.load(in_ptr0 + offsets, mask=mask)
output = tl.sin(x)
tl.store(out_ptr + offsets, output, mask=mask)

def sin_triton(x):
out = torch.empty_like(x)
n_elements = x.numel()
sin_kernel[(n_elements,)](x, out, n_elements, BLOCK_SIZE=4)
return out
Comment on lines +230 to +234
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess this is the user-defined triton kernel w/ no @triton_op ? imo it would make sense to remove it (or add some context around it to indicate that this is the alternative where you don't use @triton_op)


######################################################################
# You can invoke the ``triton_op`` in one of the following two ways.

x = torch.randn(3, device="cuda")
y = mysin(x)
z = torch.ops.mylib.mysin.default(x)

assert torch.allclose(y, x.sin())
assert torch.allclose(z, x.sin())

######################################################################
# The resulting ``triton_op`` works with ``torch.compile`` and ``AOTInductor``.

y = torch.compile(mysin)(x)
assert torch.allclose(y, x.sin())

######################################################################
# Adding training support
# ^^^^^^^^^^^^^^^^^^^^^^^
#
# Use ``register_autograd`` to add an autograd formula for the ``triton_op``.
# Prefer this to using ``torch.autograd.Function`` (which has various composability footguns
# with ``torch.compile``).

def backward(ctx, grad_output):
x, = ctx.saved_tensors
return grad_input * x.cos()

def setup_context(ctx, inputs, output):
x, = inputs
ctx.save_for_backward(x)

mysin.register_autograd(backward, setup_context=setup_context)

######################################################################
# Note that the backward must be a composition of PyTorch-understood operators.
# If you want the backward to call Triton kernels, then those must be wrapped in ``triton_op`` as well:

@triton.jit
def cos_kernel(
in_ptr0,
out_ptr,
n_elements,
BLOCK_SIZE: "tl.constexpr",
):
pid = tl.program_id(axis=0)
block_start = pid * BLOCK_SIZE
offsets = block_start + tl.arange(0, BLOCK_SIZE)
mask = offsets < n_elements
x = tl.load(in_ptr0 + offsets, mask=mask)
output = tl.cos(x)
tl.store(out_ptr + offsets, output, mask=mask)

@triton_op("mylib::mycos", mutates_args={})
def mycos(x: torch.Tensor) -> torch.Tensor:
out = torch.empty_like(x)
n_elements = x.numel()
wrap_triton(cos_kernel)[(n_elements,)](x, out, n_elements, BLOCK_SIZE=4)
return out

def backward(ctx, grad_output):
x, = ctx.saved_tensors
return grad_input * mycos(x)

def setup_context(ctx, inputs, output):
x, = inputs
ctx.save_for_backward(x)

mysin.register_autograd(backward, setup_context=setup_context)

######################################################################
# Adding a CPU Fallback
# ^^^^^^^^^^^^^^^^^^^^^
# Triton kernels don’t run on CPU. Use ``register_kernel`` to add a CPU (or any other device) fallback for the ``triton_op``:

@mysin.register_kernel("cpu")
def _(x):
return torch.sin(x)

x = torch.randn(3)
y = mysin(x)
assert torch.allclose(y, x.sin())

######################################################################
# The fallback must be composed of PyTorch operators.

######################################################################
# Adding a FlopCounter Formula
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
#
# To specify how many flops the triton kernel reports under PyTorch's flop counter,
# use ``register_flop_formula``.

from torch.utils.flop_counter import FlopCounterMode, register_flop_formula

@register_flop_formula(torch.ops.mylib.mysin)
def _(x_shape):
numel = 1
for s in x_shape:
numel *= s
return numel

x = torch.randn(3, device="cuda")

#########################################################
# ``FlopCounterMode`` requires `tabulate <https://pypi.org/project/tabulate/>`__.
# Before running the code below, make sure you have ``tabulate`` installed or install by
# running ``pip install tabulate``.
#
# >>> with FlopCounterMode() as flop_counter:
# >>> y = mysin(x)

######################################################################
# Limitations
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Having this title be limitations, and then having two paragraphs of not-limitations is a bit weird

# --------------------------------------------------------------------
#
# As of PyTorch 2.3, the support for user-defined Triton kernels in ``torch.compile``
# includes dynamic shapes, ``torch.autograd.Function``, JIT inductor, and AOT inductor.
# You can use these features together to build complex, high-performance models.
#
# PyTorch 2.6 added ``torch.library.triton_op``, which adds support for
# user-defined Triton kernels in tensor subclasses and other advanced features.
#
# However, there are certain limitations to be aware of:
#
# * **Tensor Subclasses:** Currently, there is no support for
# tensor subclasses and other advanced features.
# * **Triton Features:** While ``triton.heuristics`` can be used either standalone or
# before ``triton.autotune``, it cannot be used after ``triton.autotune``. This
# implies that if ``triton.heuristics`` and ``triton.autotune`` are to be used
Expand Down
Loading