Skip to content

Commit c307876

Browse files
committed
Update user-defined triton kernels tutorial with new torch.library.triton_op
The new API is a more advanced complement to the existing APIs.
1 parent 82f449a commit c307876

File tree

1 file changed

+206
-3
lines changed

1 file changed

+206
-3
lines changed

recipes_source/torch_compile_user_defined_triton_kernel_tutorial.py

Lines changed: 206 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -140,17 +140,220 @@ def add_fn(x, y):
140140
print(f"Vector addition of\nX:\t{x}\nY:\t{y}\nis equal to\n{out}")
141141

142142
######################################################################
143-
# Composibility and Limitations
143+
# Composability
144+
# -------------------------------------------------------------------
145+
#
146+
# User-defined triton kernels do not automatically support all PyTorch
147+
# subsystems, like in the following use cases:
148+
# - Adding a CPU fallback
149+
# - Adding a ``FlopCounter`` formula
150+
# - Composing with Tensor Subclasses
151+
#
152+
# To compose with additional PyTorch subsystems, use ``torch.library.triton_op``.
153+
#
154+
# triton_op is a structured way of defining a custom operator that is backed by one
155+
# or more triton kernels: like regular custom operators (``torch.library.custom_op``),
156+
# you are able to specify the interactions with PyTorch subsystems via ``torch.library``.
157+
# However, unlike ``torch.library.custom_op``, which creates opaque callables w.r.t.
158+
# ``torch.compile``, ``torch.compile`` traces into ``triton_op`` to apply optimizations.
159+
#
160+
# Here’s a chart of which API to use when integrating triton kernels with PyTorch.
161+
#
162+
# .. list-table::
163+
# :header-rows: 1
164+
#
165+
# * -
166+
# - triton kernel (no explicit torch.library wrapper)
167+
# - torch.library.triton_op
168+
# - torch.library.custom_op
169+
# * - Supports inference
170+
# - Yes
171+
# - Yes
172+
# - Yes
173+
# * - Supports training
174+
# - In the majority of cases
175+
# - Yes
176+
# - Yes
177+
# * - Supports torch.compile
178+
# - Yes
179+
# - Yes
180+
# - Yes
181+
# * - Supports torch.compile(fullgraph=True)
182+
# - In the majority of cases
183+
# - In the majority of cases
184+
# - In all cases
185+
# * - Does torch.compile trace into the implementation?
186+
# - Yes
187+
# - Yes
188+
# - No
189+
# * - Supports AOTInductor
190+
# - Yes
191+
# - Yes
192+
# - No
193+
# * - Supports PyTorch Subsystems like FlopCounterMode, CPU Fallback, Tensor Subclasses
194+
# - No
195+
# - Yes
196+
# - Yes
197+
198+
######################################################################
199+
# Wrapping triton kernels with triton_op
200+
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
201+
#
202+
# Use ``torch.library.triton_op`` to wrap a function that may invoke one or more triton kernels.
203+
# Use ``torch.library.wrap_triton`` to wrap the calls to the triton kernel.
204+
205+
from torch.library import triton_op, wrap_triton
206+
207+
@triton_op("mylib::mysin", mutates_args={})
208+
def mysin(x: torch.Tensor) -> torch.Tensor:
209+
out = torch.empty_like(x)
210+
n_elements = x.numel()
211+
wrap_triton(sin_kernel)[(n_elements,)](x, out, n_elements, BLOCK_SIZE=4)
212+
return out
213+
214+
@triton.jit
215+
def sin_kernel(
216+
in_ptr0,
217+
out_ptr,
218+
n_elements,
219+
BLOCK_SIZE: "tl.constexpr",
220+
):
221+
pid = tl.program_id(axis=0)
222+
block_start = pid * BLOCK_SIZE
223+
offsets = block_start + tl.arange(0, BLOCK_SIZE)
224+
mask = offsets < n_elements
225+
x = tl.load(in_ptr0 + offsets, mask=mask)
226+
output = tl.sin(x)
227+
tl.store(out_ptr + offsets, output, mask=mask)
228+
229+
def sin_triton(x):
230+
out = torch.empty_like(x)
231+
n_elements = x.numel()
232+
sin_kernel[(n_elements,)](x, out, n_elements, BLOCK_SIZE=4)
233+
return out
234+
235+
######################################################################
236+
# You can invoke the ``triton_op`` in one of the following two ways.
237+
238+
x = torch.randn(3, device="cuda")
239+
y = mysin(x)
240+
z = torch.ops.mylib.mysin.default(x)
241+
242+
assert torch.allclose(y, x.sin())
243+
assert torch.allclose(z, x.sin())
244+
245+
######################################################################
246+
# The resulting ``triton_op`` works with ``torch.compile`` and ``AOTInductor``.
247+
248+
y = torch.compile(mysin)(x)
249+
assert torch.allclose(y, x.sin())
250+
251+
######################################################################
252+
# Adding training support
253+
# ^^^^^^^^^^^^^^^^^^^^^^^
254+
#
255+
# Use ``register_autograd`` to add an autograd formula for the ``triton_op``.
256+
# Prefer this to using ``torch.autograd.Function`` (which has various composability footguns
257+
# with ``torch.compile``).
258+
259+
def backward(ctx, grad_output):
260+
x, = ctx.saved_tensors
261+
return grad_input * x.cos()
262+
263+
def setup_context(ctx, inputs, output):
264+
x, = inputs
265+
ctx.save_for_backward(x)
266+
267+
mysin.register_autograd(backward, setup_context=setup_context)
268+
269+
######################################################################
270+
# Note that the backward must be a composition of PyTorch-understood operators.
271+
# If you want the backward to call triton kernels, then those must be wrapped in ``triton_op`` as well:
272+
273+
@triton.jit
274+
def cos_kernel(
275+
in_ptr0,
276+
out_ptr,
277+
n_elements,
278+
BLOCK_SIZE: "tl.constexpr",
279+
):
280+
pid = tl.program_id(axis=0)
281+
block_start = pid * BLOCK_SIZE
282+
offsets = block_start + tl.arange(0, BLOCK_SIZE)
283+
mask = offsets < n_elements
284+
x = tl.load(in_ptr0 + offsets, mask=mask)
285+
output = tl.cos(x)
286+
tl.store(out_ptr + offsets, output, mask=mask)
287+
288+
@triton_op("mylib::mycos", mutates_args={})
289+
def mycos(x: torch.Tensor) -> torch.Tensor:
290+
out = torch.empty_like(x)
291+
n_elements = x.numel()
292+
wrap_triton(cos_kernel)[(n_elements,)](x, out, n_elements, BLOCK_SIZE=4)
293+
return out
294+
295+
def backward(ctx, grad_output):
296+
x, = ctx.saved_tensors
297+
return grad_input * mycos(x)
298+
299+
def setup_context(ctx, inputs, output):
300+
x, = inputs
301+
ctx.save_for_backward(x)
302+
303+
mysin.register_autograd(backward, setup_context=setup_context)
304+
305+
######################################################################
306+
# Adding a CPU Fallback
307+
# ^^^^^^^^^^^^^^^^^^^^^
308+
# triton kernels don’t run on CPU. Use ``register_kernel`` to add a CPU (or any other device) fallback for the ``triton_op``:
309+
310+
@mysin.register_kernel("cpu")
311+
def _(x):
312+
return torch.sin(x)
313+
314+
x = torch.randn(3)
315+
y = mysin(x)
316+
assert torch.allclose(y, x.sin())
317+
318+
######################################################################
319+
# The fallback must be composed of PyTorch operators.
320+
321+
######################################################################
322+
# Adding a FlopCounter Formula
323+
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
324+
#
325+
# To specify how many flops the triton kernel reports under PyTorch's flop counter,
326+
# use ``register_flop_formula``.
327+
328+
from torch.utils.flop_counter import FlopCounterMode, register_flop_formula
329+
330+
@register_flop_formula(torch.ops.mylib.mysin)
331+
def _(x_shape):
332+
numel = 1
333+
for s in x_shape:
334+
numel *= s
335+
return numel
336+
337+
x = torch.randn(3, device="cuda")
338+
339+
# NB: FlopCounterMode requires tabulate.
340+
#
341+
# >>> with FlopCounterMode() as flop_counter:
342+
# >>> y = mysin(x)
343+
344+
######################################################################
345+
# Limitations
144346
# --------------------------------------------------------------------
145347
#
146348
# As of PyTorch 2.3, the support for user-defined Triton kernels in ``torch.compile``
147349
# includes dynamic shapes, ``torch.autograd.Function``, JIT inductor, and AOT inductor.
148350
# You can use these features together to build complex, high-performance models.
149351
#
352+
# PyTorch 2.6 added ``torch.library.triton_op``, which adds support for
353+
# user-defined Triton kernels in tensor subclasses and other advanced features.
354+
#
150355
# However, there are certain limitations to be aware of:
151356
#
152-
# * **Tensor Subclasses:** Currently, there is no support for
153-
# tensor subclasses and other advanced features.
154357
# * **Triton Features:** While ``triton.heuristics`` can be used either standalone or
155358
# before ``triton.autotune``, it cannot be used after ``triton.autotune``. This
156359
# implies that if ``triton.heuristics`` and ``triton.autotune`` are to be used

0 commit comments

Comments
 (0)