Skip to content

Commit 397012b

Browse files
committed
Update
1 parent 551d235 commit 397012b

File tree

3 files changed

+32
-5
lines changed

3 files changed

+32
-5
lines changed

advanced_source/cpp_custom_ops.rst

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -167,13 +167,38 @@ for more details).
167167

168168
.. code-block:: python
169169
170+
# Important: the C++ custom operator definitions should be loaded first
171+
# before calling ``torch.library`` APIs that add registrations for the
172+
# C++ custom operator(s). The following import loads our
173+
# C++ custom operator definitions.
174+
# See the next section for more details.
175+
from . import _C
176+
170177
@torch.library.register_fake("extension_cpp::mymuladd")
171178
def _(a, b, c):
172179
torch._check(a.shape == b.shape)
173180
torch._check(a.dtype == torch.float)
174181
torch._check(b.dtype == torch.float)
175182
torch._check(a.device == b.device)
176183
return torch.empty_like(a)
184+
185+
How to set up hybrid Python/C++ registration
186+
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
187+
In this tutorial, we defined a custom operator in C++, added CPU/CUDA
188+
implementations in C++, and added ``FakeTensor`` kernels and backward formulas
189+
in Python. The order in which these registrations are loaded (or imported)
190+
matters (importing in the wrong order will lead to an error).
191+
192+
To use the custom operator with hybrid Python/C++ registrations, we must
193+
first load the C++ library that holds the custom operator definition
194+
and then call the ``torch.library`` registration APIs. This can happen in one
195+
of two ways:
196+
197+
1. If you're following this tutorial, importing the Python C extension module
198+
we created will load the C++ custom operator definitions.
199+
2. If your C++ custom operator is located in a shared library object, you can
200+
also use ``torch.ops.load_library("/path/to/library.so")`` to load it.
201+
177202

178203
How to add training (autograd) support for an operator
179204
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

advanced_source/python_custom_ops.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ def f(img):
7474
# do two things:
7575
#
7676
# 1. wrap the function into a PyTorch custom operator.
77-
# 2. add a "FakeTensor kernel" (aka "meta kernel") to the operator.
77+
# 2. add a "``FakeTensor`` kernel" (aka "meta kernel") to the operator.
7878
# Given the metadata (e.g. shapes)
7979
# of the input Tensors, this function says how to compute the metadata
8080
# of the output Tensor(s).
@@ -84,22 +84,22 @@ def f(img):
8484

8585
# Use torch.library.custom_op to define a new custom operator.
8686
# If your operator mutates any input Tensors, their names must be specified
87-
# in the mutates_args argument.
87+
# in the ``mutates_args`` argument.
8888
@torch.library.custom_op("mylib::crop", mutates_args=())
8989
def crop(pic: torch.Tensor, box: Sequence[int]) -> torch.Tensor:
9090
img = to_pil_image(pic.cpu())
9191
cropped_img = img.crop(box)
9292
return (pil_to_tensor(cropped_img) / 255.).to(pic.device, pic.dtype)
9393

94-
# Use register_fake to add a FakeTensor kernel for the operator
94+
# Use register_fake to add a ``FakeTensor`` kernel for the operator
9595
@crop.register_fake
9696
def _(pic, box):
9797
channels = pic.shape[0]
9898
x0, y0, x1, y1 = box
9999
return pic.new_empty(channels, y1 - y0, x1 - x0)
100100

101101
######################################################################
102-
# After this, ``crop`` now works whout graph breaks:
102+
# After this, ``crop`` now works without graph breaks:
103103

104104
@torch.compile(fullgraph=True)
105105
def f(img):
@@ -214,7 +214,7 @@ def numpy_sin(input: torch.Tensor, output: torch.Tensor) -> None:
214214

215215
######################################################################
216216
# Because the operator doesn't return anything, there is no need to register
217-
# a FakeTensor kernel (meta kernel) to get it to work with ``torch.compile``.
217+
# a ``FakeTensor`` kernel (meta kernel) to get it to work with ``torch.compile``.
218218

219219
@torch.compile(fullgraph=True)
220220
def f(x):

en-wordlist.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -177,6 +177,7 @@ OpenSlide
177177
Opset
178178
Ornstein
179179
PIL
180+
PIL's
180181
PPO
181182
PatchPredictor
182183
PennFudan
@@ -443,6 +444,7 @@ multithreading
443444
namespace
444445
natively
445446
ndarrays
447+
nightlies
446448
num
447449
numericalize
448450
numpy

0 commit comments

Comments
 (0)