Skip to content

Commit 8faf499

Browse files
committed
Dispatcher tutorial
Signed-off-by: Edward Z. Yang <ezyang@fb.com>
1 parent 32e5407 commit 8faf499

File tree

5 files changed

+411
-0
lines changed

5 files changed

+411
-0
lines changed

advanced_source/dispatcher.rst

Lines changed: 286 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,286 @@
1+
Dispatcher in C++
2+
=================
3+
4+
The dispatcher is an internal component of PyTorch which is responsible for
5+
figuring out what code should actually get run when you call a function like
6+
``torch::add``. This can be nontrivial, because PyTorch operations need
7+
to handle a lot of cross-cutting concerns that are "layered" on top of one
8+
of another. Here is a sampling of some of the things it handles:
9+
10+
* Switching between the CPU and CUDA implementations of an operator, depending
11+
on the devices of the input tensors.
12+
* Switching between the autograd and backend implementations of an operator,
13+
depending on whether or not autograd handling is necessary.
14+
* Applying autocasting when necessary for automatic mixed precision.
15+
* Applying batching rules when an operator is run under a ``vmap`` call.
16+
* Tracing execution of operations, if you are tracing a model for export.
17+
18+
If in your `custom operator code <torch_script_custom_ops>`_ you find yourself
19+
manually writing if statements to handle these cases, the dispatcher APIs can
20+
help organize your code. (Conversely, if your custom operator is very simple
21+
and is only for CPU inference, you probably don't need to use the dispatcher,
22+
just use the basic API.)
23+
24+
In this tutorial, we will describe how to structure a custom operator
25+
registration to use the dispatcher to organize various components. We'll
26+
assume that you are familiar with how to
27+
`register an operator <torch_script_custom_ops>`_ and how to write
28+
a `custom autograd function <cpp_autograd>`_.
29+
30+
Defining schema and backend implementations
31+
-------------------------------------------
32+
33+
The general principle behind the dispatcher is that it divides the
34+
implementation of an operator into multiple kernels, each of which
35+
implements functionality for a specific *dispatch key*; for example,
36+
`CPU`, `CUDA` or `Autograd`. The end effect is that when you call
37+
an operator, we first execute the `Autograd` kernel, and then we
38+
redispatch to the `CPU` or `CUDA` kernel depending on the device
39+
types of the passed in tensors.
40+
41+
Let's take a look at the various parts involved in making this
42+
happen. First, we must define the schema for the operator in question.
43+
Unlike simple pybind11-style operator registration, we don't actually
44+
provide an implementation of our operator at this point; we just
45+
provide a schema string specifying the type signature of the operator
46+
that all of our other kernels will abide by:
47+
48+
.. literalinclude:: ../advanced_source/dispatcher/op.cpp
49+
:language: cpp
50+
:start-after: BEGIN TORCH_LIBRARY
51+
:end-before: END TORCH_LIBRARY
52+
53+
Next, we need to actually provide some implementations of this operator.
54+
For concreteness, here is a really simple implementation of addition on CPU:
55+
56+
.. literalinclude:: ../advanced_source/dispatcher/op.cpp
57+
:language: cpp
58+
:start-after: BEGIN myadd_cpu
59+
:end-before: END myadd_cpu
60+
61+
We'd like to register this function as an implementation of ``myops::myadd``, but we
62+
don't want to register it as a catch-all kernel to be run in all cases; we
63+
only want it to be run when we call ``myops::myadd`` at the backend on CPU tensors.
64+
To do this, we can use the ``TORCH_LIBRARY_IMPL`` macro:
65+
66+
.. literalinclude:: ../advanced_source/dispatcher/op.cpp
67+
:language: cpp
68+
:start-after: BEGIN TORCH_LIBRARY_IMPL CPU
69+
:end-before: END TORCH_LIBRARY_IMPL CPU
70+
71+
The ``TORCH_LIBRARY_IMPL`` lets us register implementations for operators on
72+
a specific dispatch key (in this case, ``CPU``). Each call to ``impl``
73+
associates a CPU kernel with the corresponding operator (which we previously
74+
defined in the ``TORCH_LIBRARY`` block). You can have as many
75+
``TORCH_LIBRARY_IMPL`` blocks for a namespace as you like; so for example,
76+
if we also have a CUDA implementation ``myadd_cuda``, we can register it
77+
with:
78+
79+
.. literalinclude:: ../advanced_source/dispatcher/op.cpp
80+
:language: cpp
81+
:start-after: BEGIN TORCH_LIBRARY_IMPL CUDA
82+
:end-before: END TORCH_LIBRARY_IMPL CUDA
83+
84+
These registrations can be split across files or even across library boundaries; so
85+
for example, you could have these two ``TORCH_LIBRARY_IMPL`` blocks compiled
86+
into a separate ``myops_cpu`` and ``myops_cuda`` dynamic library.
87+
88+
.. note::
89+
90+
Did you know that you can also write ``TORCH_LIBRARY_IMPL`` blocks for existing
91+
core operators in PyTorch? This is how XLA support for PyTorch is
92+
implemented: the ``torch_xla`` library contains a ``TORCH_LIBRARY_IMPL``
93+
that provides implementations for all basic operators on the XLA dispatch
94+
key.
95+
96+
Adding autograd support
97+
-----------------------
98+
99+
At this point, we have an operator with both CPU and CUDA implementations. How
100+
can we add autograd support to it? As you might guess, we will register an
101+
autograd kernel (similar to what's described in the `custom autograd function <cpp_autograd>`_ tutorial)!
102+
However, there is a twist: unlike the CPU and CUDA kernels, the autograd kernel
103+
needs to *redispatch*: it needs to call back into the dispatcher to get to
104+
the final CPU and CUDA implementations.
105+
106+
Thus, before we write the autograd kernel, let's write a *dispatching function*
107+
which calls into the dispatcher to find the right kernel for your operator.
108+
This function constitutes the public C++ API for your operators--in fact, all of
109+
the tensor functions in PyTorch's C++ API all call the dispatcher in the same
110+
way under the hood. Here's what the dispatching function looks like:
111+
112+
.. literalinclude:: ../advanced_source/dispatcher/op.cpp
113+
:language: cpp
114+
:start-after: myadd
115+
:end-before: myadd
116+
117+
Let's break it down:
118+
119+
* In the first line, we look up a typed operator handle from the dispatcher
120+
corresponding to the operator that we are going to dispatch to.
121+
``findSchemaOrThrow`` takes two arguments: the (namespace qualified) name
122+
of the operator, and the overload name of the operator (typically just
123+
the empty string). ``typed`` casts the dynamically typed handle into
124+
a statically typed handle (doing a runtime test to make sure you've given
125+
the correct C++ type), so that we can do a normal C++ call on it. We
126+
pass it ``decltype(myadd)`` since the type of the dispatching function is
127+
the same as the type of the underlying kernels registered to the dispatcher.
128+
129+
For performance, this computation is done in a static variable, so that
130+
we only need to do the (slow) lookup once. If you typoed the name of the
131+
operator you want to call, this lookup will error the first time you call this
132+
function.
133+
134+
* In the second line, we simply ``call`` the operator handle with all of the
135+
arguments passed into the dispatching function. This will actually invoke
136+
the dispatcher and in the end control will be transferred to whatever kernel
137+
is appropriate for this call.
138+
139+
With the dispatch function in hand, we can now write the autograd kernel:
140+
141+
.. literalinclude:: ../advanced_source/dispatcher/op.cpp
142+
:language: cpp
143+
:start-after: myadd_autograd
144+
:end-before: myadd_autograd
145+
146+
The autograd function is written as normal using ``torch::autograd::Function``,
147+
except that instead of directly writing the implementation in ``forward()``,
148+
we:
149+
150+
1. Turn off autograd handling with the `at::AutoNonVariableTypeMode`` RAII
151+
guard, and then
152+
2. Call the dispatch function ``myadd`` to call back into the dispatcher.
153+
154+
Without (1), your calls will infinite loop (and stack overflow), because
155+
``myadd`` will send you back to the autograd implementation! With (1),
156+
the redispatch will skip over autograd and go to the next handlers,
157+
which will either be CPU and CUDA.
158+
159+
We can now register this function in the same way we registered the CPU/CUDA
160+
functions:
161+
162+
.. literalinclude:: ../advanced_source/dispatcher/op.cpp
163+
:language: cpp
164+
:start-after: BEGIN TORCH_LIBRARY_IMPL Autograd
165+
:end-before: END TORCH_LIBRARY_IMPL Autograd
166+
167+
Going beyond autograd
168+
---------------------
169+
170+
In some sense, the dispatcher isn't doing all that much: all it does is
171+
implement a glorified if-statement, along the lines of this:
172+
173+
.. code-block:: cpp
174+
175+
class MyAddFunction : ... {
176+
public:
177+
static Tensor forward(
178+
AutogradContext *ctx, torch::Tensor self, torch::Tensor other) {
179+
180+
if (self.device().type() == DeviceType::CPU) {
181+
return add_cpu(self, other);
182+
} else if (self.device().type() == DeviceType::CUDA) {
183+
return add_cuda(self, other);
184+
} else {
185+
TORCH_CHECK(0, "Unsupported device ", self.device().type());
186+
}
187+
}
188+
...
189+
}
190+
191+
So why use the dispatcher? There are a few reasons:
192+
193+
1. It is decentralized. You can assemble all of the pieces of an operator
194+
(CPU, CUDA, Autograd) without having to write a single, centralized
195+
if statement that refers to all of them. Importantly, third parties can
196+
register extra implementations for other aspects without having to patch the
197+
original definition of an operator.
198+
199+
2. It supports more dispatch keys than CPU, CUDA and Autograd. You can
200+
see a full list of dispatch keys that are currently implemented
201+
in PyTorch in ``c10/core/DispatchKey.h``. These dispatch keys
202+
implement a variety of optional functionality for operators, and if you
203+
decide you want your custom operator to support this functionality,
204+
all you have to register a kernel for the appropriate key.
205+
206+
3. The dispatcher implements support for boxed fallback functions, which
207+
are functions that can be implemented once and apply to all operators
208+
in the system. Boxed fallbacks can be used to provide default behavior
209+
for a dispatch key; if you use the dispatcher to implement your operator,
210+
you also opt into the fallbacks for all of these operations.
211+
212+
Here are some particular dispatch keys which you may need to define an operator
213+
for.
214+
215+
Autocast
216+
^^^^^^^^
217+
218+
The Autocast dispatch key implements support for
219+
`automatic mixed precision <https://developer.nvidia.com/automatic-mixed-precision>`_
220+
(AMP). An autocast kernel typically modifies the operation of an operator by casting the
221+
input arguments to some precision before carrying out the operation. For some
222+
operations, it is numerically safe to cast to lower precision, which is how AMP
223+
can achieve speed ups and reduced memory usage without sacrificing much
224+
accuracy. A nontrivial autocast kernel looks something like this:
225+
226+
.. code-block:: cpp
227+
228+
Tensor mymatmul_autocast(const Tensor& self, const Tensor& other) {
229+
c10::impl::ExcludeDispatchKeyGuard no_autocast(c10::DispatchKey::Autocast);
230+
return mymatmul(autocast::_cast(at::kHalf, self), autocast::_cast(at::kHalf, other));
231+
}
232+
233+
Notice that, like our autograd kernels, we exclude the ``Autocast`` key from
234+
dispatch before redispatching. By default, if no autocast kernel is provided,
235+
we simply fallthrough directly to the regular operator implementation (no
236+
autocasting occurs.) (We didn't use ``myadd`` for this example, since pointwise
237+
addition doesn't do autocasting and should just fall through).
238+
239+
When should an autocast kernel be registered? Unfortunately, there aren't
240+
cut-and-dry rules for when you should cast to a lower precision. You can
241+
get a sense for what operators have autocasting behavior by looking at
242+
the `AMP documentation
243+
<https://pytorch.org/docs/master/amp.html#op-specific-behavior>`_. Some other
244+
general rules:
245+
246+
* Operations that do reductions should be carried out in float32,
247+
* Any operation with multiple float tensor inputs has to standardize them
248+
to a common precision, and
249+
* Any operation that does a convolution or gemm under the hood should
250+
probably be float16
251+
252+
..
253+
254+
NB: This doesn't work because torch.ops doesn't support names.
255+
256+
Named
257+
^^^^^
258+
259+
`Named tensors <https://pytorch.org/docs/stable/named_tensor.html>`_ allow
260+
users to associate explicit names with tensor dimensions, and then have those
261+
dimensions be propagated when you run operations on those tensors. If you
262+
define a new operator, you have to also define rules for how names should
263+
be checked and propagated. The Named kernel handles implementing these rules.
264+
265+
.. literalinclude:: ../advanced_source/dispatcher/op.cpp
266+
:language: cpp
267+
:start-after: BEGIN TORCH_LIBRARY_IMPL Named
268+
:end-before: END TORCH_LIBRARY_IMPL Named
269+
270+
Batched
271+
^^^^^^^
272+
273+
Batched tensors allow you to write your code in a per-example manner, and then
274+
have them be automatically batched when run under a ``vmap`` invocation. The
275+
API for writing batching rules is currently under development, but once it is
276+
stabilized, you can add support for ``vmap`` for your operators by registering
277+
a kernel at the Batched dispatch key.
278+
279+
Tracer
280+
^^^^^^
281+
282+
The Tracer dispatch key implements support for recording invocations of operators
283+
into a trace when you run ``torch.jit.trace``. We intend to provide a
284+
boxed fallback that will implement tracing for arbitrary operations,
285+
see `issue #41478 <https://github.com/pytorch/pytorch/issues/41478>` to track
286+
progress.
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
cmake_minimum_required(VERSION 3.1 FATAL_ERROR)
2+
project(dispatcher)
3+
4+
find_package(Torch REQUIRED)
5+
6+
add_library(dispatcher SHARED op.cpp)
7+
target_compile_features(dispatcher PRIVATE cxx_std_14)
8+
target_link_libraries(dispatcher "${TORCH_LIBRARIES}")

0 commit comments

Comments
 (0)