10
10
This tutorial is for PyTorch 2.4+ and the PyTorch nightlies.
11
11
12
12
PyTorch offers a large library of operators that work on Tensors (e.g.
13
- torch.add, torch.sum, etc). However, you may wish to use a new customized
13
+ `` torch.add``, `` torch.sum`` , etc). However, you might wish to use a new customized
14
14
operator with PyTorch, perhaps written by a third-party library. This tutorial
15
15
shows how to wrap Python functions so that they behave like PyTorch native
16
- operators. Reasons why you may wish to create a custom op in PyTorch include:
16
+ operators. Reasons why you may wish to create a custom operator in PyTorch include:
17
17
18
- - Black-box-ing an arbitrary Python function for use with torch.compile
18
+ - Treating an arbitrary Python function as an opaque callable with respect
19
+ to ``torch.compile`` (that is, prevent ``torch.compile`` from tracing
20
+ into the function).
19
21
- Adding training support to an arbitrary Python function
20
22
21
23
Please note that if your operation can be expressed as a composition of
22
- existing PyTorch ops, then there is usually no need to use the custom op
23
- API -- everything (e.g. torch.compile, training support) should just work.
24
+ existing PyTorch operators, then there is usually no need to use the custom operator
25
+ API -- everything (for example ``torch.compile``, training support) should
26
+ just work.
24
27
"""
25
28
######################################################################
26
- # Wrapping PIL's crop into a custom op
29
+ # Example: Wrapping PIL's crop into a custom operator
27
30
# ------------------------------------
28
- # Let's say that we are using PIL's crop operation.
31
+ # Let's say that we are using PIL's `` crop`` operation.
29
32
30
33
import torch
31
34
from torchvision .transforms .functional import to_pil_image , pil_to_tensor
32
35
import PIL
33
36
import IPython
37
+ import matplotlib .pyplot as plt
34
38
35
39
def crop (pic , box ):
36
40
img = to_pil_image (pic .cpu ())
37
41
cropped_img = img .crop (box )
38
42
return pil_to_tensor (cropped_img ).to (pic .device ) / 255.
39
43
40
44
def display (img ):
41
- img_pil = to_pil_image (img )
42
- IPython .display .display (img_pil )
43
-
45
+ plt .imshow (img .numpy ().transpose ((1 , 2 , 0 )))
44
46
45
47
img = torch .ones (3 , 64 , 64 )
46
48
img *= torch .linspace (0 , 1 , steps = 64 ) * torch .linspace (0 , 1 , steps = 64 ).unsqueeze (- 1 )
47
49
display (img )
50
+
51
+ ######################################################################
52
+
48
53
cropped_img = crop (img , (10 , 10 , 50 , 50 ))
49
54
display (cropped_img )
50
55
51
56
######################################################################
52
- # ``crop`` doesn't work performantly out-of-the-box with torch.compile. The
53
- # following code leads to an error when run.
57
+ # ``crop`` is not handled effectively out-of-the-box by
58
+ # ``torch.compile``: ``torch.compile`` induces a
59
+ # `"graph break" <https://pytorch.org/docs/stable/torch.compiler_faq.html#graph-breaks>`_
60
+ # on functions it is unable to handle and graph breaks are bad for performance.
61
+ # The following code demonstrates this by raising an error
62
+ # (``torch.compile`` with ``fullgraph=True`` raises an error if a
63
+ # graph break occurs).
54
64
55
- """
56
65
@torch .compile (fullgraph = True )
57
66
def f (img ):
58
67
return crop (img , (10 , 10 , 50 , 50 ))
59
68
60
- cropped_img = f(img)
61
- """
69
+ # The following raises an error. Uncomment the line to see it.
70
+ # cropped_img = f(img)
62
71
63
72
######################################################################
64
- # In order to black-box ``crop`` for use with ``torch.compile``, we need to do two things:
73
+ # In order to black-box ``crop`` for use with ``torch.compile``, we need to
74
+ # do two things:
65
75
#
66
- # - wrap the function into a PyTorch custom op.
67
- # - add a "FakeTensor kernel" (aka "meta kernel") to the op. Given the metadata (e.g. shapes)
68
- # of the input Tensors, this function says how to compute the metadata of the output Tensor(s).
76
+ # 1. wrap the function into a PyTorch custom operator.
77
+ # 2. add a "FakeTensor kernel" (aka "meta kernel") to the operator.
78
+ # Given the metadata (e.g. shapes)
79
+ # of the input Tensors, this function says how to compute the metadata
80
+ # of the output Tensor(s).
69
81
70
82
71
83
from typing import Sequence
72
84
73
- # Use torch.library.custom_op to define a new custom op .
85
+ # Use torch.library.custom_op to define a new custom operator .
74
86
# If your operator mutates any input Tensors, their names must be specified
75
87
# in the mutates_args argument.
76
88
@torch .library .custom_op ("mylib::crop" , mutates_args = ())
@@ -79,22 +91,25 @@ def crop(pic: torch.Tensor, box: Sequence[int]) -> torch.Tensor:
79
91
cropped_img = img .crop (box )
80
92
return (pil_to_tensor (cropped_img ) / 255. ).to (pic .device , pic .dtype )
81
93
82
- # Use register_fake to add a FakeTensor kernel for the op
94
+ # Use register_fake to add a FakeTensor kernel for the operator
83
95
@crop .register_fake
84
96
def _ (pic , box ):
85
97
channels = pic .shape [0 ]
86
98
x0 , y0 , x1 , y1 = box
87
99
return pic .new_empty (channels , y1 - y0 , x1 - x0 )
88
100
89
101
######################################################################
90
- # After this, crop now works with torch.compile :
102
+ # After this, `` crop`` now works whout graph breaks :
91
103
92
104
@torch .compile (fullgraph = True )
93
105
def f (img ):
94
106
return crop (img , (10 , 10 , 50 , 50 ))
95
107
96
108
cropped_img = f (img )
97
109
display (img )
110
+
111
+ ######################################################################
112
+
98
113
display (cropped_img )
99
114
100
115
######################################################################
@@ -103,11 +118,11 @@ def f(img):
103
118
# Use ``torch.library.register_autograd`` to add training support for an operator.
104
119
# Prefer this over directly using ``torch.autograd.Function``; some compositions of
105
120
# ``autograd.Function`` with PyTorch operator registration APIs can lead to (and
106
- # has led to) silent incorrectness.
121
+ # has led to) silent incorrectness when composed with ``torch.compile`` .
107
122
#
108
123
# The gradient formula for ``crop`` is essentially ``PIL.paste`` (we'll leave the
109
124
# derivation as an exercise to the reader). Let's first wrap ``paste`` into a
110
- # custom op :
125
+ # custom operator :
111
126
112
127
@torch .library .custom_op ("mylib::paste" , mutates_args = ())
113
128
def paste (im1 : torch .Tensor , im2 : torch .Tensor , coord : Sequence [int ]) -> torch .Tensor :
@@ -125,7 +140,7 @@ def _(im1, im2, coord):
125
140
return torch .empty_like (im1 )
126
141
127
142
######################################################################
128
- # And now let's use register_autograd to specify the gradient formula for ``crop``:
143
+ # And now let's use `` register_autograd`` to specify the gradient formula for ``crop``:
129
144
130
145
def backward (ctx , grad_output ):
131
146
grad_input = grad_output .new_zeros (ctx .pic_shape )
@@ -141,7 +156,7 @@ def setup_context(ctx, inputs, output):
141
156
142
157
######################################################################
143
158
# Note that the backward must be a composition of PyTorch-understood operators,
144
- # which is why we wrapped paste into a custom op instead of directly using
159
+ # which is why we wrapped paste into a custom operator instead of directly using
145
160
# PIL's paste.
146
161
147
162
img = img .requires_grad_ ()
@@ -154,15 +169,15 @@ def setup_context(ctx, inputs, output):
154
169
# (black) in the unused region.
155
170
156
171
######################################################################
157
- # Testing Python Custom Ops
172
+ # Testing Python Custom operators
158
173
# -------------------------
159
- # Use torch.library.opcheck to test that the custom op was registered
174
+ # Use `` torch.library.opcheck`` to test that the custom operator was registered
160
175
# correctly. This does not test that the gradients are mathematically correct;
161
- # please write separate tests for that (either manual ones or torch.autograd.gradcheck).
176
+ # please write separate tests for that (either manual ones or `` torch.autograd.gradcheck`` ).
162
177
#
163
- # To use opcheck, pass it a set of example inputs to test against. If your
178
+ # To use `` opcheck`` , pass it a set of example inputs to test against. If your
164
179
# operator supports training, then the examples should include Tensors that
165
- # require grad. If your operator supports multiple devices, then the examplesxi
180
+ # require grad. If your operator supports multiple devices, then the examples
166
181
# should include Tensors from each device.
167
182
168
183
examples = [
@@ -176,14 +191,16 @@ def setup_context(ctx, inputs, output):
176
191
torch .library .opcheck (crop , example )
177
192
178
193
######################################################################
179
- # Mutable Python Custom Ops
194
+ # Mutable Python Custom operators
180
195
# -------------------------
181
- # You can also wrap a Python function that mutates its inputs into a custom op.
196
+ # You can also wrap a Python function that mutates its inputs into a custom
197
+ # operator.
182
198
# Functions that mutate inputs are common because that is how many low-level
183
- # kernels are written; for example, a kernel that computes sin may take in the
184
- # input and an output tensor and write ``input.sin()`` to the output tensor.
199
+ # kernels are written; for example, a kernel that computes `` sin`` may take in
200
+ # the input and an output tensor and write ``input.sin()`` to the output tensor.
185
201
#
186
- # We'll use numpy.sin to demonstrate an example of a mutable Python custom op.
202
+ # We'll use ``numpy.sin`` to demonstrate an example of a mutable Python
203
+ # custom operator.
187
204
188
205
import numpy as np
189
206
@@ -196,9 +213,8 @@ def numpy_sin(input: torch.Tensor, output: torch.Tensor) -> None:
196
213
np .sin (input_np , out = output_np )
197
214
198
215
######################################################################
199
- # This custom op automatically works with torch.compile.
200
- # Because the op doesn't return anything, there is no need to register
201
- # a FakeTensor kernel (meta kernel).
216
+ # Because the operator doesn't return anything, there is no need to register
217
+ # a FakeTensor kernel (meta kernel) to get it to work with ``torch.compile``.
202
218
203
219
@torch .compile (fullgraph = True )
204
220
def f (x ):
@@ -211,8 +227,8 @@ def f(x):
211
227
assert torch .allclose (y , x .sin ())
212
228
213
229
######################################################################
214
- # And here's an opcheck run telling us that we did indeed register the op correctly.
215
- # opcheck would error out if we forgot to add the output to ``mutates_args``, for example.
230
+ # And here's an `` opcheck`` run telling us that we did indeed register the operator correctly.
231
+ # `` opcheck`` would error out if we forgot to add the output to ``mutates_args``, for example.
216
232
217
233
example_inputs = [
218
234
[torch .randn (3 ), torch .empty (3 )],
@@ -226,6 +242,13 @@ def f(x):
226
242
######################################################################
227
243
# Conclusion
228
244
# ----------
229
- # For more information, please see:
245
+ # In this tutorial, we learned how to use ``torch.library.custom_op`` to
246
+ # create a custom operator in Python that works with PyTorch subsystems
247
+ # such as ``torch.compile`` and autograd.
248
+ #
249
+ # This tutorial provides a basic introduction to custom operators.
250
+ # For more detailed information, see:
251
+ #
230
252
# - `the torch.library documentation <https://pytorch.org/docs/stable/library.html>`_
231
253
# - `the Custom Operators Manual <https://pytorch.org/docs/main/notes/custom_operators.html>`_
254
+ #
0 commit comments