Skip to content

Commit c27636a

Browse files
zou3519svekars
andauthored
Update user-defined triton kernels tutorial with new torch.library.triton_op (#3227)
Update user-defined triton kernels tutorial with new torch.library.triton_op The new API is a more advanced complement to the existing APIs. --------- Co-authored-by: Svetlana Karslioglu <svekars@meta.com>
1 parent 5786e97 commit c27636a

File tree

1 file changed

+210
-3
lines changed

1 file changed

+210
-3
lines changed

recipes_source/torch_compile_user_defined_triton_kernel_tutorial.py

Lines changed: 210 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -140,17 +140,224 @@ def add_fn(x, y):
140140
print(f"Vector addition of\nX:\t{x}\nY:\t{y}\nis equal to\n{out}")
141141

142142
######################################################################
143-
# Composibility and Limitations
143+
# Composability
144+
# -------------------------------------------------------------------
145+
#
146+
# User-defined Triton kernels do not automatically support all PyTorch
147+
# subsystems. This can be seen in the following use cases:
148+
#
149+
# - Adding a CPU fallback
150+
# - Adding a ``FlopCounter`` formula
151+
# - Composing with Tensor Subclasses
152+
#
153+
# To compose with additional PyTorch subsystems, use ``torch.library.triton_op``.
154+
#
155+
# ``triton_op is`` a structured way of defining a custom operator that is backed by one
156+
# or more Triton kernels: like regular custom operators (``torch.library.custom_op``),
157+
# you are able to specify the interactions with PyTorch subsystems via ``torch.library``.
158+
# However, unlike ``torch.library.custom_op``, which creates opaque callables with respect to
159+
# ``torch.compile``, ``torch.compile`` traces into ``triton_op`` to apply optimizations.
160+
#
161+
# Here’s a chart of which API to use when integrating Triton kernels with PyTorch.
162+
#
163+
# .. list-table::
164+
# :header-rows: 1
165+
#
166+
# * -
167+
# - Triton kernel (no explicit ``torch.library`` wrapper)
168+
# - ``torch.library.triton_op``
169+
# - ``torch.library.custom_op``
170+
# * - Supports inference
171+
# - Yes
172+
# - Yes
173+
# - Yes
174+
# * - Supports training
175+
# - In the majority of cases
176+
# - Yes
177+
# - Yes
178+
# * - Supports ``torch.compile``
179+
# - Yes
180+
# - Yes
181+
# - Yes
182+
# * - Supports ``torch.compile(fullgraph=True)``
183+
# - In the majority of cases
184+
# - In the majority of cases
185+
# - In all cases
186+
# * - Does torch.compile trace into the implementation?
187+
# - Yes
188+
# - Yes
189+
# - No
190+
# * - Supports AOTInductor
191+
# - Yes
192+
# - Yes
193+
# - No
194+
# * - Supports PyTorch Subsystems like FlopCounterMode, CPU Fallback, Tensor Subclasses
195+
# - No
196+
# - Yes
197+
# - Yes
198+
199+
######################################################################
200+
# Wrapping Triton kernels with ``triton_op``
201+
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
202+
#
203+
# Use ``torch.library.triton_op`` to wrap a function that may invoke one or more Triton kernels.
204+
# Use ``torch.library.wrap_triton`` to wrap the calls to the Triton kernel.
205+
206+
from torch.library import triton_op, wrap_triton
207+
208+
@triton_op("mylib::mysin", mutates_args={})
209+
def mysin(x: torch.Tensor) -> torch.Tensor:
210+
out = torch.empty_like(x)
211+
n_elements = x.numel()
212+
wrap_triton(sin_kernel)[(n_elements,)](x, out, n_elements, BLOCK_SIZE=4)
213+
return out
214+
215+
@triton.jit
216+
def sin_kernel(
217+
in_ptr0,
218+
out_ptr,
219+
n_elements,
220+
BLOCK_SIZE: "tl.constexpr",
221+
):
222+
pid = tl.program_id(axis=0)
223+
block_start = pid * BLOCK_SIZE
224+
offsets = block_start + tl.arange(0, BLOCK_SIZE)
225+
mask = offsets < n_elements
226+
x = tl.load(in_ptr0 + offsets, mask=mask)
227+
output = tl.sin(x)
228+
tl.store(out_ptr + offsets, output, mask=mask)
229+
230+
def sin_triton(x):
231+
out = torch.empty_like(x)
232+
n_elements = x.numel()
233+
sin_kernel[(n_elements,)](x, out, n_elements, BLOCK_SIZE=4)
234+
return out
235+
236+
######################################################################
237+
# You can invoke the ``triton_op`` in one of the following two ways.
238+
239+
x = torch.randn(3, device="cuda")
240+
y = mysin(x)
241+
z = torch.ops.mylib.mysin.default(x)
242+
243+
assert torch.allclose(y, x.sin())
244+
assert torch.allclose(z, x.sin())
245+
246+
######################################################################
247+
# The resulting ``triton_op`` works with ``torch.compile`` and ``AOTInductor``.
248+
249+
y = torch.compile(mysin)(x)
250+
assert torch.allclose(y, x.sin())
251+
252+
######################################################################
253+
# Adding training support
254+
# ^^^^^^^^^^^^^^^^^^^^^^^
255+
#
256+
# Use ``register_autograd`` to add an autograd formula for the ``triton_op``.
257+
# Prefer this to using ``torch.autograd.Function`` (which has various composability footguns
258+
# with ``torch.compile``).
259+
260+
def backward(ctx, grad_output):
261+
x, = ctx.saved_tensors
262+
return grad_input * x.cos()
263+
264+
def setup_context(ctx, inputs, output):
265+
x, = inputs
266+
ctx.save_for_backward(x)
267+
268+
mysin.register_autograd(backward, setup_context=setup_context)
269+
270+
######################################################################
271+
# Note that the backward must be a composition of PyTorch-understood operators.
272+
# If you want the backward to call Triton kernels, then those must be wrapped in ``triton_op`` as well:
273+
274+
@triton.jit
275+
def cos_kernel(
276+
in_ptr0,
277+
out_ptr,
278+
n_elements,
279+
BLOCK_SIZE: "tl.constexpr",
280+
):
281+
pid = tl.program_id(axis=0)
282+
block_start = pid * BLOCK_SIZE
283+
offsets = block_start + tl.arange(0, BLOCK_SIZE)
284+
mask = offsets < n_elements
285+
x = tl.load(in_ptr0 + offsets, mask=mask)
286+
output = tl.cos(x)
287+
tl.store(out_ptr + offsets, output, mask=mask)
288+
289+
@triton_op("mylib::mycos", mutates_args={})
290+
def mycos(x: torch.Tensor) -> torch.Tensor:
291+
out = torch.empty_like(x)
292+
n_elements = x.numel()
293+
wrap_triton(cos_kernel)[(n_elements,)](x, out, n_elements, BLOCK_SIZE=4)
294+
return out
295+
296+
def backward(ctx, grad_output):
297+
x, = ctx.saved_tensors
298+
return grad_input * mycos(x)
299+
300+
def setup_context(ctx, inputs, output):
301+
x, = inputs
302+
ctx.save_for_backward(x)
303+
304+
mysin.register_autograd(backward, setup_context=setup_context)
305+
306+
######################################################################
307+
# Adding a CPU Fallback
308+
# ^^^^^^^^^^^^^^^^^^^^^
309+
# Triton kernels don’t run on CPU. Use ``register_kernel`` to add a CPU (or any other device) fallback for the ``triton_op``:
310+
311+
@mysin.register_kernel("cpu")
312+
def _(x):
313+
return torch.sin(x)
314+
315+
x = torch.randn(3)
316+
y = mysin(x)
317+
assert torch.allclose(y, x.sin())
318+
319+
######################################################################
320+
# The fallback must be composed of PyTorch operators.
321+
322+
######################################################################
323+
# Adding a FlopCounter Formula
324+
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
325+
#
326+
# To specify how many flops the triton kernel reports under PyTorch's flop counter,
327+
# use ``register_flop_formula``.
328+
329+
from torch.utils.flop_counter import FlopCounterMode, register_flop_formula
330+
331+
@register_flop_formula(torch.ops.mylib.mysin)
332+
def _(x_shape):
333+
numel = 1
334+
for s in x_shape:
335+
numel *= s
336+
return numel
337+
338+
x = torch.randn(3, device="cuda")
339+
340+
#########################################################
341+
# ``FlopCounterMode`` requires `tabulate <https://pypi.org/project/tabulate/>`__.
342+
# Before running the code below, make sure you have ``tabulate`` installed or install by
343+
# running ``pip install tabulate``.
344+
#
345+
# >>> with FlopCounterMode() as flop_counter:
346+
# >>> y = mysin(x)
347+
348+
######################################################################
349+
# Limitations
144350
# --------------------------------------------------------------------
145351
#
146352
# As of PyTorch 2.3, the support for user-defined Triton kernels in ``torch.compile``
147353
# includes dynamic shapes, ``torch.autograd.Function``, JIT inductor, and AOT inductor.
148354
# You can use these features together to build complex, high-performance models.
149355
#
356+
# PyTorch 2.6 added ``torch.library.triton_op``, which adds support for
357+
# user-defined Triton kernels in tensor subclasses and other advanced features.
358+
#
150359
# However, there are certain limitations to be aware of:
151360
#
152-
# * **Tensor Subclasses:** Currently, there is no support for
153-
# tensor subclasses and other advanced features.
154361
# * **Triton Features:** While ``triton.heuristics`` can be used either standalone or
155362
# before ``triton.autotune``, it cannot be used after ``triton.autotune``. This
156363
# implies that if ``triton.heuristics`` and ``triton.autotune`` are to be used

0 commit comments

Comments
 (0)