@@ -140,17 +140,220 @@ def add_fn(x, y):
140
140
print (f"Vector addition of\n X:\t { x } \n Y:\t { y } \n is equal to\n { out } " )
141
141
142
142
######################################################################
143
- # Composibility and Limitations
143
+ # Composability
144
+ # -------------------------------------------------------------------
145
+ #
146
+ # User-defined triton kernels do not automatically support all PyTorch
147
+ # subsystems, like in the following use cases:
148
+ # - Adding a CPU fallback
149
+ # - Adding a ``FlopCounter`` formula
150
+ # - Composing with Tensor Subclasses
151
+ #
152
+ # To compose with additional PyTorch subsystems, use ``torch.library.triton_op``.
153
+ #
154
+ # triton_op is a structured way of defining a custom operator that is backed by one
155
+ # or more triton kernels: like regular custom operators (``torch.library.custom_op``),
156
+ # you are able to specify the interactions with PyTorch subsystems via ``torch.library``.
157
+ # However, unlike ``torch.library.custom_op``, which creates opaque callables w.r.t.
158
+ # ``torch.compile``, ``torch.compile`` traces into ``triton_op`` to apply optimizations.
159
+ #
160
+ # Here’s a chart of which API to use when integrating triton kernels with PyTorch.
161
+ #
162
+ # .. list-table::
163
+ # :header-rows: 1
164
+ #
165
+ # * -
166
+ # - triton kernel (no explicit torch.library wrapper)
167
+ # - torch.library.triton_op
168
+ # - torch.library.custom_op
169
+ # * - Supports inference
170
+ # - Yes
171
+ # - Yes
172
+ # - Yes
173
+ # * - Supports training
174
+ # - In the majority of cases
175
+ # - Yes
176
+ # - Yes
177
+ # * - Supports torch.compile
178
+ # - Yes
179
+ # - Yes
180
+ # - Yes
181
+ # * - Supports torch.compile(fullgraph=True)
182
+ # - In the majority of cases
183
+ # - In the majority of cases
184
+ # - In all cases
185
+ # * - Does torch.compile trace into the implementation?
186
+ # - Yes
187
+ # - Yes
188
+ # - No
189
+ # * - Supports AOTInductor
190
+ # - Yes
191
+ # - Yes
192
+ # - No
193
+ # * - Supports PyTorch Subsystems like FlopCounterMode, CPU Fallback, Tensor Subclasses
194
+ # - No
195
+ # - Yes
196
+ # - Yes
197
+
198
+ ######################################################################
199
+ # Wrapping triton kernels with triton_op
200
+ # ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
201
+ #
202
+ # Use ``torch.library.triton_op`` to wrap a function that may invoke one or more triton kernels.
203
+ # Use ``torch.library.wrap_triton`` to wrap the calls to the triton kernel.
204
+
205
+ from torch .library import triton_op , wrap_triton
206
+
207
+ @triton_op ("mylib::mysin" , mutates_args = {})
208
+ def mysin (x : torch .Tensor ) -> torch .Tensor :
209
+ out = torch .empty_like (x )
210
+ n_elements = x .numel ()
211
+ wrap_triton (sin_kernel )[(n_elements ,)](x , out , n_elements , BLOCK_SIZE = 4 )
212
+ return out
213
+
214
+ @triton .jit
215
+ def sin_kernel (
216
+ in_ptr0 ,
217
+ out_ptr ,
218
+ n_elements ,
219
+ BLOCK_SIZE : "tl.constexpr" ,
220
+ ):
221
+ pid = tl .program_id (axis = 0 )
222
+ block_start = pid * BLOCK_SIZE
223
+ offsets = block_start + tl .arange (0 , BLOCK_SIZE )
224
+ mask = offsets < n_elements
225
+ x = tl .load (in_ptr0 + offsets , mask = mask )
226
+ output = tl .sin (x )
227
+ tl .store (out_ptr + offsets , output , mask = mask )
228
+
229
+ def sin_triton (x ):
230
+ out = torch .empty_like (x )
231
+ n_elements = x .numel ()
232
+ sin_kernel [(n_elements ,)](x , out , n_elements , BLOCK_SIZE = 4 )
233
+ return out
234
+
235
+ ######################################################################
236
+ # You can invoke the ``triton_op`` in one of the following two ways.
237
+
238
+ x = torch .randn (3 , device = "cuda" )
239
+ y = mysin (x )
240
+ z = torch .ops .mylib .mysin .default (x )
241
+
242
+ assert torch .allclose (y , x .sin ())
243
+ assert torch .allclose (z , x .sin ())
244
+
245
+ ######################################################################
246
+ # The resulting ``triton_op`` works with ``torch.compile`` and ``AOTInductor``.
247
+
248
+ y = torch .compile (mysin )(x )
249
+ assert torch .allclose (y , x .sin ())
250
+
251
+ ######################################################################
252
+ # Adding training support
253
+ # ^^^^^^^^^^^^^^^^^^^^^^^
254
+ #
255
+ # Use ``register_autograd`` to add an autograd formula for the ``triton_op``.
256
+ # Prefer this to using ``torch.autograd.Function`` (which has various composability footguns
257
+ # with ``torch.compile``).
258
+
259
+ def backward (ctx , grad_output ):
260
+ x , = ctx .saved_tensors
261
+ return grad_input * x .cos ()
262
+
263
+ def setup_context (ctx , inputs , output ):
264
+ x , = inputs
265
+ ctx .save_for_backward (x )
266
+
267
+ mysin .register_autograd (backward , setup_context = setup_context )
268
+
269
+ ######################################################################
270
+ # Note that the backward must be a composition of PyTorch-understood operators.
271
+ # If you want the backward to call triton kernels, then those must be wrapped in ``triton_op`` as well:
272
+
273
+ @triton .jit
274
+ def cos_kernel (
275
+ in_ptr0 ,
276
+ out_ptr ,
277
+ n_elements ,
278
+ BLOCK_SIZE : "tl.constexpr" ,
279
+ ):
280
+ pid = tl .program_id (axis = 0 )
281
+ block_start = pid * BLOCK_SIZE
282
+ offsets = block_start + tl .arange (0 , BLOCK_SIZE )
283
+ mask = offsets < n_elements
284
+ x = tl .load (in_ptr0 + offsets , mask = mask )
285
+ output = tl .cos (x )
286
+ tl .store (out_ptr + offsets , output , mask = mask )
287
+
288
+ @triton_op ("mylib::mycos" , mutates_args = {})
289
+ def mycos (x : torch .Tensor ) -> torch .Tensor :
290
+ out = torch .empty_like (x )
291
+ n_elements = x .numel ()
292
+ wrap_triton (cos_kernel )[(n_elements ,)](x , out , n_elements , BLOCK_SIZE = 4 )
293
+ return out
294
+
295
+ def backward (ctx , grad_output ):
296
+ x , = ctx .saved_tensors
297
+ return grad_input * mycos (x )
298
+
299
+ def setup_context (ctx , inputs , output ):
300
+ x , = inputs
301
+ ctx .save_for_backward (x )
302
+
303
+ mysin .register_autograd (backward , setup_context = setup_context )
304
+
305
+ ######################################################################
306
+ # Adding a CPU Fallback
307
+ # ^^^^^^^^^^^^^^^^^^^^^
308
+ # triton kernels don’t run on CPU. Use ``register_kernel`` to add a CPU (or any other device) fallback for the ``triton_op``:
309
+
310
+ @mysin .register_kernel ("cpu" )
311
+ def _ (x ):
312
+ return torch .sin (x )
313
+
314
+ x = torch .randn (3 )
315
+ y = mysin (x )
316
+ assert torch .allclose (y , x .sin ())
317
+
318
+ ######################################################################
319
+ # The fallback must be composed of PyTorch operators.
320
+
321
+ ######################################################################
322
+ # Adding a FlopCounter Formula
323
+ # ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
324
+ #
325
+ # To specify how many flops the triton kernel reports under PyTorch's flop counter,
326
+ # use ``register_flop_formula``.
327
+
328
+ from torch .utils .flop_counter import FlopCounterMode , register_flop_formula
329
+
330
+ @register_flop_formula (torch .ops .mylib .mysin )
331
+ def _ (x_shape ):
332
+ numel = 1
333
+ for s in x_shape :
334
+ numel *= s
335
+ return numel
336
+
337
+ x = torch .randn (3 , device = "cuda" )
338
+
339
+ # NB: FlopCounterMode requires tabulate.
340
+ #
341
+ # >>> with FlopCounterMode() as flop_counter:
342
+ # >>> y = mysin(x)
343
+
344
+ ######################################################################
345
+ # Limitations
144
346
# --------------------------------------------------------------------
145
347
#
146
348
# As of PyTorch 2.3, the support for user-defined Triton kernels in ``torch.compile``
147
349
# includes dynamic shapes, ``torch.autograd.Function``, JIT inductor, and AOT inductor.
148
350
# You can use these features together to build complex, high-performance models.
149
351
#
352
+ # PyTorch 2.6 added ``torch.library.triton_op``, which adds support for
353
+ # user-defined Triton kernels in tensor subclasses and other advanced features.
354
+ #
150
355
# However, there are certain limitations to be aware of:
151
356
#
152
- # * **Tensor Subclasses:** Currently, there is no support for
153
- # tensor subclasses and other advanced features.
154
357
# * **Triton Features:** While ``triton.heuristics`` can be used either standalone or
155
358
# before ``triton.autotune``, it cannot be used after ``triton.autotune``. This
156
359
# implies that if ``triton.heuristics`` and ``triton.autotune`` are to be used
0 commit comments