@@ -140,17 +140,224 @@ def add_fn(x, y):
140
140
print (f"Vector addition of\n X:\t { x } \n Y:\t { y } \n is equal to\n { out } " )
141
141
142
142
######################################################################
143
- # Composibility and Limitations
143
+ # Composability
144
+ # -------------------------------------------------------------------
145
+ #
146
+ # User-defined Triton kernels do not automatically support all PyTorch
147
+ # subsystems. This can be seen in the following use cases:
148
+ #
149
+ # - Adding a CPU fallback
150
+ # - Adding a ``FlopCounter`` formula
151
+ # - Composing with Tensor Subclasses
152
+ #
153
+ # To compose with additional PyTorch subsystems, use ``torch.library.triton_op``.
154
+ #
155
+ # ``triton_op is`` a structured way of defining a custom operator that is backed by one
156
+ # or more Triton kernels: like regular custom operators (``torch.library.custom_op``),
157
+ # you are able to specify the interactions with PyTorch subsystems via ``torch.library``.
158
+ # However, unlike ``torch.library.custom_op``, which creates opaque callables with respect to
159
+ # ``torch.compile``, ``torch.compile`` traces into ``triton_op`` to apply optimizations.
160
+ #
161
+ # Here’s a chart of which API to use when integrating Triton kernels with PyTorch.
162
+ #
163
+ # .. list-table::
164
+ # :header-rows: 1
165
+ #
166
+ # * -
167
+ # - Triton kernel (no explicit ``torch.library`` wrapper)
168
+ # - ``torch.library.triton_op``
169
+ # - ``torch.library.custom_op``
170
+ # * - Supports inference
171
+ # - Yes
172
+ # - Yes
173
+ # - Yes
174
+ # * - Supports training
175
+ # - In the majority of cases
176
+ # - Yes
177
+ # - Yes
178
+ # * - Supports ``torch.compile``
179
+ # - Yes
180
+ # - Yes
181
+ # - Yes
182
+ # * - Supports ``torch.compile(fullgraph=True)``
183
+ # - In the majority of cases
184
+ # - In the majority of cases
185
+ # - In all cases
186
+ # * - Does torch.compile trace into the implementation?
187
+ # - Yes
188
+ # - Yes
189
+ # - No
190
+ # * - Supports AOTInductor
191
+ # - Yes
192
+ # - Yes
193
+ # - No
194
+ # * - Supports PyTorch Subsystems like FlopCounterMode, CPU Fallback, Tensor Subclasses
195
+ # - No
196
+ # - Yes
197
+ # - Yes
198
+
199
+ ######################################################################
200
+ # Wrapping Triton kernels with ``triton_op``
201
+ # ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
202
+ #
203
+ # Use ``torch.library.triton_op`` to wrap a function that may invoke one or more Triton kernels.
204
+ # Use ``torch.library.wrap_triton`` to wrap the calls to the Triton kernel.
205
+
206
+ from torch .library import triton_op , wrap_triton
207
+
208
+ @triton_op ("mylib::mysin" , mutates_args = {})
209
+ def mysin (x : torch .Tensor ) -> torch .Tensor :
210
+ out = torch .empty_like (x )
211
+ n_elements = x .numel ()
212
+ wrap_triton (sin_kernel )[(n_elements ,)](x , out , n_elements , BLOCK_SIZE = 4 )
213
+ return out
214
+
215
+ @triton .jit
216
+ def sin_kernel (
217
+ in_ptr0 ,
218
+ out_ptr ,
219
+ n_elements ,
220
+ BLOCK_SIZE : "tl.constexpr" ,
221
+ ):
222
+ pid = tl .program_id (axis = 0 )
223
+ block_start = pid * BLOCK_SIZE
224
+ offsets = block_start + tl .arange (0 , BLOCK_SIZE )
225
+ mask = offsets < n_elements
226
+ x = tl .load (in_ptr0 + offsets , mask = mask )
227
+ output = tl .sin (x )
228
+ tl .store (out_ptr + offsets , output , mask = mask )
229
+
230
+ def sin_triton (x ):
231
+ out = torch .empty_like (x )
232
+ n_elements = x .numel ()
233
+ sin_kernel [(n_elements ,)](x , out , n_elements , BLOCK_SIZE = 4 )
234
+ return out
235
+
236
+ ######################################################################
237
+ # You can invoke the ``triton_op`` in one of the following two ways.
238
+
239
+ x = torch .randn (3 , device = "cuda" )
240
+ y = mysin (x )
241
+ z = torch .ops .mylib .mysin .default (x )
242
+
243
+ assert torch .allclose (y , x .sin ())
244
+ assert torch .allclose (z , x .sin ())
245
+
246
+ ######################################################################
247
+ # The resulting ``triton_op`` works with ``torch.compile`` and ``AOTInductor``.
248
+
249
+ y = torch .compile (mysin )(x )
250
+ assert torch .allclose (y , x .sin ())
251
+
252
+ ######################################################################
253
+ # Adding training support
254
+ # ^^^^^^^^^^^^^^^^^^^^^^^
255
+ #
256
+ # Use ``register_autograd`` to add an autograd formula for the ``triton_op``.
257
+ # Prefer this to using ``torch.autograd.Function`` (which has various composability footguns
258
+ # with ``torch.compile``).
259
+
260
+ def backward (ctx , grad_output ):
261
+ x , = ctx .saved_tensors
262
+ return grad_input * x .cos ()
263
+
264
+ def setup_context (ctx , inputs , output ):
265
+ x , = inputs
266
+ ctx .save_for_backward (x )
267
+
268
+ mysin .register_autograd (backward , setup_context = setup_context )
269
+
270
+ ######################################################################
271
+ # Note that the backward must be a composition of PyTorch-understood operators.
272
+ # If you want the backward to call Triton kernels, then those must be wrapped in ``triton_op`` as well:
273
+
274
+ @triton .jit
275
+ def cos_kernel (
276
+ in_ptr0 ,
277
+ out_ptr ,
278
+ n_elements ,
279
+ BLOCK_SIZE : "tl.constexpr" ,
280
+ ):
281
+ pid = tl .program_id (axis = 0 )
282
+ block_start = pid * BLOCK_SIZE
283
+ offsets = block_start + tl .arange (0 , BLOCK_SIZE )
284
+ mask = offsets < n_elements
285
+ x = tl .load (in_ptr0 + offsets , mask = mask )
286
+ output = tl .cos (x )
287
+ tl .store (out_ptr + offsets , output , mask = mask )
288
+
289
+ @triton_op ("mylib::mycos" , mutates_args = {})
290
+ def mycos (x : torch .Tensor ) -> torch .Tensor :
291
+ out = torch .empty_like (x )
292
+ n_elements = x .numel ()
293
+ wrap_triton (cos_kernel )[(n_elements ,)](x , out , n_elements , BLOCK_SIZE = 4 )
294
+ return out
295
+
296
+ def backward (ctx , grad_output ):
297
+ x , = ctx .saved_tensors
298
+ return grad_input * mycos (x )
299
+
300
+ def setup_context (ctx , inputs , output ):
301
+ x , = inputs
302
+ ctx .save_for_backward (x )
303
+
304
+ mysin .register_autograd (backward , setup_context = setup_context )
305
+
306
+ ######################################################################
307
+ # Adding a CPU Fallback
308
+ # ^^^^^^^^^^^^^^^^^^^^^
309
+ # Triton kernels don’t run on CPU. Use ``register_kernel`` to add a CPU (or any other device) fallback for the ``triton_op``:
310
+
311
+ @mysin .register_kernel ("cpu" )
312
+ def _ (x ):
313
+ return torch .sin (x )
314
+
315
+ x = torch .randn (3 )
316
+ y = mysin (x )
317
+ assert torch .allclose (y , x .sin ())
318
+
319
+ ######################################################################
320
+ # The fallback must be composed of PyTorch operators.
321
+
322
+ ######################################################################
323
+ # Adding a FlopCounter Formula
324
+ # ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
325
+ #
326
+ # To specify how many flops the triton kernel reports under PyTorch's flop counter,
327
+ # use ``register_flop_formula``.
328
+
329
+ from torch .utils .flop_counter import FlopCounterMode , register_flop_formula
330
+
331
+ @register_flop_formula (torch .ops .mylib .mysin )
332
+ def _ (x_shape ):
333
+ numel = 1
334
+ for s in x_shape :
335
+ numel *= s
336
+ return numel
337
+
338
+ x = torch .randn (3 , device = "cuda" )
339
+
340
+ #########################################################
341
+ # ``FlopCounterMode`` requires `tabulate <https://pypi.org/project/tabulate/>`__.
342
+ # Before running the code below, make sure you have ``tabulate`` installed or install by
343
+ # running ``pip install tabulate``.
344
+ #
345
+ # >>> with FlopCounterMode() as flop_counter:
346
+ # >>> y = mysin(x)
347
+
348
+ ######################################################################
349
+ # Limitations
144
350
# --------------------------------------------------------------------
145
351
#
146
352
# As of PyTorch 2.3, the support for user-defined Triton kernels in ``torch.compile``
147
353
# includes dynamic shapes, ``torch.autograd.Function``, JIT inductor, and AOT inductor.
148
354
# You can use these features together to build complex, high-performance models.
149
355
#
356
+ # PyTorch 2.6 added ``torch.library.triton_op``, which adds support for
357
+ # user-defined Triton kernels in tensor subclasses and other advanced features.
358
+ #
150
359
# However, there are certain limitations to be aware of:
151
360
#
152
- # * **Tensor Subclasses:** Currently, there is no support for
153
- # tensor subclasses and other advanced features.
154
361
# * **Triton Features:** While ``triton.heuristics`` can be used either standalone or
155
362
# before ``triton.autotune``, it cannot be used after ``triton.autotune``. This
156
363
# implies that if ``triton.heuristics`` and ``triton.autotune`` are to be used
0 commit comments