1
+ .. _cpp-custom-ops-tutorial :
2
+
1
3
Custom C++ and CUDA Operators
2
4
=============================
3
5
4
- .. note ::
5
- This tutorial is for PyTorch 2.4+ and the PyTorch nightlies.
6
+ **Author: ** `Richard Zou <https://github.com/zou3519 >`_
7
+
8
+ .. grid :: 2
9
+
10
+ .. grid-item-card :: :octicon:`mortar-board;1em;` What you will learn
11
+
12
+ * How to integrate custom operators written in C++/CUDA with PyTorch
13
+ * How to test custom operators using ``torch.library.opcheck ``
14
+
15
+ .. grid-item-card :: :octicon:`list-unordered;1em;` Prerequisites
16
+
17
+ * PyTorch 2.4 or later
18
+ * Basic understanding of C++ and CUDA programming
6
19
7
20
PyTorch offers a large library of operators that work on Tensors (e.g. torch.add, torch.sum, etc).
8
21
However, you may wish to bring a new custom operator to PyTorch. This tutorial demonstrates the
@@ -17,11 +30,11 @@ the operation are as follows:
17
30
def mymuladd (a : Tensor, b : Tensor, c : float ):
18
31
return a * b + c
19
32
20
- You can find the end-to-end working example for this tutorial over at
21
- https://github.com/pytorch/extension-cpp .
33
+ You can find the end-to-end working example for this tutorial
34
+ ` here < https://github.com/pytorch/extension-cpp >`_ .
22
35
23
- Build System
24
- ------------
36
+ Setting up the Build System
37
+ ---------------------------
25
38
26
39
If you are developing custom C++/CUDA code, it must be compiled.
27
40
Note that if you’re interfacing with a Python library that already has bindings
@@ -54,7 +67,7 @@ how this is set up.
54
67
55
68
Defining the custom op and adding backend implementations
56
69
---------------------------------------------------------
57
- First, let’ s write a C++ function that computes ``mymuladd ``:
70
+ First, let' s write a C++ function that computes ``mymuladd ``:
58
71
59
72
.. code-block :: cpp
60
73
@@ -86,8 +99,8 @@ Operator registration is a two step-process:
86
99
- **Registering backend implementations ** - In this step, implementations for various
87
100
backends, such as CPU and CUDA, are associated with the operator.
88
101
89
- How to define an operator
90
- ^^^^^^^^^^^^^^^^^^^^^^^^^
102
+ Defining an operator
103
+ ^^^^^^^^^^^^^^^^^^^^
91
104
To define an operator, follow these steps:
92
105
93
106
1. select a namespace for an operator. We recommend the namespace be the name of your top-level
@@ -110,8 +123,8 @@ To define an operator, follow these steps:
110
123
111
124
This makes the operator available from Python via ``torch.ops.extension_cpp.mymuladd ``.
112
125
113
- How to register backend implementations for an operator
114
- ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
126
+ Registering backend implementations for an operator
127
+ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
115
128
Use ``TORCH_LIBRARY_IMPL `` to register a backend implementation for the operator.
116
129
117
130
.. code-block :: cpp
@@ -129,7 +142,7 @@ in a separate ``TORCH_LIBRARY_IMPL`` block:
129
142
int idx = blockIdx.x * blockDim.x + threadIdx.x;
130
143
if (idx < numel) result[idx] = a[idx] * b[idx] + c;
131
144
}
132
-
145
+
133
146
at::Tensor mymuladd_cuda(const at::Tensor& a, const at::Tensor& b, double c) {
134
147
TORCH_CHECK(a.sizes() == b.sizes());
135
148
TORCH_CHECK(a.dtype() == at::kFloat);
@@ -142,17 +155,17 @@ in a separate ``TORCH_LIBRARY_IMPL`` block:
142
155
const float* a_ptr = a_contig.data_ptr<float>();
143
156
const float* b_ptr = b_contig.data_ptr<float>();
144
157
float* result_ptr = result.data_ptr<float>();
145
-
158
+
146
159
int numel = a_contig.numel();
147
160
muladd_kernel<<<(numel+255)/256, 256>>>(numel, a_ptr, b_ptr, c, result_ptr);
148
161
return result;
149
162
}
150
-
163
+
151
164
TORCH_LIBRARY_IMPL(extension_cpp, CUDA, m) {
152
165
m.impl("mymuladd", &mymuladd_cuda);
153
166
}
154
167
155
- How to add torch.compile support for an operator
168
+ Adding `` torch.compile `` support for an operator
156
169
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
157
170
158
171
To add ``torch.compile `` support for an operator, we must add a FakeTensor kernel (also
@@ -182,8 +195,8 @@ for more details).
182
195
torch._check(a.device == b.device)
183
196
return torch.empty_like(a)
184
197
185
- How to set up hybrid Python/C++ registration
186
- ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
198
+ Setting up hybrid Python/C++ registration
199
+ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
187
200
In this tutorial, we defined a custom operator in C++, added CPU/CUDA
188
201
implementations in C++, and added ``FakeTensor `` kernels and backward formulas
189
202
in Python. The order in which these registrations are loaded (or imported)
@@ -199,9 +212,9 @@ of two ways:
199
212
2. If your C++ custom operator is located in a shared library object, you can
200
213
also use ``torch.ops.load_library("/path/to/library.so") `` to load it.
201
214
202
-
203
- How to add training (autograd) support for an operator
204
- ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
215
+
216
+ Adding training (autograd) support for an operator
217
+ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
205
218
Use ``torch.library.register_autograd `` to add training support for an operator. Prefer
206
219
this over directly using Python ``torch.autograd.Function `` or C++ ``torch::autograd::Function ``;
207
220
you must use those in a very specific way to avoid silent incorrectness (see
@@ -218,7 +231,7 @@ for more details).
218
231
if ctx.needs_input_grad[1 ]:
219
232
grad_b = grad * a
220
233
return grad_a, grad_b, None
221
-
234
+
222
235
def _setup_context (ctx , inputs , output ):
223
236
a, b, c = inputs
224
237
saved_a, saved_b = None , None
@@ -227,7 +240,7 @@ for more details).
227
240
if ctx.needs_input_grad[1 ]:
228
241
saved_a = a
229
242
ctx.save_for_backward(saved_a, saved_b)
230
-
243
+
231
244
# This code adds training support for the operator. You must provide us
232
245
# the backward formula for the operator and a `setup_context` function
233
246
# to save values to be used in the backward.
@@ -248,8 +261,8 @@ custom operator and then call that from the backward:
248
261
TORCH_CHECK(a.sizes() == b.sizes());
249
262
TORCH_CHECK(a.dtype() == at::kFloat);
250
263
TORCH_CHECK(b.dtype() == at::kFloat);
251
- TORCH_INTERNAL_ASSERT (a.device().type() == at::DeviceType::CPU);
252
- TORCH_INTERNAL_ASSERT (b.device().type() == at::DeviceType::CPU);
264
+ TORCH_CHECK (a.device().type() == at::DeviceType::CPU);
265
+ TORCH_CHECK (b.device().type() == at::DeviceType::CPU);
253
266
at::Tensor a_contig = a.contiguous();
254
267
at::Tensor b_contig = b.contiguous();
255
268
at::Tensor result = torch::empty(a_contig.sizes(), a_contig.options());
@@ -261,14 +274,14 @@ custom operator and then call that from the backward:
261
274
}
262
275
return result;
263
276
}
264
-
277
+
265
278
TORCH_LIBRARY(extension_cpp, m) {
266
279
m.def("mymuladd(Tensor a, Tensor b, float c) -> Tensor");
267
280
// New! defining the mymul operator
268
281
m.def("mymul(Tensor a, Tensor b) -> Tensor");
269
282
}
270
-
271
-
283
+
284
+
272
285
TORCH_LIBRARY_IMPL(extension_cpp, CPU, m) {
273
286
m.impl("mymuladd", &mymuladd_cpu);
274
287
// New! registering the cpu kernel for the mymul operator
@@ -285,8 +298,8 @@ custom operator and then call that from the backward:
285
298
if ctx.needs_input_grad[1 ]:
286
299
grad_b = torch.ops.extension_cpp.mymul.default(grad, a)
287
300
return grad_a, grad_b, None
288
-
289
-
301
+
302
+
290
303
def _setup_context (ctx , inputs , output ):
291
304
a, b, c = inputs
292
305
saved_a, saved_b = None , None
@@ -295,16 +308,16 @@ custom operator and then call that from the backward:
295
308
if ctx.needs_input_grad[1 ]:
296
309
saved_a = a
297
310
ctx.save_for_backward(saved_a, saved_b)
298
-
299
-
311
+
312
+
300
313
# This code adds training support for the operator. You must provide us
301
314
# the backward formula for the operator and a `setup_context` function
302
315
# to save values to be used in the backward.
303
316
torch.library.register_autograd(
304
317
" extension_cpp::mymuladd" , _backward, setup_context = _setup_context)
305
318
306
- How to test an operator
307
- -----------------------
319
+ Testing an operator
320
+ -------------------
308
321
Use ``torch.library.opcheck `` to test that the custom op was registered correctly.
309
322
Note that this function does not test that the gradients are mathematically correct
310
323
-- plan to write separate tests for that, either manual ones or by using
@@ -315,36 +328,36 @@ Note that this function does not test that the gradients are mathematically corr
315
328
def sample_inputs (device , * , requires_grad = False ):
316
329
def make_tensor (* size ):
317
330
return torch.randn(size, device = device, requires_grad = requires_grad)
318
-
331
+
319
332
def make_nondiff_tensor (* size ):
320
333
return torch.randn(size, device = device, requires_grad = False )
321
-
334
+
322
335
return [
323
336
[make_tensor(3 ), make_tensor(3 ), 1 ],
324
337
[make_tensor(20 ), make_tensor(20 ), 3.14 ],
325
338
[make_tensor(20 ), make_nondiff_tensor(20 ), - 123 ],
326
339
[make_nondiff_tensor(2 , 3 ), make_tensor(2 , 3 ), - 0.3 ],
327
340
]
328
-
341
+
329
342
def reference_muladd (a , b , c ):
330
343
return a * b + c
331
-
344
+
332
345
samples = sample_inputs(device, requires_grad = True )
333
346
samples.extend(sample_inputs(device, requires_grad = False ))
334
347
for args in samples:
335
348
# Correctness test
336
349
result = torch.ops.extension_cpp.mymuladd(* args)
337
350
expected = reference_muladd(* args)
338
351
torch.testing.assert_close(result, expected)
339
-
352
+
340
353
# Use opcheck to check for incorrect usage of operator registration APIs
341
354
torch.library.opcheck(torch.ops.extension_cpp.mymuladd.default, args)
342
355
343
356
.. _mutable-ops :
344
357
345
- How to create mutable operators
346
- -------------------------------
347
- You may wish to author a custom operator that mutates its inputs. Use ``Tensor(a!) ``
358
+ Creating mutable operators
359
+ --------------------------
360
+ You may wish to author a custom operator that mutates its inputs. Use ``Tensor(a!) ``
348
361
to specify each mutable Tensor in the schema; otherwise, there will be undefined
349
362
behavior. If there are multiple mutated Tensors, use different names (for example, ``Tensor(a!) ``,
350
363
``Tensor(b!) ``, ``Tensor(c!) ``) for each mutable Tensor.
0 commit comments