@@ -23,18 +23,18 @@ https://github.com/pytorch/extension-cpp .
23
23
Build System
24
24
------------
25
25
26
- If you author custom C++/CUDA code, it needs to be compiled somehow .
26
+ If you are developing custom C++/CUDA code, it must be compiled.
27
27
Note that if you’re interfacing with a Python library that already has bindings
28
- to precompiled C++/CUDA code, then you may actually want to write a Python custom operator
29
- (TODO: tutorial)
28
+ to precompiled C++/CUDA code, you might consider writing a custom Python operator
29
+ instead ( :ref: ` python-custom-ops- tutorial`).
30
30
31
31
Use `torch.utils.cpp_extension <https://pytorch.org/docs/stable/cpp_extension.html >`_
32
32
to compile custom C++/CUDA code for use with PyTorch
33
33
C++ extensions may be built either "ahead of time" with setuptools, or "just in time"
34
- via `load_inline <https://pytorch.org/docs/stable/cpp_extension.html#torch.utils.cpp_extension.load_inline> `;
34
+ via `load_inline <https://pytorch.org/docs/stable/cpp_extension.html#torch.utils.cpp_extension.load_inline >`_ ;
35
35
we’ll focus on the "ahead of time" flavor.
36
36
37
- Using cpp_extension is as simple as writing the following setup.py:
37
+ Using `` cpp_extension `` is as simple as writing the following `` setup.py `` :
38
38
39
39
.. code-block :: python
40
40
@@ -46,15 +46,18 @@ Using cpp_extension is as simple as writing the following setup.py:
46
46
cpp_extension.CppExtension(" extension_cpp" , [" muladd.cpp" ])],
47
47
cmdclass = {' build_ext' : cpp_extension.BuildExtension})
48
48
49
- If you need to compile CUDA code (e.g. .cu files), then instead use
49
+ If you need to compile CUDA code (for example, `` .cu `` files), then instead use
50
50
`torch.utils.cpp_extension.CUDAExtension <https://pytorch.org/docs/stable/cpp_extension.html#torch.utils.cpp_extension.CUDAExtension >`_
51
- Please see how https://github.com/pytorch/extension-cpp is set up for more details.
51
+ Please see how
52
+ `extension-cpp <https://github.com/pytorch/extension-cpp >`_ for an example for
53
+ how this is set up.
52
54
53
55
Defining the custom op and adding backend implementations
54
56
---------------------------------------------------------
55
- First, let’s write a C++ function that computes mymuladd:
57
+ First, let’s write a C++ function that computes `` mymuladd `` :
56
58
57
59
.. code-block :: cpp
60
+
58
61
at::Tensor mymuladd_cpu(at::Tensor a, const at::Tensor& b, double c) {
59
62
TORCH_CHECK(a.sizes() == b.sizes());
60
63
TORCH_CHECK(a.dtype() == at::kFloat);
@@ -74,29 +77,31 @@ First, let’s write a C++ function that computes mymuladd:
74
77
}
75
78
76
79
In order to use this from PyTorch’s Python frontend, we need to register it
77
- as a PyTorch operator using the TORCH_LIBRARY API. This will automatically
80
+ as a PyTorch operator using the `` TORCH_LIBRARY `` API. This will automatically
78
81
bind the operator to Python.
79
82
80
83
Operator registration is a two step-process:
81
84
82
- - we need to define the operator (so that PyTorch knows about it)
83
- - we need to register various backend implementations (e.g. CPU/CUDA) to the operator
85
+ - **Defining the operator ** - This step ensures that PyTorch is aware of the new operator.
86
+ - **Registering backend implementations ** - In this step, implementations for various
87
+ backends, such as CPU and CUDA, are associated with the operator.
84
88
85
89
How to define an operator
86
90
^^^^^^^^^^^^^^^^^^^^^^^^^
87
- To define an operator:
91
+ To define an operator, follow these steps :
88
92
89
- - select a namespace for an operator. We recommend the namespace be the name of your top-level
90
- project; we’ll use "extension_cpp" in our tutorial.
91
- - provide a schema string that specifies the input/output types of the operator and if an
92
- input Tensors will be mutated. We support more types in addition to Tensor and float;
93
- please see `The Custom Operators Manual <https://pytorch.org/docs/main/notes/custom_operators.html >`_
94
- for more details.
93
+ 1. select a namespace for an operator. We recommend the namespace be the name of your top-level
94
+ project; we’ll use "extension_cpp" in our tutorial.
95
+ 2. provide a schema string that specifies the input/output types of the operator and if an
96
+ input Tensors will be mutated. We support more types in addition to Tensor and float;
97
+ please see `The Custom Operators Manual <https://pytorch.org/docs/main/notes/custom_operators.html >`_
98
+ for more details.
95
99
96
- If you are authoring an operator that can mutate its input Tensors, please see here
97
- (:ref: `mutable-ops `) for how to specify that.
100
+ * If you are authoring an operator that can mutate its input Tensors, please see here
101
+ (:ref: `mutable-ops `) for how to specify that.
98
102
99
103
.. code-block :: cpp
104
+
100
105
TORCH_LIBRARY(extension_cpp, m) {
101
106
// Note that "float" in the schema corresponds to the C++ double type
102
107
// and the Python float type.
@@ -107,16 +112,19 @@ This makes the operator available from Python via ``torch.ops.extension_cpp.mymu
107
112
108
113
How to register backend implementations for an operator
109
114
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
110
- Use TORCH_LIBRARY_IMPL to register a backend implementation for the operator.
115
+ Use `` TORCH_LIBRARY_IMPL `` to register a backend implementation for the operator.
111
116
112
117
.. code-block :: cpp
118
+
113
119
TORCH_LIBRARY_IMPL(extension_cpp, CPU, m) {
114
120
m.impl("mymuladd", &mymuladd_cpu);
115
121
}
116
122
117
- If we also have a CUDA implementation myaddmul_cuda, we can register it in a separate TORCH_LIBRARY_IMPL block:
123
+ If you also have a CUDA implementation of ``myaddmul ``, you can register it
124
+ in a separate ``TORCH_LIBRARY_IMPL `` block:
118
125
119
126
.. code-block :: cpp
127
+
120
128
__global__ void muladd_kernel(int numel, const float* a, const float* b, float c, float* result) {
121
129
int idx = blockIdx.x * blockDim.x + threadIdx.x;
122
130
if (idx < numel) result[idx] = a[idx] * b[idx] + c;
@@ -147,9 +155,9 @@ If we also have a CUDA implementation myaddmul_cuda, we can register it in a sep
147
155
How to add torch.compile support for an operator
148
156
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
149
157
150
- To add torch.compile support for an operator, we must add a FakeTensor kernel (also
151
- known as a “ meta kernel” or “ abstract impl” ). FakeTensors are Tensors that have
152
- metadata (i.e. shape, dtype, device) but no data: the FakeTensor kernel for an
158
+ To add `` torch.compile `` support for an operator, we must add a FakeTensor kernel (also
159
+ known as a " meta kernel" or " abstract impl" ). FakeTensors are Tensors that have
160
+ metadata (such as shape, dtype, device) but no data: the FakeTensor kernel for an
153
161
operator specifies how to compute the metadata of output tensors given the metadata of input tensors.
154
162
155
163
We recommend that this be done from Python via the `torch.library.register_fake ` API,
@@ -158,6 +166,7 @@ though it is possible to do this from C++ as well (see
158
166
for more details).
159
167
160
168
.. code-block :: python
169
+
161
170
@torch.library.register_fake (" extension_cpp::mymuladd" )
162
171
def _ (a , b , c ):
163
172
torch._check(a.shape == b.shape)
@@ -168,13 +177,14 @@ for more details).
168
177
169
178
How to add training (autograd) support for an operator
170
179
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
171
- Use torch.library.register_autograd to add training support for an operator. Prefer
172
- this over directly using Python torch.autograd.Function or C++ torch::autograd: :Function;
173
- one must use those in a very specific way to avoid silent incorrectness (see
180
+ Use `` torch.library.register_autograd `` to add training support for an operator. Prefer
181
+ this over directly using Python `` torch.autograd.Function `` or C++ `` torch::autograd::Function `` ;
182
+ you must use those in a very specific way to avoid silent incorrectness (see
174
183
`The Custom Operators Manual <https://pytorch.org/docs/main/notes/custom_operators.html >`_
175
184
for more details).
176
185
177
186
.. code-block :: python
187
+
178
188
def _backward (ctx , grad ):
179
189
a, b = ctx.saved_tensors
180
190
grad_a, grad_b = None , None
@@ -193,20 +203,21 @@ for more details).
193
203
saved_a = a
194
204
ctx.save_for_backward(saved_a, saved_b)
195
205
196
- # This adds training support for the operator. You must provide us
206
+ # This code adds training support for the operator. You must provide us
197
207
# the backward formula for the operator and a `setup_context` function
198
208
# to save values to be used in the backward.
199
209
torch.library.register_autograd(
200
210
" extension_cpp::mymuladd" , _backward, setup_context = _setup_context)
201
211
202
212
Note that the backward must be a composition of PyTorch-understood operators.
203
213
If you wish to use another custom C++ or CUDA kernel in your backwards pass,
204
- it must be wrapped into a custom op .
214
+ it must be wrapped into a custom operator .
205
215
206
- So if we had our own custom mymul kernel, we would need to wrap it into a
216
+ If we had our own custom `` mymul `` kernel, we would need to wrap it into a
207
217
custom operator and then call that from the backward:
208
218
209
219
.. code-block :: cpp
220
+
210
221
// New! a mymul_cpu kernel
211
222
at::Tensor mymul_cpu(const at::Tensor& a, const at::Tensor& b) {
212
223
TORCH_CHECK(a.sizes() == b.sizes());
@@ -261,19 +272,21 @@ custom operator and then call that from the backward:
261
272
ctx.save_for_backward(saved_a, saved_b)
262
273
263
274
264
- # This adds training support for the operator. You must provide us
275
+ # This code adds training support for the operator. You must provide us
265
276
# the backward formula for the operator and a `setup_context` function
266
277
# to save values to be used in the backward.
267
278
torch.library.register_autograd(
268
279
" extension_cpp::mymuladd" , _backward, setup_context = _setup_context)
269
280
270
281
How to test an operator
271
282
-----------------------
272
- Use torch.library.opcheck to test that the custom op was registered correctly.
273
- This does not test that the gradients are mathematically correct; please write
274
- separate tests for that (either manual ones or torch.autograd.gradcheck).
283
+ Use ``torch.library.opcheck `` to test that the custom op was registered correctly.
284
+ Note that this function does not test that the gradients are mathematically correct
285
+ -- plan to write separate tests for that, either manual ones or by using
286
+ ``torch.autograd.gradcheck ``.
275
287
276
288
.. code-block :: python
289
+
277
290
def sample_inputs (device , * , requires_grad = False ):
278
291
def make_tensor (* size ):
279
292
return torch.randn(size, device = device, requires_grad = requires_grad)
@@ -308,12 +321,13 @@ How to create mutable operators
308
321
-------------------------------
309
322
You may wish to author a custom operator that mutates its inputs. Use ``Tensor(a!) ``
310
323
to specify each mutable Tensor in the schema; otherwise, there will be undefined
311
- behavior. If there are multiple mutated Tensors, use different names (i.e. ``Tensor(a!) ``,
324
+ behavior. If there are multiple mutated Tensors, use different names (for example, ``Tensor(a!) ``,
312
325
``Tensor(b!) ``, ``Tensor(c!) ``) for each mutable Tensor.
313
326
314
327
Let's author a ``myadd_out(a, b, out) `` operator, which writes the contents of ``a+b `` into ``out ``.
315
328
316
329
.. code-block :: cpp
330
+
317
331
// An example of an operator that mutates one of its inputs.
318
332
void myadd_out_cpu(const at::Tensor& a, const at::Tensor& b, at::Tensor& out) {
319
333
TORCH_CHECK(a.sizes() == b.sizes());
@@ -338,6 +352,7 @@ Let's author a ``myadd_out(a, b, out)`` operator, which writes the contents of `
338
352
When defining the operator, we must specify that it mutates the out Tensor in the schema:
339
353
340
354
.. code-block :: cpp
355
+
341
356
TORCH_LIBRARY(extension_cpp, m) {
342
357
m.def("mymuladd(Tensor a, Tensor b, float c) -> Tensor");
343
358
m.def("mymul(Tensor a, Tensor b) -> Tensor");
@@ -352,14 +367,14 @@ When defining the operator, we must specify that it mutates the out Tensor in th
352
367
m.impl("myadd_out", &myadd_out_cpu);
353
368
}
354
369
355
- Please do not return any mutated Tensors as outputs of the operator; this will
356
- run you into problems later down the line.
370
+ .. note ::
371
+
372
+ Do not return any mutated Tensors as outputs of the operator as this will
373
+ cause incompatibility with PyTorch subsystems like ``torch.compile ``.
357
374
358
375
Conclusion
359
376
----------
360
377
In this tutorial, we went over the recommended approach to integrating Custom C++
361
- and CUDA operators with PyTorch. The TORCH_LIBRARY/torch.library APIs are fairly
362
- low-level; more detail about how to use them can be found over at
363
- `The Custom Operators Manual <https://pytorch.org/docs/main/notes/custom_operators.html >`_
364
-
365
-
378
+ and CUDA operators with PyTorch. The ``TORCH_LIBRARY/torch.library `` APIs are fairly
379
+ low-level. For more information about how to use the API, see
380
+ `The Custom Operators Manual <https://pytorch.org/docs/main/notes/custom_operators.html >`_.
0 commit comments