Skip to content

Commit 1182e37

Browse files
committed
update
1 parent fffb6b0 commit 1182e37

File tree

4 files changed

+82
-52
lines changed

4 files changed

+82
-52
lines changed

advanced_source/cpp_custom_ops.rst

Lines changed: 54 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,21 @@
1+
.. _cpp-custom-ops-tutorial:
2+
13
Custom C++ and CUDA Operators
24
=============================
35

4-
.. note::
5-
This tutorial is for PyTorch 2.4+ and the PyTorch nightlies.
6+
**Author:** `Richard Zou <https://github.com/zou3519>`_
7+
8+
.. grid:: 2
9+
10+
.. grid-item-card:: :octicon:`mortar-board;1em;` What you will learn
11+
12+
* How to integrate custom operators written in C++/CUDA with PyTorch
13+
* How to test custom operators using ``torch.library.opcheck``
14+
15+
.. grid-item-card:: :octicon:`list-unordered;1em;` Prerequisites
16+
17+
* PyTorch 2.4 or later
18+
* Basic understanding of C++ and CUDA programming
619

720
PyTorch offers a large library of operators that work on Tensors (e.g. torch.add, torch.sum, etc).
821
However, you may wish to bring a new custom operator to PyTorch. This tutorial demonstrates the
@@ -17,11 +30,11 @@ the operation are as follows:
1730
def mymuladd(a: Tensor, b: Tensor, c: float):
1831
return a * b + c
1932
20-
You can find the end-to-end working example for this tutorial over at
21-
https://github.com/pytorch/extension-cpp .
33+
You can find the end-to-end working example for this tutorial
34+
`here <https://github.com/pytorch/extension-cpp>`_ .
2235

23-
Build System
24-
------------
36+
Setting up the Build System
37+
---------------------------
2538

2639
If you are developing custom C++/CUDA code, it must be compiled.
2740
Note that if you’re interfacing with a Python library that already has bindings
@@ -54,7 +67,7 @@ how this is set up.
5467

5568
Defining the custom op and adding backend implementations
5669
---------------------------------------------------------
57-
First, lets write a C++ function that computes ``mymuladd``:
70+
First, let's write a C++ function that computes ``mymuladd``:
5871

5972
.. code-block:: cpp
6073
@@ -86,8 +99,8 @@ Operator registration is a two step-process:
8699
- **Registering backend implementations** - In this step, implementations for various
87100
backends, such as CPU and CUDA, are associated with the operator.
88101

89-
How to define an operator
90-
^^^^^^^^^^^^^^^^^^^^^^^^^
102+
Defining an operator
103+
^^^^^^^^^^^^^^^^^^^^
91104
To define an operator, follow these steps:
92105

93106
1. select a namespace for an operator. We recommend the namespace be the name of your top-level
@@ -110,8 +123,8 @@ To define an operator, follow these steps:
110123
111124
This makes the operator available from Python via ``torch.ops.extension_cpp.mymuladd``.
112125

113-
How to register backend implementations for an operator
114-
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
126+
Registering backend implementations for an operator
127+
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
115128
Use ``TORCH_LIBRARY_IMPL`` to register a backend implementation for the operator.
116129

117130
.. code-block:: cpp
@@ -129,7 +142,7 @@ in a separate ``TORCH_LIBRARY_IMPL`` block:
129142
int idx = blockIdx.x * blockDim.x + threadIdx.x;
130143
if (idx < numel) result[idx] = a[idx] * b[idx] + c;
131144
}
132-
145+
133146
at::Tensor mymuladd_cuda(const at::Tensor& a, const at::Tensor& b, double c) {
134147
TORCH_CHECK(a.sizes() == b.sizes());
135148
TORCH_CHECK(a.dtype() == at::kFloat);
@@ -142,17 +155,17 @@ in a separate ``TORCH_LIBRARY_IMPL`` block:
142155
const float* a_ptr = a_contig.data_ptr<float>();
143156
const float* b_ptr = b_contig.data_ptr<float>();
144157
float* result_ptr = result.data_ptr<float>();
145-
158+
146159
int numel = a_contig.numel();
147160
muladd_kernel<<<(numel+255)/256, 256>>>(numel, a_ptr, b_ptr, c, result_ptr);
148161
return result;
149162
}
150-
163+
151164
TORCH_LIBRARY_IMPL(extension_cpp, CUDA, m) {
152165
m.impl("mymuladd", &mymuladd_cuda);
153166
}
154167
155-
How to add torch.compile support for an operator
168+
Adding ``torch.compile`` support for an operator
156169
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
157170

158171
To add ``torch.compile`` support for an operator, we must add a FakeTensor kernel (also
@@ -182,8 +195,8 @@ for more details).
182195
torch._check(a.device == b.device)
183196
return torch.empty_like(a)
184197
185-
How to set up hybrid Python/C++ registration
186-
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
198+
Setting up hybrid Python/C++ registration
199+
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
187200
In this tutorial, we defined a custom operator in C++, added CPU/CUDA
188201
implementations in C++, and added ``FakeTensor`` kernels and backward formulas
189202
in Python. The order in which these registrations are loaded (or imported)
@@ -199,9 +212,9 @@ of two ways:
199212
2. If your C++ custom operator is located in a shared library object, you can
200213
also use ``torch.ops.load_library("/path/to/library.so")`` to load it.
201214

202-
203-
How to add training (autograd) support for an operator
204-
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
215+
216+
Adding training (autograd) support for an operator
217+
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
205218
Use ``torch.library.register_autograd`` to add training support for an operator. Prefer
206219
this over directly using Python ``torch.autograd.Function`` or C++ ``torch::autograd::Function``;
207220
you must use those in a very specific way to avoid silent incorrectness (see
@@ -218,7 +231,7 @@ for more details).
218231
if ctx.needs_input_grad[1]:
219232
grad_b = grad * a
220233
return grad_a, grad_b, None
221-
234+
222235
def _setup_context(ctx, inputs, output):
223236
a, b, c = inputs
224237
saved_a, saved_b = None, None
@@ -227,7 +240,7 @@ for more details).
227240
if ctx.needs_input_grad[1]:
228241
saved_a = a
229242
ctx.save_for_backward(saved_a, saved_b)
230-
243+
231244
# This code adds training support for the operator. You must provide us
232245
# the backward formula for the operator and a `setup_context` function
233246
# to save values to be used in the backward.
@@ -248,8 +261,8 @@ custom operator and then call that from the backward:
248261
TORCH_CHECK(a.sizes() == b.sizes());
249262
TORCH_CHECK(a.dtype() == at::kFloat);
250263
TORCH_CHECK(b.dtype() == at::kFloat);
251-
TORCH_INTERNAL_ASSERT(a.device().type() == at::DeviceType::CPU);
252-
TORCH_INTERNAL_ASSERT(b.device().type() == at::DeviceType::CPU);
264+
TORCH_CHECK(a.device().type() == at::DeviceType::CPU);
265+
TORCH_CHECK(b.device().type() == at::DeviceType::CPU);
253266
at::Tensor a_contig = a.contiguous();
254267
at::Tensor b_contig = b.contiguous();
255268
at::Tensor result = torch::empty(a_contig.sizes(), a_contig.options());
@@ -261,14 +274,14 @@ custom operator and then call that from the backward:
261274
}
262275
return result;
263276
}
264-
277+
265278
TORCH_LIBRARY(extension_cpp, m) {
266279
m.def("mymuladd(Tensor a, Tensor b, float c) -> Tensor");
267280
// New! defining the mymul operator
268281
m.def("mymul(Tensor a, Tensor b) -> Tensor");
269282
}
270-
271-
283+
284+
272285
TORCH_LIBRARY_IMPL(extension_cpp, CPU, m) {
273286
m.impl("mymuladd", &mymuladd_cpu);
274287
// New! registering the cpu kernel for the mymul operator
@@ -285,8 +298,8 @@ custom operator and then call that from the backward:
285298
if ctx.needs_input_grad[1]:
286299
grad_b = torch.ops.extension_cpp.mymul.default(grad, a)
287300
return grad_a, grad_b, None
288-
289-
301+
302+
290303
def _setup_context(ctx, inputs, output):
291304
a, b, c = inputs
292305
saved_a, saved_b = None, None
@@ -295,16 +308,16 @@ custom operator and then call that from the backward:
295308
if ctx.needs_input_grad[1]:
296309
saved_a = a
297310
ctx.save_for_backward(saved_a, saved_b)
298-
299-
311+
312+
300313
# This code adds training support for the operator. You must provide us
301314
# the backward formula for the operator and a `setup_context` function
302315
# to save values to be used in the backward.
303316
torch.library.register_autograd(
304317
"extension_cpp::mymuladd", _backward, setup_context=_setup_context)
305318
306-
How to test an operator
307-
-----------------------
319+
Testing an operator
320+
-------------------
308321
Use ``torch.library.opcheck`` to test that the custom op was registered correctly.
309322
Note that this function does not test that the gradients are mathematically correct
310323
-- plan to write separate tests for that, either manual ones or by using
@@ -315,36 +328,36 @@ Note that this function does not test that the gradients are mathematically corr
315328
def sample_inputs(device, *, requires_grad=False):
316329
def make_tensor(*size):
317330
return torch.randn(size, device=device, requires_grad=requires_grad)
318-
331+
319332
def make_nondiff_tensor(*size):
320333
return torch.randn(size, device=device, requires_grad=False)
321-
334+
322335
return [
323336
[make_tensor(3), make_tensor(3), 1],
324337
[make_tensor(20), make_tensor(20), 3.14],
325338
[make_tensor(20), make_nondiff_tensor(20), -123],
326339
[make_nondiff_tensor(2, 3), make_tensor(2, 3), -0.3],
327340
]
328-
341+
329342
def reference_muladd(a, b, c):
330343
return a * b + c
331-
344+
332345
samples = sample_inputs(device, requires_grad=True)
333346
samples.extend(sample_inputs(device, requires_grad=False))
334347
for args in samples:
335348
# Correctness test
336349
result = torch.ops.extension_cpp.mymuladd(*args)
337350
expected = reference_muladd(*args)
338351
torch.testing.assert_close(result, expected)
339-
352+
340353
# Use opcheck to check for incorrect usage of operator registration APIs
341354
torch.library.opcheck(torch.ops.extension_cpp.mymuladd.default, args)
342355
343356
.. _mutable-ops:
344357

345-
How to create mutable operators
346-
-------------------------------
347-
You may wish to author a custom operator that mutates its inputs. Use ``Tensor(a!)``
358+
Creating mutable operators
359+
--------------------------
360+
You may wish to author a custom operator that mutates its inputs. Use ``Tensor(a!)``
348361
to specify each mutable Tensor in the schema; otherwise, there will be undefined
349362
behavior. If there are multiple mutated Tensors, use different names (for example, ``Tensor(a!)``,
350363
``Tensor(b!)``, ``Tensor(c!)``) for each mutable Tensor.

advanced_source/custom_ops_landing_page.rst

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,8 @@ APIs.
1313
TL;DR
1414
-----
1515

16-
How do I author a custom op from Python?
17-
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
16+
Authoring a custom operator from Python
17+
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
1818

1919
Please see :ref:`python-custom-ops-tutorial`.
2020

@@ -24,8 +24,8 @@ respect to ``torch.compile`` and ``torch.export``.
2424
- you have some Python bindings to C++/CUDA kernels and want those to compose with PyTorch
2525
subsystems (like ``torch.compile`` or ``torch.autograd``)
2626

27-
How do I integrate custom C++ and/or CUDA code with PyTorch?
28-
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
27+
Integrating custom C++ and/or CUDA code with PyTorch
28+
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
2929

3030
Please see :ref:`cpp-custom-ops-tutorial`.
3131

advanced_source/python_custom_ops.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,16 @@
66
Python Custom Operators
77
=======================
88
9-
.. note::
10-
This tutorial is for PyTorch 2.4+ and the PyTorch nightlies.
9+
.. grid:: 2
10+
11+
.. grid-item-card:: :octicon:`mortar-board;1em;` What you will learn
12+
13+
* How to integrate custom operators written in Python with PyTorch
14+
* How to test custom operators using ``torch.library.opcheck``
15+
16+
.. grid-item-card:: :octicon:`list-unordered;1em;` Prerequisites
17+
18+
* PyTorch 2.4 or later
1119
1220
PyTorch offers a large library of operators that work on Tensors (e.g.
1321
``torch.add``, ``torch.sum``, etc). However, you might wish to use a new customized
@@ -170,7 +178,7 @@ def setup_context(ctx, inputs, output):
170178

171179
######################################################################
172180
# Testing Python Custom operators
173-
# -------------------------
181+
# -------------------------------
174182
# Use ``torch.library.opcheck`` to test that the custom operator was registered
175183
# correctly. This does not test that the gradients are mathematically correct;
176184
# please write separate tests for that (either manual ones or ``torch.autograd.gradcheck``).
@@ -192,7 +200,7 @@ def setup_context(ctx, inputs, output):
192200

193201
######################################################################
194202
# Mutable Python Custom operators
195-
# -------------------------
203+
# -------------------------------
196204
# You can also wrap a Python function that mutates its inputs into a custom
197205
# operator.
198206
# Functions that mutate inputs are common because that is how many low-level

index.rst

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -416,6 +416,13 @@ Welcome to PyTorch Tutorials
416416
:link: advanced/cpp_frontend.html
417417
:tags: Frontend-APIs,C++
418418

419+
.. customcarditem::
420+
:header: Python Custom Operators Landing Page
421+
:card_description: This is the landing page for all things related to custom operators in PyTorch.
422+
:image: _static/img/thumbnails/cropped/Custom-Cpp-and-CUDA-Extensions.png
423+
:link: advanced/custom_ops_landing_page.html
424+
:tags: Extending-PyTorch,Frontend-APIs,C++,CUDA
425+
419426
.. customcarditem::
420427
:header: Python Custom Operators
421428
:card_description: Create Custom Operators in Python. Useful for black-boxing a Python function for use with torch.compile.
@@ -588,7 +595,7 @@ Welcome to PyTorch Tutorials
588595

589596
.. customcarditem::
590597
:header: (beta) Accelerating BERT with semi-structured sparsity
591-
:card_description: Train BERT, prune it to be 2:4 sparse, and then accelerate it to achieve 2x inference speedups with semi-structured sparsity and torch.compile.
598+
:card_description: Train BERT, prune it to be 2:4 sparse, and then accelerate it to achieve 2x inference speedups with semi-structured sparsity and torch.compile.
592599
:image: _static/img/thumbnails/cropped/Pruning-Tutorial.png
593600
:link: advanced/semi_structured_sparse.html
594601
:tags: Text,Model-Optimization
@@ -797,7 +804,7 @@ Welcome to PyTorch Tutorials
797804

798805
.. customcarditem::
799806
:header: Using the ExecuTorch SDK to Profile a Model
800-
:card_description: Explore how to use the ExecuTorch SDK to profile, debug, and visualize ExecuTorch models
807+
:card_description: Explore how to use the ExecuTorch SDK to profile, debug, and visualize ExecuTorch models
801808
:image: _static/img/ExecuTorch-Logo-cropped.svg
802809
:link: https://pytorch.org/executorch/stable/tutorials/sdk-integration-tutorial.html
803810
:tags: Edge
@@ -935,6 +942,7 @@ Additional Resources
935942
beginner/basics/autogradqs_tutorial
936943
beginner/basics/optimization_tutorial
937944
beginner/basics/saveloadrun_tutorial
945+
advanced/custom_ops_landing_page
938946

939947
.. toctree::
940948
:maxdepth: 2
@@ -1081,6 +1089,7 @@ Additional Resources
10811089
:hidden:
10821090
:caption: Extending PyTorch
10831091

1092+
advanced/custom_ops_landing_page
10841093
advanced/python_custom_ops
10851094
advanced/cpp_custom_ops
10861095
intermediate/custom_function_double_backward_tutorial
@@ -1153,7 +1162,7 @@ Additional Resources
11531162
Using the ExecuTorch SDK to Profile a Model <https://pytorch.org/executorch/stable/tutorials/sdk-integration-tutorial.html>
11541163
Building an ExecuTorch iOS Demo App <https://pytorch.org/executorch/stable/demo-apps-ios.html>
11551164
Building an ExecuTorch Android Demo App <https://pytorch.org/executorch/stable/demo-apps-android.html>
1156-
Lowering a Model as a Delegate <https://pytorch.org/executorch/stable/examples-end-to-end-to-lower-model-to-delegate.html>
1165+
Lowering a Model as a Delegate <https://pytorch.org/executorch/stable/examples-end-to-end-to-lower-model-to-delegate.html>
11571166

11581167
.. toctree::
11591168
:maxdepth: 2

0 commit comments

Comments
 (0)