Skip to content

Commit dcccabd

Browse files
committed
Add new Python and C++/CUDA Custom Op tutorials
I want to land this before PyTorch 2.4 (so we can link to these in PyTorch's nightly documentation) and then have a follow-up PR for 2.4 that actually runs the scripts (so that they can generate outputs). pytorch/pytorch#127443 to remind myself of the above. NB: These two tutorials replace all of the existing custom ops (and cpp extensions) tutorials: - advanced/cpp_extension - advanced/torch_script_custom_ops - advanced/torch_script_custom_classes - advanced/dispatcher In a follow-up PR we will add warnings to all of those tutorials stating that they are deprecated but we will preserve the text just in case people still need them (e.g. if they are not using PyTorch 2.4). Test Plan: - I tested these locally.
1 parent 2dd1997 commit dcccabd

File tree

4 files changed

+618
-0
lines changed

4 files changed

+618
-0
lines changed

.jenkins/validate_tutorials_built.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
"intermediate_source/fx_conv_bn_fuser",
3131
"intermediate_source/_torch_export_nightly_tutorial", # does not work on release
3232
"advanced_source/super_resolution_with_onnxruntime",
33+
"advanced_source/python_custom_ops", # https://github.com/pytorch/pytorch/issues/127443
3334
"advanced_source/ddp_pipeline", # requires 4 gpus
3435
"advanced_source/usb_semisup_learn", # fails with CUDA OOM error, should try on a different worker
3536
"prototype_source/fx_graph_mode_ptq_dynamic",

advanced_source/cpp_custom_ops.rst

Lines changed: 365 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,365 @@
1+
Custom C++ and CUDA Operators
2+
=============================
3+
4+
.. note::
5+
This tutorial is for PyTorch 2.4+ and the PyTorch nightlies.
6+
7+
PyTorch offers a large library of operators that work on Tensors (e.g. torch.add, torch.sum, etc).
8+
However, you may wish to bring a new custom operator to PyTorch. This tutorial demonstrates the
9+
blessed path to authoring a custom operator written in C++/CUDA.
10+
11+
For our tutorial, we’ll demonstrate how to author a fused multiply-add C++
12+
and CUDA operator that composes with PyTorch subsystems. The semantics of
13+
the operation are as follows:
14+
15+
.. code-block:: python
16+
17+
def mymuladd(a: Tensor, b: Tensor, c: float):
18+
return a * b + c
19+
20+
You can find the end-to-end working example for this tutorial over at
21+
https://github.com/pytorch/extension-cpp .
22+
23+
Build System
24+
------------
25+
26+
If you author custom C++/CUDA code, it needs to be compiled somehow.
27+
Note that if you’re interfacing with a Python library that already has bindings
28+
to precompiled C++/CUDA code, then you may actually want to write a Python custom operator
29+
(TODO: tutorial)
30+
31+
Use `torch.utils.cpp_extension <https://pytorch.org/docs/stable/cpp_extension.html>`_
32+
to compile custom C++/CUDA code for use with PyTorch
33+
C++ extensions may be built either "ahead of time" with setuptools, or "just in time"
34+
via `load_inline <https://pytorch.org/docs/stable/cpp_extension.html#torch.utils.cpp_extension.load_inline>`;
35+
we’ll focus on the "ahead of time" flavor.
36+
37+
Using cpp_extension is as simple as writing the following setup.py:
38+
39+
.. code-block:: python
40+
41+
from setuptools import setup, Extension
42+
from torch.utils import cpp_extension
43+
44+
setup(name="extension_cpp",
45+
ext_modules=[
46+
cpp_extension.CppExtension("extension_cpp", ["muladd.cpp"])],
47+
cmdclass={'build_ext': cpp_extension.BuildExtension})
48+
49+
If you need to compile CUDA code (e.g. .cu files), then instead use
50+
`torch.utils.cpp_extension.CUDAExtension <https://pytorch.org/docs/stable/cpp_extension.html#torch.utils.cpp_extension.CUDAExtension>`_
51+
Please see how https://github.com/pytorch/extension-cpp is set up for more details.
52+
53+
Defining the custom op and adding backend implementations
54+
---------------------------------------------------------
55+
First, let’s write a C++ function that computes mymuladd:
56+
57+
.. code-block:: cpp
58+
at::Tensor mymuladd_cpu(at::Tensor a, const at::Tensor& b, double c) {
59+
TORCH_CHECK(a.sizes() == b.sizes());
60+
TORCH_CHECK(a.dtype() == at::kFloat);
61+
TORCH_CHECK(b.dtype() == at::kFloat);
62+
TORCH_INTERNAL_ASSERT(a.device().type() == at::DeviceType::CPU);
63+
TORCH_INTERNAL_ASSERT(b.device().type() == at::DeviceType::CPU);
64+
at::Tensor a_contig = a.contiguous();
65+
at::Tensor b_contig = b.contiguous();
66+
at::Tensor result = torch::empty(a_contig.sizes(), a_contig.options());
67+
const float* a_ptr = a_contig.data_ptr<float>();
68+
const float* b_ptr = b_contig.data_ptr<float>();
69+
float* result_ptr = result.data_ptr<float>();
70+
for (int64_t i = 0; i < result.numel(); i++) {
71+
result_ptr[i] = a_ptr[i] * b_ptr[i] + c;
72+
}
73+
return result;
74+
}
75+
76+
In order to use this from PyTorch’s Python frontend, we need to register it
77+
as a PyTorch operator using the TORCH_LIBRARY API. This will automatically
78+
bind the operator to Python.
79+
80+
Operator registration is a two step-process:
81+
82+
- we need to define the operator (so that PyTorch knows about it)
83+
- we need to register various backend implementations (e.g. CPU/CUDA) to the operator
84+
85+
How to define an operator
86+
^^^^^^^^^^^^^^^^^^^^^^^^^
87+
To define an operator:
88+
89+
- select a namespace for an operator. We recommend the namespace be the name of your top-level
90+
project; we’ll use "extension_cpp" in our tutorial.
91+
- provide a schema string that specifies the input/output types of the operator and if an
92+
input Tensors will be mutated. We support more types in addition to Tensor and float;
93+
please see `The Custom Operators Manual <https://pytorch.org/docs/main/notes/custom_operators.html>`_
94+
for more details.
95+
96+
If you are authoring an operator that can mutate its input Tensors, please see here
97+
(:ref:`mutable-ops`) for how to specify that.
98+
99+
.. code-block:: cpp
100+
TORCH_LIBRARY(extension_cpp, m) {
101+
// Note that "float" in the schema corresponds to the C++ double type
102+
// and the Python float type.
103+
m.def("mymuladd(Tensor a, Tensor b, float c) -> Tensor");
104+
}
105+
106+
This makes the operator available from Python via ``torch.ops.extension_cpp.mymuladd``.
107+
108+
How to register backend implementations for an operator
109+
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
110+
Use TORCH_LIBRARY_IMPL to register a backend implementation for the operator.
111+
112+
.. code-block:: cpp
113+
TORCH_LIBRARY_IMPL(extension_cpp, CPU, m) {
114+
m.impl("mymuladd", &mymuladd_cpu);
115+
}
116+
117+
If we also have a CUDA implementation myaddmul_cuda, we can register it in a separate TORCH_LIBRARY_IMPL block:
118+
119+
.. code-block:: cpp
120+
__global__ void muladd_kernel(int numel, const float* a, const float* b, float c, float* result) {
121+
int idx = blockIdx.x * blockDim.x + threadIdx.x;
122+
if (idx < numel) result[idx] = a[idx] * b[idx] + c;
123+
}
124+
125+
at::Tensor mymuladd_cuda(const at::Tensor& a, const at::Tensor& b, double c) {
126+
TORCH_CHECK(a.sizes() == b.sizes());
127+
TORCH_CHECK(a.dtype() == at::kFloat);
128+
TORCH_CHECK(b.dtype() == at::kFloat);
129+
TORCH_INTERNAL_ASSERT(a.device().type() == at::DeviceType::CUDA);
130+
TORCH_INTERNAL_ASSERT(b.device().type() == at::DeviceType::CUDA);
131+
at::Tensor a_contig = a.contiguous();
132+
at::Tensor b_contig = b.contiguous();
133+
at::Tensor result = torch::empty(a_contig.sizes(), a_contig.options());
134+
const float* a_ptr = a_contig.data_ptr<float>();
135+
const float* b_ptr = b_contig.data_ptr<float>();
136+
float* result_ptr = result.data_ptr<float>();
137+
138+
int numel = a_contig.numel();
139+
muladd_kernel<<<(numel+255)/256, 256>>>(numel, a_ptr, b_ptr, c, result_ptr);
140+
return result;
141+
}
142+
143+
TORCH_LIBRARY_IMPL(extension_cpp, CUDA, m) {
144+
m.impl("mymuladd", &mymuladd_cuda);
145+
}
146+
147+
How to add torch.compile support for an operator
148+
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
149+
150+
To add torch.compile support for an operator, we must add a FakeTensor kernel (also
151+
known as a “meta kernel” or “abstract impl”). FakeTensors are Tensors that have
152+
metadata (i.e. shape, dtype, device) but no data: the FakeTensor kernel for an
153+
operator specifies how to compute the metadata of output tensors given the metadata of input tensors.
154+
155+
We recommend that this be done from Python via the `torch.library.register_fake` API,
156+
though it is possible to do this from C++ as well (see
157+
`The Custom Operators Manual <https://pytorch.org/docs/main/notes/custom_operators.html>`_
158+
for more details).
159+
160+
.. code-block:: python
161+
@torch.library.register_fake("extension_cpp::mymuladd")
162+
def _(a, b, c):
163+
torch._check(a.shape == b.shape)
164+
torch._check(a.dtype == torch.float)
165+
torch._check(b.dtype == torch.float)
166+
torch._check(a.device == b.device)
167+
return torch.empty_like(a)
168+
169+
How to add training (autograd) support for an operator
170+
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
171+
Use torch.library.register_autograd to add training support for an operator. Prefer
172+
this over directly using Python torch.autograd.Function or C++ torch::autograd::Function;
173+
one must use those in a very specific way to avoid silent incorrectness (see
174+
`The Custom Operators Manual <https://pytorch.org/docs/main/notes/custom_operators.html>`_
175+
for more details).
176+
177+
.. code-block:: python
178+
def _backward(ctx, grad):
179+
a, b = ctx.saved_tensors
180+
grad_a, grad_b = None, None
181+
if ctx.needs_input_grad[0]:
182+
grad_a = grad * b
183+
if ctx.needs_input_grad[1]:
184+
grad_b = grad * a
185+
return grad_a, grad_b, None
186+
187+
def _setup_context(ctx, inputs, output):
188+
a, b, c = inputs
189+
saved_a, saved_b = None, None
190+
if ctx.needs_input_grad[0]:
191+
saved_b = b
192+
if ctx.needs_input_grad[1]:
193+
saved_a = a
194+
ctx.save_for_backward(saved_a, saved_b)
195+
196+
# This adds training support for the operator. You must provide us
197+
# the backward formula for the operator and a `setup_context` function
198+
# to save values to be used in the backward.
199+
torch.library.register_autograd(
200+
"extension_cpp::mymuladd", _backward, setup_context=_setup_context)
201+
202+
Note that the backward must be a composition of PyTorch-understood operators.
203+
If you wish to use another custom C++ or CUDA kernel in your backwards pass,
204+
it must be wrapped into a custom op.
205+
206+
So if we had our own custom mymul kernel, we would need to wrap it into a
207+
custom operator and then call that from the backward:
208+
209+
.. code-block:: cpp
210+
// New! a mymul_cpu kernel
211+
at::Tensor mymul_cpu(const at::Tensor& a, const at::Tensor& b) {
212+
TORCH_CHECK(a.sizes() == b.sizes());
213+
TORCH_CHECK(a.dtype() == at::kFloat);
214+
TORCH_CHECK(b.dtype() == at::kFloat);
215+
TORCH_INTERNAL_ASSERT(a.device().type() == at::DeviceType::CPU);
216+
TORCH_INTERNAL_ASSERT(b.device().type() == at::DeviceType::CPU);
217+
at::Tensor a_contig = a.contiguous();
218+
at::Tensor b_contig = b.contiguous();
219+
at::Tensor result = torch::empty(a_contig.sizes(), a_contig.options());
220+
const float* a_ptr = a_contig.data_ptr<float>();
221+
const float* b_ptr = b_contig.data_ptr<float>();
222+
float* result_ptr = result.data_ptr<float>();
223+
for (int64_t i = 0; i < result.numel(); i++) {
224+
result_ptr[i] = a_ptr[i] * b_ptr[i];
225+
}
226+
return result;
227+
}
228+
229+
TORCH_LIBRARY(extension_cpp, m) {
230+
m.def("mymuladd(Tensor a, Tensor b, float c) -> Tensor");
231+
// New! defining the mymul operator
232+
m.def("mymul(Tensor a, Tensor b) -> Tensor");
233+
}
234+
235+
236+
TORCH_LIBRARY_IMPL(extension_cpp, CPU, m) {
237+
m.impl("mymuladd", &mymuladd_cpu);
238+
// New! registering the cpu kernel for the mymul operator
239+
m.impl("mymul", &mymul_cpu);
240+
}
241+
242+
.. code-block:: python
243+
244+
def _backward(ctx, grad):
245+
a, b = ctx.saved_tensors
246+
grad_a, grad_b = None, None
247+
if ctx.needs_input_grad[0]:
248+
grad_a = torch.ops.extension_cpp.mymul.default(grad, b)
249+
if ctx.needs_input_grad[1]:
250+
grad_b = torch.ops.extension_cpp.mymul.default(grad, a)
251+
return grad_a, grad_b, None
252+
253+
254+
def _setup_context(ctx, inputs, output):
255+
a, b, c = inputs
256+
saved_a, saved_b = None, None
257+
if ctx.needs_input_grad[0]:
258+
saved_b = b
259+
if ctx.needs_input_grad[1]:
260+
saved_a = a
261+
ctx.save_for_backward(saved_a, saved_b)
262+
263+
264+
# This adds training support for the operator. You must provide us
265+
# the backward formula for the operator and a `setup_context` function
266+
# to save values to be used in the backward.
267+
torch.library.register_autograd(
268+
"extension_cpp::mymuladd", _backward, setup_context=_setup_context)
269+
270+
How to test an operator
271+
-----------------------
272+
Use torch.library.opcheck to test that the custom op was registered correctly.
273+
This does not test that the gradients are mathematically correct; please write
274+
separate tests for that (either manual ones or torch.autograd.gradcheck).
275+
276+
.. code-block:: python
277+
def sample_inputs(device, *, requires_grad=False):
278+
def make_tensor(*size):
279+
return torch.randn(size, device=device, requires_grad=requires_grad)
280+
281+
def make_nondiff_tensor(*size):
282+
return torch.randn(size, device=device, requires_grad=False)
283+
284+
return [
285+
[make_tensor(3), make_tensor(3), 1],
286+
[make_tensor(20), make_tensor(20), 3.14],
287+
[make_tensor(20), make_nondiff_tensor(20), -123],
288+
[make_nondiff_tensor(2, 3), make_tensor(2, 3), -0.3],
289+
]
290+
291+
def reference_muladd(a, b, c):
292+
return a * b + c
293+
294+
samples = sample_inputs(device, requires_grad=True)
295+
samples.extend(sample_inputs(device, requires_grad=False))
296+
for args in samples:
297+
# Correctness test
298+
result = torch.ops.extension_cpp.mymuladd(*args)
299+
expected = reference_muladd(*args)
300+
torch.testing.assert_close(result, expected)
301+
302+
# Use opcheck to check for incorrect usage of operator registration APIs
303+
torch.library.opcheck(torch.ops.extension_cpp.mymuladd.default, args)
304+
305+
.. _mutable-ops:
306+
307+
How to create mutable operators
308+
-------------------------------
309+
You may wish to author a custom operator that mutates its inputs. Use ``Tensor(a!)``
310+
to specify each mutable Tensor in the schema; otherwise, there will be undefined
311+
behavior. If there are multiple mutated Tensors, use different names (i.e. ``Tensor(a!)``,
312+
``Tensor(b!)``, ``Tensor(c!)``) for each mutable Tensor.
313+
314+
Let's author a ``myadd_out(a, b, out)`` operator, which writes the contents of ``a+b`` into ``out``.
315+
316+
.. code-block:: cpp
317+
// An example of an operator that mutates one of its inputs.
318+
void myadd_out_cpu(const at::Tensor& a, const at::Tensor& b, at::Tensor& out) {
319+
TORCH_CHECK(a.sizes() == b.sizes());
320+
TORCH_CHECK(b.sizes() == out.sizes());
321+
TORCH_CHECK(a.dtype() == at::kFloat);
322+
TORCH_CHECK(b.dtype() == at::kFloat);
323+
TORCH_CHECK(out.dtype() == at::kFloat);
324+
TORCH_CHECK(out.is_contiguous());
325+
TORCH_INTERNAL_ASSERT(a.device().type() == at::DeviceType::CPU);
326+
TORCH_INTERNAL_ASSERT(b.device().type() == at::DeviceType::CPU);
327+
TORCH_INTERNAL_ASSERT(out.device().type() == at::DeviceType::CPU);
328+
at::Tensor a_contig = a.contiguous();
329+
at::Tensor b_contig = b.contiguous();
330+
const float* a_ptr = a_contig.data_ptr<float>();
331+
const float* b_ptr = b_contig.data_ptr<float>();
332+
float* result_ptr = out.data_ptr<float>();
333+
for (int64_t i = 0; i < out.numel(); i++) {
334+
result_ptr[i] = a_ptr[i] + b_ptr[i];
335+
}
336+
}
337+
338+
When defining the operator, we must specify that it mutates the out Tensor in the schema:
339+
340+
.. code-block:: cpp
341+
TORCH_LIBRARY(extension_cpp, m) {
342+
m.def("mymuladd(Tensor a, Tensor b, float c) -> Tensor");
343+
m.def("mymul(Tensor a, Tensor b) -> Tensor");
344+
// New!
345+
m.def("myadd_out(Tensor a, Tensor b, Tensor(a!) out) -> ()");
346+
}
347+
348+
TORCH_LIBRARY_IMPL(extension_cpp, CPU, m) {
349+
m.impl("mymuladd", &mymuladd_cpu);
350+
m.impl("mymul", &mymul_cpu);
351+
// New!
352+
m.impl("myadd_out", &myadd_out_cpu);
353+
}
354+
355+
Please do not return any mutated Tensors as outputs of the operator; this will
356+
run you into problems later down the line.
357+
358+
Conclusion
359+
----------
360+
In this tutorial, we went over the recommended approach to integrating Custom C++
361+
and CUDA operators with PyTorch. The TORCH_LIBRARY/torch.library APIs are fairly
362+
low-level; more detail about how to use them can be found over at
363+
`The Custom Operators Manual <https://pytorch.org/docs/main/notes/custom_operators.html>`_
364+
365+

0 commit comments

Comments
 (0)