Skip to content

Commit bc9b6e9

Browse files
committed
Fix cpp custom ops tutorial per review
1 parent 3634837 commit bc9b6e9

File tree

2 files changed

+60
-43
lines changed

2 files changed

+60
-43
lines changed

advanced_source/cpp_custom_ops.rst

Lines changed: 58 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -23,18 +23,18 @@ https://github.com/pytorch/extension-cpp .
2323
Build System
2424
------------
2525

26-
If you author custom C++/CUDA code, it needs to be compiled somehow.
26+
If you are developing custom C++/CUDA code, it must be compiled.
2727
Note that if you’re interfacing with a Python library that already has bindings
28-
to precompiled C++/CUDA code, then you may actually want to write a Python custom operator
29-
(TODO: tutorial)
28+
to precompiled C++/CUDA code, you might consider writing a custom Python operator
29+
instead (:ref:`python-custom-ops-tutorial`).
3030

3131
Use `torch.utils.cpp_extension <https://pytorch.org/docs/stable/cpp_extension.html>`_
3232
to compile custom C++/CUDA code for use with PyTorch
3333
C++ extensions may be built either "ahead of time" with setuptools, or "just in time"
34-
via `load_inline <https://pytorch.org/docs/stable/cpp_extension.html#torch.utils.cpp_extension.load_inline>`;
34+
via `load_inline <https://pytorch.org/docs/stable/cpp_extension.html#torch.utils.cpp_extension.load_inline>`_;
3535
we’ll focus on the "ahead of time" flavor.
3636

37-
Using cpp_extension is as simple as writing the following setup.py:
37+
Using ``cpp_extension`` is as simple as writing the following ``setup.py``:
3838

3939
.. code-block:: python
4040
@@ -46,15 +46,18 @@ Using cpp_extension is as simple as writing the following setup.py:
4646
cpp_extension.CppExtension("extension_cpp", ["muladd.cpp"])],
4747
cmdclass={'build_ext': cpp_extension.BuildExtension})
4848
49-
If you need to compile CUDA code (e.g. .cu files), then instead use
49+
If you need to compile CUDA code (for example, ``.cu`` files), then instead use
5050
`torch.utils.cpp_extension.CUDAExtension <https://pytorch.org/docs/stable/cpp_extension.html#torch.utils.cpp_extension.CUDAExtension>`_
51-
Please see how https://github.com/pytorch/extension-cpp is set up for more details.
51+
Please see how
52+
`extension-cpp <https://github.com/pytorch/extension-cpp>`_ for an example for
53+
how this is set up.
5254

5355
Defining the custom op and adding backend implementations
5456
---------------------------------------------------------
55-
First, let’s write a C++ function that computes mymuladd:
57+
First, let’s write a C++ function that computes ``mymuladd``:
5658

5759
.. code-block:: cpp
60+
5861
at::Tensor mymuladd_cpu(at::Tensor a, const at::Tensor& b, double c) {
5962
TORCH_CHECK(a.sizes() == b.sizes());
6063
TORCH_CHECK(a.dtype() == at::kFloat);
@@ -74,29 +77,31 @@ First, let’s write a C++ function that computes mymuladd:
7477
}
7578
7679
In order to use this from PyTorch’s Python frontend, we need to register it
77-
as a PyTorch operator using the TORCH_LIBRARY API. This will automatically
80+
as a PyTorch operator using the ``TORCH_LIBRARY`` API. This will automatically
7881
bind the operator to Python.
7982

8083
Operator registration is a two step-process:
8184

82-
- we need to define the operator (so that PyTorch knows about it)
83-
- we need to register various backend implementations (e.g. CPU/CUDA) to the operator
85+
- **Defining the operator** - This step ensures that PyTorch is aware of the new operator.
86+
- **Registering backend implementations** - In this step, implementations for various
87+
backends, such as CPU and CUDA, are associated with the operator.
8488

8589
How to define an operator
8690
^^^^^^^^^^^^^^^^^^^^^^^^^
87-
To define an operator:
91+
To define an operator, follow these steps:
8892

89-
- select a namespace for an operator. We recommend the namespace be the name of your top-level
90-
project; we’ll use "extension_cpp" in our tutorial.
91-
- provide a schema string that specifies the input/output types of the operator and if an
92-
input Tensors will be mutated. We support more types in addition to Tensor and float;
93-
please see `The Custom Operators Manual <https://pytorch.org/docs/main/notes/custom_operators.html>`_
94-
for more details.
93+
1. select a namespace for an operator. We recommend the namespace be the name of your top-level
94+
project; we’ll use "extension_cpp" in our tutorial.
95+
2. provide a schema string that specifies the input/output types of the operator and if an
96+
input Tensors will be mutated. We support more types in addition to Tensor and float;
97+
please see `The Custom Operators Manual <https://pytorch.org/docs/main/notes/custom_operators.html>`_
98+
for more details.
9599

96-
If you are authoring an operator that can mutate its input Tensors, please see here
97-
(:ref:`mutable-ops`) for how to specify that.
100+
* If you are authoring an operator that can mutate its input Tensors, please see here
101+
(:ref:`mutable-ops`) for how to specify that.
98102

99103
.. code-block:: cpp
104+
100105
TORCH_LIBRARY(extension_cpp, m) {
101106
// Note that "float" in the schema corresponds to the C++ double type
102107
// and the Python float type.
@@ -107,16 +112,19 @@ This makes the operator available from Python via ``torch.ops.extension_cpp.mymu
107112

108113
How to register backend implementations for an operator
109114
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
110-
Use TORCH_LIBRARY_IMPL to register a backend implementation for the operator.
115+
Use ``TORCH_LIBRARY_IMPL`` to register a backend implementation for the operator.
111116

112117
.. code-block:: cpp
118+
113119
TORCH_LIBRARY_IMPL(extension_cpp, CPU, m) {
114120
m.impl("mymuladd", &mymuladd_cpu);
115121
}
116122
117-
If we also have a CUDA implementation myaddmul_cuda, we can register it in a separate TORCH_LIBRARY_IMPL block:
123+
If you also have a CUDA implementation of ``myaddmul``, you can register it
124+
in a separate ``TORCH_LIBRARY_IMPL`` block:
118125

119126
.. code-block:: cpp
127+
120128
__global__ void muladd_kernel(int numel, const float* a, const float* b, float c, float* result) {
121129
int idx = blockIdx.x * blockDim.x + threadIdx.x;
122130
if (idx < numel) result[idx] = a[idx] * b[idx] + c;
@@ -147,9 +155,9 @@ If we also have a CUDA implementation myaddmul_cuda, we can register it in a sep
147155
How to add torch.compile support for an operator
148156
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
149157

150-
To add torch.compile support for an operator, we must add a FakeTensor kernel (also
151-
known as a meta kernel or abstract impl). FakeTensors are Tensors that have
152-
metadata (i.e. shape, dtype, device) but no data: the FakeTensor kernel for an
158+
To add ``torch.compile`` support for an operator, we must add a FakeTensor kernel (also
159+
known as a "meta kernel" or "abstract impl"). FakeTensors are Tensors that have
160+
metadata (such as shape, dtype, device) but no data: the FakeTensor kernel for an
153161
operator specifies how to compute the metadata of output tensors given the metadata of input tensors.
154162

155163
We recommend that this be done from Python via the `torch.library.register_fake` API,
@@ -158,6 +166,7 @@ though it is possible to do this from C++ as well (see
158166
for more details).
159167

160168
.. code-block:: python
169+
161170
@torch.library.register_fake("extension_cpp::mymuladd")
162171
def _(a, b, c):
163172
torch._check(a.shape == b.shape)
@@ -168,13 +177,14 @@ for more details).
168177
169178
How to add training (autograd) support for an operator
170179
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
171-
Use torch.library.register_autograd to add training support for an operator. Prefer
172-
this over directly using Python torch.autograd.Function or C++ torch::autograd::Function;
173-
one must use those in a very specific way to avoid silent incorrectness (see
180+
Use ``torch.library.register_autograd`` to add training support for an operator. Prefer
181+
this over directly using Python ``torch.autograd.Function`` or C++ ``torch::autograd::Function``;
182+
you must use those in a very specific way to avoid silent incorrectness (see
174183
`The Custom Operators Manual <https://pytorch.org/docs/main/notes/custom_operators.html>`_
175184
for more details).
176185

177186
.. code-block:: python
187+
178188
def _backward(ctx, grad):
179189
a, b = ctx.saved_tensors
180190
grad_a, grad_b = None, None
@@ -193,20 +203,21 @@ for more details).
193203
saved_a = a
194204
ctx.save_for_backward(saved_a, saved_b)
195205
196-
# This adds training support for the operator. You must provide us
206+
# This code adds training support for the operator. You must provide us
197207
# the backward formula for the operator and a `setup_context` function
198208
# to save values to be used in the backward.
199209
torch.library.register_autograd(
200210
"extension_cpp::mymuladd", _backward, setup_context=_setup_context)
201211
202212
Note that the backward must be a composition of PyTorch-understood operators.
203213
If you wish to use another custom C++ or CUDA kernel in your backwards pass,
204-
it must be wrapped into a custom op.
214+
it must be wrapped into a custom operator.
205215

206-
So if we had our own custom mymul kernel, we would need to wrap it into a
216+
If we had our own custom ``mymul`` kernel, we would need to wrap it into a
207217
custom operator and then call that from the backward:
208218

209219
.. code-block:: cpp
220+
210221
// New! a mymul_cpu kernel
211222
at::Tensor mymul_cpu(const at::Tensor& a, const at::Tensor& b) {
212223
TORCH_CHECK(a.sizes() == b.sizes());
@@ -261,19 +272,21 @@ custom operator and then call that from the backward:
261272
ctx.save_for_backward(saved_a, saved_b)
262273
263274
264-
# This adds training support for the operator. You must provide us
275+
# This code adds training support for the operator. You must provide us
265276
# the backward formula for the operator and a `setup_context` function
266277
# to save values to be used in the backward.
267278
torch.library.register_autograd(
268279
"extension_cpp::mymuladd", _backward, setup_context=_setup_context)
269280
270281
How to test an operator
271282
-----------------------
272-
Use torch.library.opcheck to test that the custom op was registered correctly.
273-
This does not test that the gradients are mathematically correct; please write
274-
separate tests for that (either manual ones or torch.autograd.gradcheck).
283+
Use ``torch.library.opcheck`` to test that the custom op was registered correctly.
284+
Note that this function does not test that the gradients are mathematically correct
285+
-- plan to write separate tests for that, either manual ones or by using
286+
``torch.autograd.gradcheck``.
275287

276288
.. code-block:: python
289+
277290
def sample_inputs(device, *, requires_grad=False):
278291
def make_tensor(*size):
279292
return torch.randn(size, device=device, requires_grad=requires_grad)
@@ -308,12 +321,13 @@ How to create mutable operators
308321
-------------------------------
309322
You may wish to author a custom operator that mutates its inputs. Use ``Tensor(a!)``
310323
to specify each mutable Tensor in the schema; otherwise, there will be undefined
311-
behavior. If there are multiple mutated Tensors, use different names (i.e. ``Tensor(a!)``,
324+
behavior. If there are multiple mutated Tensors, use different names (for example, ``Tensor(a!)``,
312325
``Tensor(b!)``, ``Tensor(c!)``) for each mutable Tensor.
313326

314327
Let's author a ``myadd_out(a, b, out)`` operator, which writes the contents of ``a+b`` into ``out``.
315328

316329
.. code-block:: cpp
330+
317331
// An example of an operator that mutates one of its inputs.
318332
void myadd_out_cpu(const at::Tensor& a, const at::Tensor& b, at::Tensor& out) {
319333
TORCH_CHECK(a.sizes() == b.sizes());
@@ -338,6 +352,7 @@ Let's author a ``myadd_out(a, b, out)`` operator, which writes the contents of `
338352
When defining the operator, we must specify that it mutates the out Tensor in the schema:
339353

340354
.. code-block:: cpp
355+
341356
TORCH_LIBRARY(extension_cpp, m) {
342357
m.def("mymuladd(Tensor a, Tensor b, float c) -> Tensor");
343358
m.def("mymul(Tensor a, Tensor b) -> Tensor");
@@ -352,14 +367,14 @@ When defining the operator, we must specify that it mutates the out Tensor in th
352367
m.impl("myadd_out", &myadd_out_cpu);
353368
}
354369
355-
Please do not return any mutated Tensors as outputs of the operator; this will
356-
run you into problems later down the line.
370+
.. note::
371+
372+
Do not return any mutated Tensors as outputs of the operator as this will
373+
cause incompatibility with PyTorch subsystems like ``torch.compile``.
357374

358375
Conclusion
359376
----------
360377
In this tutorial, we went over the recommended approach to integrating Custom C++
361-
and CUDA operators with PyTorch. The TORCH_LIBRARY/torch.library APIs are fairly
362-
low-level; more detail about how to use them can be found over at
363-
`The Custom Operators Manual <https://pytorch.org/docs/main/notes/custom_operators.html>`_
364-
365-
378+
and CUDA operators with PyTorch. The ``TORCH_LIBRARY/torch.library`` APIs are fairly
379+
low-level. For more information about how to use the API, see
380+
`The Custom Operators Manual <https://pytorch.org/docs/main/notes/custom_operators.html>`_.

advanced_source/python_custom_ops.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
# -*- coding: utf-8 -*-
22

33
"""
4+
.. _python-custom-ops-tutorial:
5+
46
Python Custom Operators
57
=======================
68

0 commit comments

Comments
 (0)