-
Notifications
You must be signed in to change notification settings - Fork 4.2k
Update user-defined triton kernels tutorial with new torch.library.triton_op #3227
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
21 commits
Select commit
Hold shift + click to select a range
c307876
Update user-defined triton kernels tutorial with new torch.library.tr…
zou3519 5b03d04
Update recipes_source/torch_compile_user_defined_triton_kernel_tutori…
svekars dd359f7
Update recipes_source/torch_compile_user_defined_triton_kernel_tutori…
zou3519 20e86d0
Update recipes_source/torch_compile_user_defined_triton_kernel_tutori…
zou3519 72d6007
Update recipes_source/torch_compile_user_defined_triton_kernel_tutori…
zou3519 58bc83d
Update recipes_source/torch_compile_user_defined_triton_kernel_tutori…
zou3519 665cd6e
Update recipes_source/torch_compile_user_defined_triton_kernel_tutori…
zou3519 4f9eede
Update recipes_source/torch_compile_user_defined_triton_kernel_tutori…
zou3519 e5aaa9a
Update recipes_source/torch_compile_user_defined_triton_kernel_tutori…
zou3519 9de60b4
Update recipes_source/torch_compile_user_defined_triton_kernel_tutori…
zou3519 a165a7e
Update recipes_source/torch_compile_user_defined_triton_kernel_tutori…
zou3519 3620928
Update recipes_source/torch_compile_user_defined_triton_kernel_tutori…
zou3519 e8c6763
Update recipes_source/torch_compile_user_defined_triton_kernel_tutori…
zou3519 cbe7f04
Update recipes_source/torch_compile_user_defined_triton_kernel_tutori…
zou3519 844ff02
Update recipes_source/torch_compile_user_defined_triton_kernel_tutori…
zou3519 7b0fb12
Update recipes_source/torch_compile_user_defined_triton_kernel_tutori…
zou3519 db1c64e
Update recipes_source/torch_compile_user_defined_triton_kernel_tutori…
zou3519 a88cf38
Update recipes_source/torch_compile_user_defined_triton_kernel_tutori…
zou3519 bc9c98e
Update recipes_source/torch_compile_user_defined_triton_kernel_tutori…
svekars 9291e85
Merge branch 'main' into update_triton
svekars 2ee6553
Merge branch 'main' into update_triton
svekars File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -140,17 +140,224 @@ def add_fn(x, y): | |
print(f"Vector addition of\nX:\t{x}\nY:\t{y}\nis equal to\n{out}") | ||
|
||
###################################################################### | ||
# Composibility and Limitations | ||
# Composability | ||
# ------------------------------------------------------------------- | ||
# | ||
# User-defined Triton kernels do not automatically support all PyTorch | ||
# subsystems. This can be seen in the following use cases: | ||
# | ||
# - Adding a CPU fallback | ||
# - Adding a ``FlopCounter`` formula | ||
# - Composing with Tensor Subclasses | ||
# | ||
# To compose with additional PyTorch subsystems, use ``torch.library.triton_op``. | ||
# | ||
# ``triton_op is`` a structured way of defining a custom operator that is backed by one | ||
# or more Triton kernels: like regular custom operators (``torch.library.custom_op``), | ||
# you are able to specify the interactions with PyTorch subsystems via ``torch.library``. | ||
# However, unlike ``torch.library.custom_op``, which creates opaque callables with respect to | ||
# ``torch.compile``, ``torch.compile`` traces into ``triton_op`` to apply optimizations. | ||
# | ||
# Here’s a chart of which API to use when integrating Triton kernels with PyTorch. | ||
# | ||
# .. list-table:: | ||
# :header-rows: 1 | ||
# | ||
# * - | ||
# - Triton kernel (no explicit ``torch.library`` wrapper) | ||
# - ``torch.library.triton_op`` | ||
# - ``torch.library.custom_op`` | ||
# * - Supports inference | ||
# - Yes | ||
# - Yes | ||
# - Yes | ||
# * - Supports training | ||
# - In the majority of cases | ||
# - Yes | ||
# - Yes | ||
# * - Supports ``torch.compile`` | ||
# - Yes | ||
# - Yes | ||
# - Yes | ||
# * - Supports ``torch.compile(fullgraph=True)`` | ||
# - In the majority of cases | ||
# - In the majority of cases | ||
# - In all cases | ||
# * - Does torch.compile trace into the implementation? | ||
# - Yes | ||
# - Yes | ||
# - No | ||
# * - Supports AOTInductor | ||
# - Yes | ||
# - Yes | ||
# - No | ||
# * - Supports PyTorch Subsystems like FlopCounterMode, CPU Fallback, Tensor Subclasses | ||
# - No | ||
# - Yes | ||
# - Yes | ||
|
||
###################################################################### | ||
# Wrapping Triton kernels with ``triton_op`` | ||
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ | ||
# | ||
# Use ``torch.library.triton_op`` to wrap a function that may invoke one or more Triton kernels. | ||
# Use ``torch.library.wrap_triton`` to wrap the calls to the Triton kernel. | ||
|
||
from torch.library import triton_op, wrap_triton | ||
|
||
@triton_op("mylib::mysin", mutates_args={}) | ||
def mysin(x: torch.Tensor) -> torch.Tensor: | ||
out = torch.empty_like(x) | ||
n_elements = x.numel() | ||
wrap_triton(sin_kernel)[(n_elements,)](x, out, n_elements, BLOCK_SIZE=4) | ||
return out | ||
|
||
@triton.jit | ||
def sin_kernel( | ||
in_ptr0, | ||
out_ptr, | ||
n_elements, | ||
BLOCK_SIZE: "tl.constexpr", | ||
): | ||
pid = tl.program_id(axis=0) | ||
block_start = pid * BLOCK_SIZE | ||
offsets = block_start + tl.arange(0, BLOCK_SIZE) | ||
mask = offsets < n_elements | ||
x = tl.load(in_ptr0 + offsets, mask=mask) | ||
output = tl.sin(x) | ||
tl.store(out_ptr + offsets, output, mask=mask) | ||
|
||
def sin_triton(x): | ||
out = torch.empty_like(x) | ||
n_elements = x.numel() | ||
sin_kernel[(n_elements,)](x, out, n_elements, BLOCK_SIZE=4) | ||
return out | ||
|
||
###################################################################### | ||
# You can invoke the ``triton_op`` in one of the following two ways. | ||
|
||
x = torch.randn(3, device="cuda") | ||
y = mysin(x) | ||
z = torch.ops.mylib.mysin.default(x) | ||
|
||
assert torch.allclose(y, x.sin()) | ||
assert torch.allclose(z, x.sin()) | ||
|
||
###################################################################### | ||
# The resulting ``triton_op`` works with ``torch.compile`` and ``AOTInductor``. | ||
|
||
y = torch.compile(mysin)(x) | ||
assert torch.allclose(y, x.sin()) | ||
|
||
###################################################################### | ||
# Adding training support | ||
# ^^^^^^^^^^^^^^^^^^^^^^^ | ||
# | ||
# Use ``register_autograd`` to add an autograd formula for the ``triton_op``. | ||
# Prefer this to using ``torch.autograd.Function`` (which has various composability footguns | ||
# with ``torch.compile``). | ||
|
||
def backward(ctx, grad_output): | ||
x, = ctx.saved_tensors | ||
return grad_input * x.cos() | ||
|
||
def setup_context(ctx, inputs, output): | ||
x, = inputs | ||
ctx.save_for_backward(x) | ||
|
||
mysin.register_autograd(backward, setup_context=setup_context) | ||
|
||
###################################################################### | ||
# Note that the backward must be a composition of PyTorch-understood operators. | ||
# If you want the backward to call Triton kernels, then those must be wrapped in ``triton_op`` as well: | ||
|
||
@triton.jit | ||
def cos_kernel( | ||
in_ptr0, | ||
out_ptr, | ||
n_elements, | ||
BLOCK_SIZE: "tl.constexpr", | ||
): | ||
pid = tl.program_id(axis=0) | ||
block_start = pid * BLOCK_SIZE | ||
offsets = block_start + tl.arange(0, BLOCK_SIZE) | ||
mask = offsets < n_elements | ||
x = tl.load(in_ptr0 + offsets, mask=mask) | ||
output = tl.cos(x) | ||
tl.store(out_ptr + offsets, output, mask=mask) | ||
|
||
@triton_op("mylib::mycos", mutates_args={}) | ||
def mycos(x: torch.Tensor) -> torch.Tensor: | ||
out = torch.empty_like(x) | ||
n_elements = x.numel() | ||
wrap_triton(cos_kernel)[(n_elements,)](x, out, n_elements, BLOCK_SIZE=4) | ||
return out | ||
|
||
def backward(ctx, grad_output): | ||
x, = ctx.saved_tensors | ||
return grad_input * mycos(x) | ||
|
||
def setup_context(ctx, inputs, output): | ||
x, = inputs | ||
ctx.save_for_backward(x) | ||
|
||
mysin.register_autograd(backward, setup_context=setup_context) | ||
|
||
###################################################################### | ||
# Adding a CPU Fallback | ||
# ^^^^^^^^^^^^^^^^^^^^^ | ||
# Triton kernels don’t run on CPU. Use ``register_kernel`` to add a CPU (or any other device) fallback for the ``triton_op``: | ||
|
||
@mysin.register_kernel("cpu") | ||
def _(x): | ||
return torch.sin(x) | ||
|
||
x = torch.randn(3) | ||
y = mysin(x) | ||
assert torch.allclose(y, x.sin()) | ||
|
||
###################################################################### | ||
# The fallback must be composed of PyTorch operators. | ||
|
||
###################################################################### | ||
# Adding a FlopCounter Formula | ||
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^ | ||
# | ||
# To specify how many flops the triton kernel reports under PyTorch's flop counter, | ||
# use ``register_flop_formula``. | ||
|
||
from torch.utils.flop_counter import FlopCounterMode, register_flop_formula | ||
|
||
@register_flop_formula(torch.ops.mylib.mysin) | ||
def _(x_shape): | ||
numel = 1 | ||
for s in x_shape: | ||
numel *= s | ||
return numel | ||
|
||
x = torch.randn(3, device="cuda") | ||
|
||
######################################################### | ||
# ``FlopCounterMode`` requires `tabulate <https://pypi.org/project/tabulate/>`__. | ||
# Before running the code below, make sure you have ``tabulate`` installed or install by | ||
# running ``pip install tabulate``. | ||
# | ||
# >>> with FlopCounterMode() as flop_counter: | ||
# >>> y = mysin(x) | ||
|
||
###################################################################### | ||
# Limitations | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Having this title be limitations, and then having two paragraphs of not-limitations is a bit weird |
||
# -------------------------------------------------------------------- | ||
# | ||
# As of PyTorch 2.3, the support for user-defined Triton kernels in ``torch.compile`` | ||
# includes dynamic shapes, ``torch.autograd.Function``, JIT inductor, and AOT inductor. | ||
# You can use these features together to build complex, high-performance models. | ||
# | ||
# PyTorch 2.6 added ``torch.library.triton_op``, which adds support for | ||
# user-defined Triton kernels in tensor subclasses and other advanced features. | ||
# | ||
# However, there are certain limitations to be aware of: | ||
# | ||
# * **Tensor Subclasses:** Currently, there is no support for | ||
# tensor subclasses and other advanced features. | ||
# * **Triton Features:** While ``triton.heuristics`` can be used either standalone or | ||
# before ``triton.autotune``, it cannot be used after ``triton.autotune``. This | ||
# implies that if ``triton.heuristics`` and ``triton.autotune`` are to be used | ||
|
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess this is the user-defined triton kernel w/ no
@triton_op
? imo it would make sense to remove it (or add some context around it to indicate that this is the alternative where you don't use@triton_op
)