Skip to content

Commit 4703d17

Browse files
committed
update
1 parent edbcbcb commit 4703d17

File tree

2 files changed

+87
-64
lines changed

2 files changed

+87
-64
lines changed

advanced_source/cpp_custom_ops.rst

Lines changed: 22 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ Using ``cpp_extension`` is as simple as writing the following ``setup.py``:
4747
cmdclass={'build_ext': cpp_extension.BuildExtension})
4848
4949
If you need to compile CUDA code (for example, ``.cu`` files), then instead use
50-
`torch.utils.cpp_extension.CUDAExtension <https://pytorch.org/docs/stable/cpp_extension.html#torch.utils.cpp_extension.CUDAExtension>`_
50+
`torch.utils.cpp_extension.CUDAExtension <https://pytorch.org/docs/stable/cpp_extension.html#torch.utils.cpp_extension.CUDAExtension>`_.
5151
Please see how
5252
`extension-cpp <https://github.com/pytorch/extension-cpp>`_ for an example for
5353
how this is set up.
@@ -167,14 +167,14 @@ for more details).
167167

168168
.. code-block:: python
169169
170-
@torch.library.register_fake("extension_cpp::mymuladd")
171-
def _(a, b, c):
172-
torch._check(a.shape == b.shape)
173-
torch._check(a.dtype == torch.float)
174-
torch._check(b.dtype == torch.float)
175-
torch._check(a.device == b.device)
176-
return torch.empty_like(a)
177-
170+
@torch.library.register_fake("extension_cpp::mymuladd")
171+
def _(a, b, c):
172+
torch._check(a.shape == b.shape)
173+
torch._check(a.dtype == torch.float)
174+
torch._check(b.dtype == torch.float)
175+
torch._check(a.device == b.device)
176+
return torch.empty_like(a)
177+
178178
How to add training (autograd) support for an operator
179179
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
180180
Use ``torch.library.register_autograd`` to add training support for an operator. Prefer
@@ -353,19 +353,19 @@ When defining the operator, we must specify that it mutates the out Tensor in th
353353

354354
.. code-block:: cpp
355355
356-
TORCH_LIBRARY(extension_cpp, m) {
357-
m.def("mymuladd(Tensor a, Tensor b, float c) -> Tensor");
358-
m.def("mymul(Tensor a, Tensor b) -> Tensor");
359-
// New!
360-
m.def("myadd_out(Tensor a, Tensor b, Tensor(a!) out) -> ()");
361-
}
362-
363-
TORCH_LIBRARY_IMPL(extension_cpp, CPU, m) {
364-
m.impl("mymuladd", &mymuladd_cpu);
365-
m.impl("mymul", &mymul_cpu);
366-
// New!
367-
m.impl("myadd_out", &myadd_out_cpu);
368-
}
356+
TORCH_LIBRARY(extension_cpp, m) {
357+
m.def("mymuladd(Tensor a, Tensor b, float c) -> Tensor");
358+
m.def("mymul(Tensor a, Tensor b) -> Tensor");
359+
// New!
360+
m.def("myadd_out(Tensor a, Tensor b, Tensor(a!) out) -> ()");
361+
}
362+
363+
TORCH_LIBRARY_IMPL(extension_cpp, CPU, m) {
364+
m.impl("mymuladd", &mymuladd_cpu);
365+
m.impl("mymul", &mymul_cpu);
366+
// New!
367+
m.impl("myadd_out", &myadd_out_cpu);
368+
}
369369
370370
.. note::
371371

advanced_source/python_custom_ops.py

Lines changed: 65 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -10,67 +10,79 @@
1010
This tutorial is for PyTorch 2.4+ and the PyTorch nightlies.
1111
1212
PyTorch offers a large library of operators that work on Tensors (e.g.
13-
torch.add, torch.sum, etc). However, you may wish to use a new customized
13+
``torch.add``, ``torch.sum``, etc). However, you might wish to use a new customized
1414
operator with PyTorch, perhaps written by a third-party library. This tutorial
1515
shows how to wrap Python functions so that they behave like PyTorch native
16-
operators. Reasons why you may wish to create a custom op in PyTorch include:
16+
operators. Reasons why you may wish to create a custom operator in PyTorch include:
1717
18-
- Black-box-ing an arbitrary Python function for use with torch.compile
18+
- Treating an arbitrary Python function as an opaque callable with respect
19+
to ``torch.compile`` (that is, prevent ``torch.compile`` from tracing
20+
into the function).
1921
- Adding training support to an arbitrary Python function
2022
2123
Please note that if your operation can be expressed as a composition of
22-
existing PyTorch ops, then there is usually no need to use the custom op
23-
API -- everything (e.g. torch.compile, training support) should just work.
24+
existing PyTorch operators, then there is usually no need to use the custom operator
25+
API -- everything (for example ``torch.compile``, training support) should
26+
just work.
2427
"""
2528
######################################################################
26-
# Wrapping PIL's crop into a custom op
29+
# Example: Wrapping PIL's crop into a custom operator
2730
# ------------------------------------
28-
# Let's say that we are using PIL's crop operation.
31+
# Let's say that we are using PIL's ``crop`` operation.
2932

3033
import torch
3134
from torchvision.transforms.functional import to_pil_image, pil_to_tensor
3235
import PIL
3336
import IPython
37+
import matplotlib.pyplot as plt
3438

3539
def crop(pic, box):
3640
img = to_pil_image(pic.cpu())
3741
cropped_img = img.crop(box)
3842
return pil_to_tensor(cropped_img).to(pic.device) / 255.
3943

4044
def display(img):
41-
img_pil = to_pil_image(img)
42-
IPython.display.display(img_pil)
43-
45+
plt.imshow(img.numpy().transpose((1, 2, 0)))
4446

4547
img = torch.ones(3, 64, 64)
4648
img *= torch.linspace(0, 1, steps=64) * torch.linspace(0, 1, steps=64).unsqueeze(-1)
4749
display(img)
50+
51+
######################################################################
52+
4853
cropped_img = crop(img, (10, 10, 50, 50))
4954
display(cropped_img)
5055

5156
######################################################################
52-
# ``crop`` doesn't work performantly out-of-the-box with torch.compile. The
53-
# following code leads to an error when run.
57+
# ``crop`` is not handled effectively out-of-the-box by
58+
# ``torch.compile``: ``torch.compile`` induces a
59+
# `"graph break" <https://pytorch.org/docs/stable/torch.compiler_faq.html#graph-breaks>`_
60+
# on functions it is unable to handle and graph breaks are bad for performance.
61+
# The following code demonstrates this by raising an error
62+
# (``torch.compile`` with ``fullgraph=True`` raises an error if a
63+
# graph break occurs).
5464

55-
"""
5665
@torch.compile(fullgraph=True)
5766
def f(img):
5867
return crop(img, (10, 10, 50, 50))
5968

60-
cropped_img = f(img)
61-
"""
69+
# The following raises an error. Uncomment the line to see it.
70+
# cropped_img = f(img)
6271

6372
######################################################################
64-
# In order to black-box ``crop`` for use with ``torch.compile``, we need to do two things:
73+
# In order to black-box ``crop`` for use with ``torch.compile``, we need to
74+
# do two things:
6575
#
66-
# - wrap the function into a PyTorch custom op.
67-
# - add a "FakeTensor kernel" (aka "meta kernel") to the op. Given the metadata (e.g. shapes)
68-
# of the input Tensors, this function says how to compute the metadata of the output Tensor(s).
76+
# 1. wrap the function into a PyTorch custom operator.
77+
# 2. add a "FakeTensor kernel" (aka "meta kernel") to the operator.
78+
# Given the metadata (e.g. shapes)
79+
# of the input Tensors, this function says how to compute the metadata
80+
# of the output Tensor(s).
6981

7082

7183
from typing import Sequence
7284

73-
# Use torch.library.custom_op to define a new custom op.
85+
# Use torch.library.custom_op to define a new custom operator.
7486
# If your operator mutates any input Tensors, their names must be specified
7587
# in the mutates_args argument.
7688
@torch.library.custom_op("mylib::crop", mutates_args=())
@@ -79,22 +91,25 @@ def crop(pic: torch.Tensor, box: Sequence[int]) -> torch.Tensor:
7991
cropped_img = img.crop(box)
8092
return (pil_to_tensor(cropped_img) / 255.).to(pic.device, pic.dtype)
8193

82-
# Use register_fake to add a FakeTensor kernel for the op
94+
# Use register_fake to add a FakeTensor kernel for the operator
8395
@crop.register_fake
8496
def _(pic, box):
8597
channels = pic.shape[0]
8698
x0, y0, x1, y1 = box
8799
return pic.new_empty(channels, y1 - y0, x1 - x0)
88100

89101
######################################################################
90-
# After this, crop now works with torch.compile:
102+
# After this, ``crop`` now works whout graph breaks:
91103

92104
@torch.compile(fullgraph=True)
93105
def f(img):
94106
return crop(img, (10, 10, 50, 50))
95107

96108
cropped_img = f(img)
97109
display(img)
110+
111+
######################################################################
112+
98113
display(cropped_img)
99114

100115
######################################################################
@@ -103,11 +118,11 @@ def f(img):
103118
# Use ``torch.library.register_autograd`` to add training support for an operator.
104119
# Prefer this over directly using ``torch.autograd.Function``; some compositions of
105120
# ``autograd.Function`` with PyTorch operator registration APIs can lead to (and
106-
# has led to) silent incorrectness.
121+
# has led to) silent incorrectness when composed with ``torch.compile``.
107122
#
108123
# The gradient formula for ``crop`` is essentially ``PIL.paste`` (we'll leave the
109124
# derivation as an exercise to the reader). Let's first wrap ``paste`` into a
110-
# custom op:
125+
# custom operator:
111126

112127
@torch.library.custom_op("mylib::paste", mutates_args=())
113128
def paste(im1: torch.Tensor, im2: torch.Tensor, coord: Sequence[int]) -> torch.Tensor:
@@ -125,7 +140,7 @@ def _(im1, im2, coord):
125140
return torch.empty_like(im1)
126141

127142
######################################################################
128-
# And now let's use register_autograd to specify the gradient formula for ``crop``:
143+
# And now let's use ``register_autograd`` to specify the gradient formula for ``crop``:
129144

130145
def backward(ctx, grad_output):
131146
grad_input = grad_output.new_zeros(ctx.pic_shape)
@@ -141,7 +156,7 @@ def setup_context(ctx, inputs, output):
141156

142157
######################################################################
143158
# Note that the backward must be a composition of PyTorch-understood operators,
144-
# which is why we wrapped paste into a custom op instead of directly using
159+
# which is why we wrapped paste into a custom operator instead of directly using
145160
# PIL's paste.
146161

147162
img = img.requires_grad_()
@@ -154,15 +169,15 @@ def setup_context(ctx, inputs, output):
154169
# (black) in the unused region.
155170

156171
######################################################################
157-
# Testing Python Custom Ops
172+
# Testing Python Custom operators
158173
# -------------------------
159-
# Use torch.library.opcheck to test that the custom op was registered
174+
# Use ``torch.library.opcheck`` to test that the custom operator was registered
160175
# correctly. This does not test that the gradients are mathematically correct;
161-
# please write separate tests for that (either manual ones or torch.autograd.gradcheck).
176+
# please write separate tests for that (either manual ones or ``torch.autograd.gradcheck``).
162177
#
163-
# To use opcheck, pass it a set of example inputs to test against. If your
178+
# To use ``opcheck``, pass it a set of example inputs to test against. If your
164179
# operator supports training, then the examples should include Tensors that
165-
# require grad. If your operator supports multiple devices, then the examplesxi
180+
# require grad. If your operator supports multiple devices, then the examples
166181
# should include Tensors from each device.
167182

168183
examples = [
@@ -176,14 +191,16 @@ def setup_context(ctx, inputs, output):
176191
torch.library.opcheck(crop, example)
177192

178193
######################################################################
179-
# Mutable Python Custom Ops
194+
# Mutable Python Custom operators
180195
# -------------------------
181-
# You can also wrap a Python function that mutates its inputs into a custom op.
196+
# You can also wrap a Python function that mutates its inputs into a custom
197+
# operator.
182198
# Functions that mutate inputs are common because that is how many low-level
183-
# kernels are written; for example, a kernel that computes sin may take in the
184-
# input and an output tensor and write ``input.sin()`` to the output tensor.
199+
# kernels are written; for example, a kernel that computes ``sin`` may take in
200+
# the input and an output tensor and write ``input.sin()`` to the output tensor.
185201
#
186-
# We'll use numpy.sin to demonstrate an example of a mutable Python custom op.
202+
# We'll use ``numpy.sin`` to demonstrate an example of a mutable Python
203+
# custom operator.
187204

188205
import numpy as np
189206

@@ -196,9 +213,8 @@ def numpy_sin(input: torch.Tensor, output: torch.Tensor) -> None:
196213
np.sin(input_np, out=output_np)
197214

198215
######################################################################
199-
# This custom op automatically works with torch.compile.
200-
# Because the op doesn't return anything, there is no need to register
201-
# a FakeTensor kernel (meta kernel).
216+
# Because the operator doesn't return anything, there is no need to register
217+
# a FakeTensor kernel (meta kernel) to get it to work with ``torch.compile``.
202218

203219
@torch.compile(fullgraph=True)
204220
def f(x):
@@ -211,8 +227,8 @@ def f(x):
211227
assert torch.allclose(y, x.sin())
212228

213229
######################################################################
214-
# And here's an opcheck run telling us that we did indeed register the op correctly.
215-
# opcheck would error out if we forgot to add the output to ``mutates_args``, for example.
230+
# And here's an ``opcheck`` run telling us that we did indeed register the operator correctly.
231+
# ``opcheck`` would error out if we forgot to add the output to ``mutates_args``, for example.
216232

217233
example_inputs = [
218234
[torch.randn(3), torch.empty(3)],
@@ -226,6 +242,13 @@ def f(x):
226242
######################################################################
227243
# Conclusion
228244
# ----------
229-
# For more information, please see:
245+
# In this tutorial, we learned how to use ``torch.library.custom_op`` to
246+
# create a custom operator in Python that works with PyTorch subsystems
247+
# such as ``torch.compile`` and autograd.
248+
#
249+
# This tutorial provides a basic introduction to custom operators.
250+
# For more detailed information, see:
251+
#
230252
# - `the torch.library documentation <https://pytorch.org/docs/stable/library.html>`_
231253
# - `the Custom Operators Manual <https://pytorch.org/docs/main/notes/custom_operators.html>`_
254+
#

0 commit comments

Comments
 (0)