From c307876147b5df626c1d4c5e910bfd7280d9b672 Mon Sep 17 00:00:00 2001 From: rzou Date: Fri, 10 Jan 2025 08:58:00 -0800 Subject: [PATCH 01/19] Update user-defined triton kernels tutorial with new torch.library.triton_op The new API is a more advanced complement to the existing APIs. --- ...ile_user_defined_triton_kernel_tutorial.py | 209 +++++++++++++++++- 1 file changed, 206 insertions(+), 3 deletions(-) diff --git a/recipes_source/torch_compile_user_defined_triton_kernel_tutorial.py b/recipes_source/torch_compile_user_defined_triton_kernel_tutorial.py index 7d183af6fd1..98c041bb275 100644 --- a/recipes_source/torch_compile_user_defined_triton_kernel_tutorial.py +++ b/recipes_source/torch_compile_user_defined_triton_kernel_tutorial.py @@ -140,17 +140,220 @@ def add_fn(x, y): print(f"Vector addition of\nX:\t{x}\nY:\t{y}\nis equal to\n{out}") ###################################################################### -# Composibility and Limitations +# Composability +# ------------------------------------------------------------------- +# +# User-defined triton kernels do not automatically support all PyTorch +# subsystems, like in the following use cases: +# - Adding a CPU fallback +# - Adding a ``FlopCounter`` formula +# - Composing with Tensor Subclasses +# +# To compose with additional PyTorch subsystems, use ``torch.library.triton_op``. +# +# triton_op is a structured way of defining a custom operator that is backed by one +# or more triton kernels: like regular custom operators (``torch.library.custom_op``), +# you are able to specify the interactions with PyTorch subsystems via ``torch.library``. +# However, unlike ``torch.library.custom_op``, which creates opaque callables w.r.t. +# ``torch.compile``, ``torch.compile`` traces into ``triton_op`` to apply optimizations. +# +# Here’s a chart of which API to use when integrating triton kernels with PyTorch. +# +# .. list-table:: +# :header-rows: 1 +# +# * - +# - triton kernel (no explicit torch.library wrapper) +# - torch.library.triton_op +# - torch.library.custom_op +# * - Supports inference +# - Yes +# - Yes +# - Yes +# * - Supports training +# - In the majority of cases +# - Yes +# - Yes +# * - Supports torch.compile +# - Yes +# - Yes +# - Yes +# * - Supports torch.compile(fullgraph=True) +# - In the majority of cases +# - In the majority of cases +# - In all cases +# * - Does torch.compile trace into the implementation? +# - Yes +# - Yes +# - No +# * - Supports AOTInductor +# - Yes +# - Yes +# - No +# * - Supports PyTorch Subsystems like FlopCounterMode, CPU Fallback, Tensor Subclasses +# - No +# - Yes +# - Yes + +###################################################################### +# Wrapping triton kernels with triton_op +# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +# +# Use ``torch.library.triton_op`` to wrap a function that may invoke one or more triton kernels. +# Use ``torch.library.wrap_triton`` to wrap the calls to the triton kernel. + +from torch.library import triton_op, wrap_triton + +@triton_op("mylib::mysin", mutates_args={}) +def mysin(x: torch.Tensor) -> torch.Tensor: + out = torch.empty_like(x) + n_elements = x.numel() + wrap_triton(sin_kernel)[(n_elements,)](x, out, n_elements, BLOCK_SIZE=4) + return out + +@triton.jit +def sin_kernel( + in_ptr0, + out_ptr, + n_elements, + BLOCK_SIZE: "tl.constexpr", +): + pid = tl.program_id(axis=0) + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elements + x = tl.load(in_ptr0 + offsets, mask=mask) + output = tl.sin(x) + tl.store(out_ptr + offsets, output, mask=mask) + +def sin_triton(x): + out = torch.empty_like(x) + n_elements = x.numel() + sin_kernel[(n_elements,)](x, out, n_elements, BLOCK_SIZE=4) + return out + +###################################################################### +# You can invoke the ``triton_op`` in one of the following two ways. + +x = torch.randn(3, device="cuda") +y = mysin(x) +z = torch.ops.mylib.mysin.default(x) + +assert torch.allclose(y, x.sin()) +assert torch.allclose(z, x.sin()) + +###################################################################### +# The resulting ``triton_op`` works with ``torch.compile`` and ``AOTInductor``. + +y = torch.compile(mysin)(x) +assert torch.allclose(y, x.sin()) + +###################################################################### +# Adding training support +# ^^^^^^^^^^^^^^^^^^^^^^^ +# +# Use ``register_autograd`` to add an autograd formula for the ``triton_op``. +# Prefer this to using ``torch.autograd.Function`` (which has various composability footguns +# with ``torch.compile``). + +def backward(ctx, grad_output): + x, = ctx.saved_tensors + return grad_input * x.cos() + +def setup_context(ctx, inputs, output): + x, = inputs + ctx.save_for_backward(x) + +mysin.register_autograd(backward, setup_context=setup_context) + +###################################################################### +# Note that the backward must be a composition of PyTorch-understood operators. +# If you want the backward to call triton kernels, then those must be wrapped in ``triton_op`` as well: + +@triton.jit +def cos_kernel( + in_ptr0, + out_ptr, + n_elements, + BLOCK_SIZE: "tl.constexpr", +): + pid = tl.program_id(axis=0) + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elements + x = tl.load(in_ptr0 + offsets, mask=mask) + output = tl.cos(x) + tl.store(out_ptr + offsets, output, mask=mask) + +@triton_op("mylib::mycos", mutates_args={}) +def mycos(x: torch.Tensor) -> torch.Tensor: + out = torch.empty_like(x) + n_elements = x.numel() + wrap_triton(cos_kernel)[(n_elements,)](x, out, n_elements, BLOCK_SIZE=4) + return out + +def backward(ctx, grad_output): + x, = ctx.saved_tensors + return grad_input * mycos(x) + +def setup_context(ctx, inputs, output): + x, = inputs + ctx.save_for_backward(x) + +mysin.register_autograd(backward, setup_context=setup_context) + +###################################################################### +# Adding a CPU Fallback +# ^^^^^^^^^^^^^^^^^^^^^ +# triton kernels don’t run on CPU. Use ``register_kernel`` to add a CPU (or any other device) fallback for the ``triton_op``: + +@mysin.register_kernel("cpu") +def _(x): + return torch.sin(x) + +x = torch.randn(3) +y = mysin(x) +assert torch.allclose(y, x.sin()) + +###################################################################### +# The fallback must be composed of PyTorch operators. + +###################################################################### +# Adding a FlopCounter Formula +# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +# +# To specify how many flops the triton kernel reports under PyTorch's flop counter, +# use ``register_flop_formula``. + +from torch.utils.flop_counter import FlopCounterMode, register_flop_formula + +@register_flop_formula(torch.ops.mylib.mysin) +def _(x_shape): + numel = 1 + for s in x_shape: + numel *= s + return numel + +x = torch.randn(3, device="cuda") + +# NB: FlopCounterMode requires tabulate. +# +# >>> with FlopCounterMode() as flop_counter: +# >>> y = mysin(x) + +###################################################################### +# Limitations # -------------------------------------------------------------------- # # As of PyTorch 2.3, the support for user-defined Triton kernels in ``torch.compile`` # includes dynamic shapes, ``torch.autograd.Function``, JIT inductor, and AOT inductor. # You can use these features together to build complex, high-performance models. # +# PyTorch 2.6 added ``torch.library.triton_op``, which adds support for +# user-defined Triton kernels in tensor subclasses and other advanced features. +# # However, there are certain limitations to be aware of: # -# * **Tensor Subclasses:** Currently, there is no support for -# tensor subclasses and other advanced features. # * **Triton Features:** While ``triton.heuristics`` can be used either standalone or # before ``triton.autotune``, it cannot be used after ``triton.autotune``. This # implies that if ``triton.heuristics`` and ``triton.autotune`` are to be used From 5b03d047b898644a53c73cc49c3085999927fabf Mon Sep 17 00:00:00 2001 From: Svetlana Karslioglu Date: Fri, 10 Jan 2025 10:09:37 -0800 Subject: [PATCH 02/19] Update recipes_source/torch_compile_user_defined_triton_kernel_tutorial.py --- .../torch_compile_user_defined_triton_kernel_tutorial.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/recipes_source/torch_compile_user_defined_triton_kernel_tutorial.py b/recipes_source/torch_compile_user_defined_triton_kernel_tutorial.py index 98c041bb275..65c23f41466 100644 --- a/recipes_source/torch_compile_user_defined_triton_kernel_tutorial.py +++ b/recipes_source/torch_compile_user_defined_triton_kernel_tutorial.py @@ -144,7 +144,8 @@ def add_fn(x, y): # ------------------------------------------------------------------- # # User-defined triton kernels do not automatically support all PyTorch -# subsystems, like in the following use cases: +# subsystems. This can be seen in the following use cases: + # - Adding a CPU fallback # - Adding a ``FlopCounter`` formula # - Composing with Tensor Subclasses From dd359f7d4ed80ad1651094f974e9ffb7976f1e91 Mon Sep 17 00:00:00 2001 From: Richard Zou Date: Mon, 13 Jan 2025 14:34:29 -0500 Subject: [PATCH 03/19] Update recipes_source/torch_compile_user_defined_triton_kernel_tutorial.py Co-authored-by: Svetlana Karslioglu --- .../torch_compile_user_defined_triton_kernel_tutorial.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/recipes_source/torch_compile_user_defined_triton_kernel_tutorial.py b/recipes_source/torch_compile_user_defined_triton_kernel_tutorial.py index 65c23f41466..a813dcc953a 100644 --- a/recipes_source/torch_compile_user_defined_triton_kernel_tutorial.py +++ b/recipes_source/torch_compile_user_defined_triton_kernel_tutorial.py @@ -143,7 +143,7 @@ def add_fn(x, y): # Composability # ------------------------------------------------------------------- # -# User-defined triton kernels do not automatically support all PyTorch +# User-defined Triton kernels do not automatically support all PyTorch # subsystems. This can be seen in the following use cases: # - Adding a CPU fallback From 20e86d096a64e6cbb29dc6522cc8302f7977dc8c Mon Sep 17 00:00:00 2001 From: Richard Zou Date: Mon, 13 Jan 2025 14:34:41 -0500 Subject: [PATCH 04/19] Update recipes_source/torch_compile_user_defined_triton_kernel_tutorial.py Co-authored-by: Svetlana Karslioglu --- .../torch_compile_user_defined_triton_kernel_tutorial.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/recipes_source/torch_compile_user_defined_triton_kernel_tutorial.py b/recipes_source/torch_compile_user_defined_triton_kernel_tutorial.py index a813dcc953a..aba3577e1aa 100644 --- a/recipes_source/torch_compile_user_defined_triton_kernel_tutorial.py +++ b/recipes_source/torch_compile_user_defined_triton_kernel_tutorial.py @@ -153,7 +153,7 @@ def add_fn(x, y): # To compose with additional PyTorch subsystems, use ``torch.library.triton_op``. # # triton_op is a structured way of defining a custom operator that is backed by one -# or more triton kernels: like regular custom operators (``torch.library.custom_op``), +# or more Triton kernels: like regular custom operators (``torch.library.custom_op``), # you are able to specify the interactions with PyTorch subsystems via ``torch.library``. # However, unlike ``torch.library.custom_op``, which creates opaque callables w.r.t. # ``torch.compile``, ``torch.compile`` traces into ``triton_op`` to apply optimizations. From 72d60075b0efa6f99510d17da293bc76e5318b5b Mon Sep 17 00:00:00 2001 From: Richard Zou Date: Mon, 13 Jan 2025 14:34:49 -0500 Subject: [PATCH 05/19] Update recipes_source/torch_compile_user_defined_triton_kernel_tutorial.py Co-authored-by: Svetlana Karslioglu --- .../torch_compile_user_defined_triton_kernel_tutorial.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/recipes_source/torch_compile_user_defined_triton_kernel_tutorial.py b/recipes_source/torch_compile_user_defined_triton_kernel_tutorial.py index aba3577e1aa..f8bd54b666d 100644 --- a/recipes_source/torch_compile_user_defined_triton_kernel_tutorial.py +++ b/recipes_source/torch_compile_user_defined_triton_kernel_tutorial.py @@ -155,7 +155,7 @@ def add_fn(x, y): # triton_op is a structured way of defining a custom operator that is backed by one # or more Triton kernels: like regular custom operators (``torch.library.custom_op``), # you are able to specify the interactions with PyTorch subsystems via ``torch.library``. -# However, unlike ``torch.library.custom_op``, which creates opaque callables w.r.t. +# However, unlike ``torch.library.custom_op``, which creates opaque callables with respect to # ``torch.compile``, ``torch.compile`` traces into ``triton_op`` to apply optimizations. # # Here’s a chart of which API to use when integrating triton kernels with PyTorch. From 58bc83d613d2fbca7f01fc7f260c95583ba6854e Mon Sep 17 00:00:00 2001 From: Richard Zou Date: Mon, 13 Jan 2025 14:35:00 -0500 Subject: [PATCH 06/19] Update recipes_source/torch_compile_user_defined_triton_kernel_tutorial.py Co-authored-by: Svetlana Karslioglu --- .../torch_compile_user_defined_triton_kernel_tutorial.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/recipes_source/torch_compile_user_defined_triton_kernel_tutorial.py b/recipes_source/torch_compile_user_defined_triton_kernel_tutorial.py index f8bd54b666d..be1d9c6f2d2 100644 --- a/recipes_source/torch_compile_user_defined_triton_kernel_tutorial.py +++ b/recipes_source/torch_compile_user_defined_triton_kernel_tutorial.py @@ -200,7 +200,7 @@ def add_fn(x, y): # Wrapping triton kernels with triton_op # ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ # -# Use ``torch.library.triton_op`` to wrap a function that may invoke one or more triton kernels. +# Use ``torch.library.triton_op`` to wrap a function that may invoke one or more Triton kernels. # Use ``torch.library.wrap_triton`` to wrap the calls to the triton kernel. from torch.library import triton_op, wrap_triton From 665cd6e9e328584a3b594cdceb87b5e13ee88a37 Mon Sep 17 00:00:00 2001 From: Richard Zou Date: Mon, 13 Jan 2025 14:35:08 -0500 Subject: [PATCH 07/19] Update recipes_source/torch_compile_user_defined_triton_kernel_tutorial.py Co-authored-by: Svetlana Karslioglu --- .../torch_compile_user_defined_triton_kernel_tutorial.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/recipes_source/torch_compile_user_defined_triton_kernel_tutorial.py b/recipes_source/torch_compile_user_defined_triton_kernel_tutorial.py index be1d9c6f2d2..503534a1292 100644 --- a/recipes_source/torch_compile_user_defined_triton_kernel_tutorial.py +++ b/recipes_source/torch_compile_user_defined_triton_kernel_tutorial.py @@ -269,7 +269,7 @@ def setup_context(ctx, inputs, output): ###################################################################### # Note that the backward must be a composition of PyTorch-understood operators. -# If you want the backward to call triton kernels, then those must be wrapped in ``triton_op`` as well: +# If you want the backward to call Triton kernels, then those must be wrapped in ``triton_op`` as well: @triton.jit def cos_kernel( From 4f9eede4da81f7cf426c5639e610317f768ec0e2 Mon Sep 17 00:00:00 2001 From: Richard Zou Date: Mon, 13 Jan 2025 14:35:25 -0500 Subject: [PATCH 08/19] Update recipes_source/torch_compile_user_defined_triton_kernel_tutorial.py Co-authored-by: Svetlana Karslioglu --- .../torch_compile_user_defined_triton_kernel_tutorial.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/recipes_source/torch_compile_user_defined_triton_kernel_tutorial.py b/recipes_source/torch_compile_user_defined_triton_kernel_tutorial.py index 503534a1292..759a8af59d5 100644 --- a/recipes_source/torch_compile_user_defined_triton_kernel_tutorial.py +++ b/recipes_source/torch_compile_user_defined_triton_kernel_tutorial.py @@ -201,7 +201,7 @@ def add_fn(x, y): # ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ # # Use ``torch.library.triton_op`` to wrap a function that may invoke one or more Triton kernels. -# Use ``torch.library.wrap_triton`` to wrap the calls to the triton kernel. +# Use ``torch.library.wrap_triton`` to wrap the calls to the Triton kernel. from torch.library import triton_op, wrap_triton From e5aaa9a79dba0c9e36aca6a2f8e08216f2555b88 Mon Sep 17 00:00:00 2001 From: Richard Zou Date: Mon, 13 Jan 2025 14:35:32 -0500 Subject: [PATCH 09/19] Update recipes_source/torch_compile_user_defined_triton_kernel_tutorial.py Co-authored-by: Svetlana Karslioglu --- .../torch_compile_user_defined_triton_kernel_tutorial.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/recipes_source/torch_compile_user_defined_triton_kernel_tutorial.py b/recipes_source/torch_compile_user_defined_triton_kernel_tutorial.py index 759a8af59d5..f283c09e8e7 100644 --- a/recipes_source/torch_compile_user_defined_triton_kernel_tutorial.py +++ b/recipes_source/torch_compile_user_defined_triton_kernel_tutorial.py @@ -306,7 +306,7 @@ def setup_context(ctx, inputs, output): ###################################################################### # Adding a CPU Fallback # ^^^^^^^^^^^^^^^^^^^^^ -# triton kernels don’t run on CPU. Use ``register_kernel`` to add a CPU (or any other device) fallback for the ``triton_op``: +# Triton kernels don’t run on CPU. Use ``register_kernel`` to add a CPU (or any other device) fallback for the ``triton_op``: @mysin.register_kernel("cpu") def _(x): From 9de60b4f2536611ef10d168b45da10d83c51e50d Mon Sep 17 00:00:00 2001 From: Richard Zou Date: Mon, 13 Jan 2025 14:35:42 -0500 Subject: [PATCH 10/19] Update recipes_source/torch_compile_user_defined_triton_kernel_tutorial.py Co-authored-by: Svetlana Karslioglu --- .../torch_compile_user_defined_triton_kernel_tutorial.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/recipes_source/torch_compile_user_defined_triton_kernel_tutorial.py b/recipes_source/torch_compile_user_defined_triton_kernel_tutorial.py index f283c09e8e7..90232ad72e9 100644 --- a/recipes_source/torch_compile_user_defined_triton_kernel_tutorial.py +++ b/recipes_source/torch_compile_user_defined_triton_kernel_tutorial.py @@ -337,7 +337,10 @@ def _(x_shape): x = torch.randn(3, device="cuda") -# NB: FlopCounterMode requires tabulate. +######################################################### +# ``FlopCounterMode`` requires `tabulate `__. +# Before running the code below, make sure you have ``tabulate`` installed or install by +# running ``pip install tabulate``. # # >>> with FlopCounterMode() as flop_counter: # >>> y = mysin(x) From a165a7ef2576007c37467c7cd8a03361cf7a79c7 Mon Sep 17 00:00:00 2001 From: Richard Zou Date: Mon, 13 Jan 2025 14:35:54 -0500 Subject: [PATCH 11/19] Update recipes_source/torch_compile_user_defined_triton_kernel_tutorial.py Co-authored-by: Svetlana Karslioglu --- .../torch_compile_user_defined_triton_kernel_tutorial.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/recipes_source/torch_compile_user_defined_triton_kernel_tutorial.py b/recipes_source/torch_compile_user_defined_triton_kernel_tutorial.py index 90232ad72e9..fddaed45829 100644 --- a/recipes_source/torch_compile_user_defined_triton_kernel_tutorial.py +++ b/recipes_source/torch_compile_user_defined_triton_kernel_tutorial.py @@ -158,7 +158,7 @@ def add_fn(x, y): # However, unlike ``torch.library.custom_op``, which creates opaque callables with respect to # ``torch.compile``, ``torch.compile`` traces into ``triton_op`` to apply optimizations. # -# Here’s a chart of which API to use when integrating triton kernels with PyTorch. +# Here’s a chart of which API to use when integrating Triton kernels with PyTorch. # # .. list-table:: # :header-rows: 1 From 3620928e47b0de90c7c1b6f2ecbb1d49806483d2 Mon Sep 17 00:00:00 2001 From: Richard Zou Date: Mon, 13 Jan 2025 14:36:06 -0500 Subject: [PATCH 12/19] Update recipes_source/torch_compile_user_defined_triton_kernel_tutorial.py Co-authored-by: Svetlana Karslioglu --- .../torch_compile_user_defined_triton_kernel_tutorial.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/recipes_source/torch_compile_user_defined_triton_kernel_tutorial.py b/recipes_source/torch_compile_user_defined_triton_kernel_tutorial.py index fddaed45829..1bc56350762 100644 --- a/recipes_source/torch_compile_user_defined_triton_kernel_tutorial.py +++ b/recipes_source/torch_compile_user_defined_triton_kernel_tutorial.py @@ -175,7 +175,7 @@ def add_fn(x, y): # - In the majority of cases # - Yes # - Yes -# * - Supports torch.compile +# * - Supports ``torch.compile`` # - Yes # - Yes # - Yes From e8c676307919d8a09b1192081895c3fa9428cd4a Mon Sep 17 00:00:00 2001 From: Richard Zou Date: Mon, 13 Jan 2025 14:36:19 -0500 Subject: [PATCH 13/19] Update recipes_source/torch_compile_user_defined_triton_kernel_tutorial.py Co-authored-by: Svetlana Karslioglu --- .../torch_compile_user_defined_triton_kernel_tutorial.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/recipes_source/torch_compile_user_defined_triton_kernel_tutorial.py b/recipes_source/torch_compile_user_defined_triton_kernel_tutorial.py index 1bc56350762..666b5630336 100644 --- a/recipes_source/torch_compile_user_defined_triton_kernel_tutorial.py +++ b/recipes_source/torch_compile_user_defined_triton_kernel_tutorial.py @@ -164,7 +164,7 @@ def add_fn(x, y): # :header-rows: 1 # # * - -# - triton kernel (no explicit torch.library wrapper) +# - Triton kernel (no explicit ``torch.library`` wrapper) # - torch.library.triton_op # - torch.library.custom_op # * - Supports inference From cbe7f043c555e5632e8e62417f7012c150be58c3 Mon Sep 17 00:00:00 2001 From: Richard Zou Date: Mon, 13 Jan 2025 14:37:02 -0500 Subject: [PATCH 14/19] Update recipes_source/torch_compile_user_defined_triton_kernel_tutorial.py Co-authored-by: Svetlana Karslioglu --- .../torch_compile_user_defined_triton_kernel_tutorial.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/recipes_source/torch_compile_user_defined_triton_kernel_tutorial.py b/recipes_source/torch_compile_user_defined_triton_kernel_tutorial.py index 666b5630336..e33e0847653 100644 --- a/recipes_source/torch_compile_user_defined_triton_kernel_tutorial.py +++ b/recipes_source/torch_compile_user_defined_triton_kernel_tutorial.py @@ -165,7 +165,7 @@ def add_fn(x, y): # # * - # - Triton kernel (no explicit ``torch.library`` wrapper) -# - torch.library.triton_op +# - ``torch.library.triton_op`` # - torch.library.custom_op # * - Supports inference # - Yes From 844ff02e1ba524980004f7a8a656d84eb59023a5 Mon Sep 17 00:00:00 2001 From: Richard Zou Date: Mon, 13 Jan 2025 14:37:14 -0500 Subject: [PATCH 15/19] Update recipes_source/torch_compile_user_defined_triton_kernel_tutorial.py Co-authored-by: Svetlana Karslioglu --- .../torch_compile_user_defined_triton_kernel_tutorial.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/recipes_source/torch_compile_user_defined_triton_kernel_tutorial.py b/recipes_source/torch_compile_user_defined_triton_kernel_tutorial.py index e33e0847653..87d08661241 100644 --- a/recipes_source/torch_compile_user_defined_triton_kernel_tutorial.py +++ b/recipes_source/torch_compile_user_defined_triton_kernel_tutorial.py @@ -166,7 +166,7 @@ def add_fn(x, y): # * - # - Triton kernel (no explicit ``torch.library`` wrapper) # - ``torch.library.triton_op`` -# - torch.library.custom_op +# - ``torch.library.custom_op`` # * - Supports inference # - Yes # - Yes From 7b0fb129ffa52783cc91d1675bfa0982de56bce9 Mon Sep 17 00:00:00 2001 From: Richard Zou Date: Mon, 13 Jan 2025 14:40:04 -0500 Subject: [PATCH 16/19] Update recipes_source/torch_compile_user_defined_triton_kernel_tutorial.py Co-authored-by: Svetlana Karslioglu --- .../torch_compile_user_defined_triton_kernel_tutorial.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/recipes_source/torch_compile_user_defined_triton_kernel_tutorial.py b/recipes_source/torch_compile_user_defined_triton_kernel_tutorial.py index 87d08661241..c4a31949d58 100644 --- a/recipes_source/torch_compile_user_defined_triton_kernel_tutorial.py +++ b/recipes_source/torch_compile_user_defined_triton_kernel_tutorial.py @@ -152,7 +152,7 @@ def add_fn(x, y): # # To compose with additional PyTorch subsystems, use ``torch.library.triton_op``. # -# triton_op is a structured way of defining a custom operator that is backed by one +# ``triton_op is`` a structured way of defining a custom operator that is backed by one # or more Triton kernels: like regular custom operators (``torch.library.custom_op``), # you are able to specify the interactions with PyTorch subsystems via ``torch.library``. # However, unlike ``torch.library.custom_op``, which creates opaque callables with respect to From db1c64ed0c2e02b8eccbc4a5263e78fe86dbf889 Mon Sep 17 00:00:00 2001 From: Richard Zou Date: Mon, 13 Jan 2025 14:40:13 -0500 Subject: [PATCH 17/19] Update recipes_source/torch_compile_user_defined_triton_kernel_tutorial.py Co-authored-by: Svetlana Karslioglu --- .../torch_compile_user_defined_triton_kernel_tutorial.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/recipes_source/torch_compile_user_defined_triton_kernel_tutorial.py b/recipes_source/torch_compile_user_defined_triton_kernel_tutorial.py index c4a31949d58..0e12a033831 100644 --- a/recipes_source/torch_compile_user_defined_triton_kernel_tutorial.py +++ b/recipes_source/torch_compile_user_defined_triton_kernel_tutorial.py @@ -179,7 +179,7 @@ def add_fn(x, y): # - Yes # - Yes # - Yes -# * - Supports torch.compile(fullgraph=True) +# * - Supports ``torch.compile(fullgraph=True)`` # - In the majority of cases # - In the majority of cases # - In all cases From a88cf38789dd3c5b065cbb86dfa5c471cfc15729 Mon Sep 17 00:00:00 2001 From: Richard Zou Date: Mon, 13 Jan 2025 14:40:21 -0500 Subject: [PATCH 18/19] Update recipes_source/torch_compile_user_defined_triton_kernel_tutorial.py Co-authored-by: Svetlana Karslioglu --- .../torch_compile_user_defined_triton_kernel_tutorial.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/recipes_source/torch_compile_user_defined_triton_kernel_tutorial.py b/recipes_source/torch_compile_user_defined_triton_kernel_tutorial.py index 0e12a033831..f31afa2b514 100644 --- a/recipes_source/torch_compile_user_defined_triton_kernel_tutorial.py +++ b/recipes_source/torch_compile_user_defined_triton_kernel_tutorial.py @@ -197,7 +197,7 @@ def add_fn(x, y): # - Yes ###################################################################### -# Wrapping triton kernels with triton_op +# Wrapping Triton kernels with ``triton_op`` # ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ # # Use ``torch.library.triton_op`` to wrap a function that may invoke one or more Triton kernels. From bc9c98eedca865940f64e4922b8e3c2e0bfe3da5 Mon Sep 17 00:00:00 2001 From: Svetlana Karslioglu Date: Mon, 13 Jan 2025 13:50:32 -0800 Subject: [PATCH 19/19] Update recipes_source/torch_compile_user_defined_triton_kernel_tutorial.py --- .../torch_compile_user_defined_triton_kernel_tutorial.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/recipes_source/torch_compile_user_defined_triton_kernel_tutorial.py b/recipes_source/torch_compile_user_defined_triton_kernel_tutorial.py index f31afa2b514..7f3e6fbf6f8 100644 --- a/recipes_source/torch_compile_user_defined_triton_kernel_tutorial.py +++ b/recipes_source/torch_compile_user_defined_triton_kernel_tutorial.py @@ -145,7 +145,7 @@ def add_fn(x, y): # # User-defined Triton kernels do not automatically support all PyTorch # subsystems. This can be seen in the following use cases: - +# # - Adding a CPU fallback # - Adding a ``FlopCounter`` formula # - Composing with Tensor Subclasses