Skip to content

Commit b6ffdb9

Browse files
authored
Dispatcher tutorial (#1072)
* Dispatcher tutorial Signed-off-by: Edward Z. Yang <ezyang@fb.com> * typofix Signed-off-by: Edward Z. Yang <ezyang@fb.com> * morefix Signed-off-by: Edward Z. Yang <ezyang@fb.com>
1 parent 0075e38 commit b6ffdb9

File tree

5 files changed

+406
-0
lines changed

5 files changed

+406
-0
lines changed

advanced_source/dispatcher.rst

Lines changed: 281 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,281 @@
1+
Dispatcher in C++
2+
=================
3+
4+
The dispatcher is an internal component of PyTorch which is responsible for
5+
figuring out what code should actually get run when you call a function like
6+
``torch::add``. This can be nontrivial, because PyTorch operations need
7+
to handle a lot of cross-cutting concerns that are "layered" on top of one
8+
of another. Here is a sampling of some of the things it handles:
9+
10+
* Switching between the CPU and CUDA implementations of an operator, depending
11+
on the devices of the input tensors.
12+
* Switching between the autograd and backend implementations of an operator,
13+
depending on whether or not autograd handling is necessary.
14+
* Applying autocasting when necessary for automatic mixed precision.
15+
* Applying batching rules when an operator is run under a ``vmap`` call.
16+
* Tracing execution of operations, if you are tracing a model for export.
17+
18+
If in your `custom operator code <torch_script_custom_ops>`_ you find yourself
19+
manually writing if statements to handle these cases, the dispatcher APIs can
20+
help organize your code. (Conversely, if your custom operator is very simple
21+
and is only for CPU inference, you probably don't need to use the dispatcher,
22+
just use the basic API.)
23+
24+
In this tutorial, we will describe how to structure a custom operator
25+
registration to use the dispatcher to organize various components. We'll
26+
assume that you are familiar with how to
27+
`register an operator <torch_script_custom_ops>`_ and how to write
28+
a `custom autograd function <cpp_autograd>`_.
29+
30+
Defining schema and backend implementations
31+
-------------------------------------------
32+
33+
The general principle behind the dispatcher is that it divides the
34+
implementation of an operator into multiple kernels, each of which implements
35+
functionality for a specific *dispatch key*; for example, CPU, CUDA or Autograd.
36+
The dispatcher determines what the highest priority dispatch key is at the time
37+
you call an operator (this is done by looking at both the tensor arguments as
38+
well as some thread local state), and transfers control to the kernel for that
39+
dispatch key. The end effect is that when you call an operator, we first
40+
execute the Autograd kernel, and then we redispatch to the CPU or CUDA kernel
41+
depending on the device types of the passed in tensors.
42+
43+
Let's take a look at the various parts involved in making this
44+
happen. First, we must define the schema for the operator in question.
45+
Unlike simple pybind11-style operator registration, we don't actually
46+
provide an implementation of our operator at this point; we just
47+
provide a schema string specifying the type signature of the operator
48+
that all of our other kernels will abide by:
49+
50+
.. literalinclude:: ../advanced_source/dispatcher/op.cpp
51+
:language: cpp
52+
:start-after: BEGIN TORCH_LIBRARY
53+
:end-before: END TORCH_LIBRARY
54+
55+
Next, we need to actually provide some implementations of this operator.
56+
For concreteness, here is a really simple implementation of addition on CPU:
57+
58+
.. literalinclude:: ../advanced_source/dispatcher/op.cpp
59+
:language: cpp
60+
:start-after: BEGIN myadd_cpu
61+
:end-before: END myadd_cpu
62+
63+
We'd like to register this function as an implementation of ``myops::myadd``.
64+
However, the simple way of registering it (``def("myadd", myadd_cpu)``) would
65+
register the kernel to run in all cases, even if the tensor is not a CPU
66+
tensor! (Internally, we refer to these as "catch-all" kernels, since they
67+
catch all cases.) To ensure that ``myadd_cpu`` is only run for
68+
CPU tensors, we can use the ``TORCH_LIBRARY_IMPL`` macro:
69+
70+
.. literalinclude:: ../advanced_source/dispatcher/op.cpp
71+
:language: cpp
72+
:start-after: BEGIN TORCH_LIBRARY_IMPL CPU
73+
:end-before: END TORCH_LIBRARY_IMPL CPU
74+
75+
The ``TORCH_LIBRARY_IMPL`` lets us register implementations for operators on
76+
a specific dispatch key (in this case, CPU). Each call to ``impl``
77+
associates a CPU kernel with the corresponding operator (which we previously
78+
defined in the ``TORCH_LIBRARY`` block). If we also have a CUDA implementation ``myadd_cuda``,
79+
we can register it in a separate ``TORCH_LIBRARY_IMPL`` block:
80+
81+
.. literalinclude:: ../advanced_source/dispatcher/op.cpp
82+
:language: cpp
83+
:start-after: BEGIN TORCH_LIBRARY_IMPL CUDA
84+
:end-before: END TORCH_LIBRARY_IMPL CUDA
85+
86+
These registrations can be split across files or even across library boundaries; so
87+
for example, you could have these two ``TORCH_LIBRARY_IMPL`` blocks compiled
88+
into a separate ``myops_cpu`` and ``myops_cuda`` dynamic libraries. Generally,
89+
speaking, the structure of your registrations will look like this:
90+
91+
1. A single ``TORCH_LIBRARY`` that lists every custom operator in your namespace
92+
in a centralized place.
93+
2. A ``TORCH_LIBRARY_IMPL`` per dispatch key that registers implementations for
94+
that key (e.g., CPU or CUDA). If you like, you can further subdivide
95+
``TORCH_LIBRARY_IMPL`` blocks into a block per operator. This is convenient
96+
if you have a separate file per operator implementation, but don't want to
97+
expose the operators in a header; you can just put the registration in the
98+
cpp file that defines your operator.
99+
100+
.. note::
101+
102+
Did you know that you can also write ``TORCH_LIBRARY_IMPL`` blocks for existing
103+
core operators in PyTorch? This is how XLA support for PyTorch is
104+
implemented: the ``torch_xla`` library contains a ``TORCH_LIBRARY_IMPL``
105+
that provides implementations for all basic operators on the XLA dispatch
106+
key.
107+
108+
Adding autograd support
109+
-----------------------
110+
111+
At this point, we have an operator with both CPU and CUDA implementations. How
112+
can we add autograd support to it? As you might guess, we will register an
113+
autograd kernel (similar to what's described in the `custom autograd function <cpp_autograd>`_ tutorial)!
114+
However, there is a twist: unlike the CPU and CUDA kernels, the autograd kernel
115+
needs to *redispatch*: it needs to call back into the dispatcher to get to
116+
the final CPU and CUDA implementations.
117+
118+
Thus, before we write the autograd kernel, let's write a *dispatching function*
119+
which calls into the dispatcher to find the right kernel for your operator.
120+
This function constitutes the public C++ API for your operators--in fact, all of
121+
the tensor functions in PyTorch's C++ API all call the dispatcher in the same
122+
way under the hood. Here's what the dispatching function looks like:
123+
124+
.. literalinclude:: ../advanced_source/dispatcher/op.cpp
125+
:language: cpp
126+
:start-after: BEGIN myadd
127+
:end-before: END myadd
128+
129+
Let's break it down:
130+
131+
* In the first line, we look up a typed operator handle from the dispatcher
132+
corresponding to the operator that we are going to dispatch to.
133+
``findSchemaOrThrow`` takes two arguments: the (namespace qualified) name
134+
of the operator, and the overload name of the operator (typically just
135+
the empty string). ``typed`` casts the dynamically typed handle into
136+
a statically typed handle (doing a runtime test to make sure you've given
137+
the correct C++ type), so that we can do a normal C++ call on it. We
138+
pass it ``decltype(myadd)`` since the type of the dispatching function is
139+
the same as the type of the underlying kernels registered to the dispatcher.
140+
141+
For performance, this computation is done in a static variable, so that
142+
we only need to do the (slow) lookup once. If you typoed the name of the
143+
operator you want to call, this lookup will error the first time you call this
144+
function.
145+
146+
* In the second line, we simply ``call`` the operator handle with all of the
147+
arguments passed into the dispatching function. This will actually invoke
148+
the dispatcher and in the end control will be transferred to whatever kernel
149+
is appropriate for this call.
150+
151+
With the dispatch function in hand, we can now write the autograd kernel:
152+
153+
.. literalinclude:: ../advanced_source/dispatcher/op.cpp
154+
:language: cpp
155+
:start-after: BEGIN myadd_autograd
156+
:end-before: END myadd_autograd
157+
158+
The autograd function is written as normal using ``torch::autograd::Function``,
159+
except that instead of directly writing the implementation in ``forward()``,
160+
we:
161+
162+
1. Turn off autograd handling with the ``at::AutoNonVariableTypeMode`` RAII
163+
guard, and then
164+
2. Call the dispatch function ``myadd`` to call back into the dispatcher.
165+
166+
Without (1), your calls will infinite loop (and stack overflow), because
167+
``myadd`` will send you back to this function (as the highest priority dispatch
168+
key would still be autograd.) With (1),
169+
autograd is excluded from the set of dispatch keys under consideration, and
170+
we will go to the next handlers, which will either be CPU and CUDA.
171+
172+
We can now register this function in the same way we registered the CPU/CUDA
173+
functions:
174+
175+
.. literalinclude:: ../advanced_source/dispatcher/op.cpp
176+
:language: cpp
177+
:start-after: BEGIN TORCH_LIBRARY_IMPL Autograd
178+
:end-before: END TORCH_LIBRARY_IMPL Autograd
179+
180+
Going beyond autograd
181+
---------------------
182+
183+
In some sense, the dispatcher isn't doing all that much: all it does is
184+
implement a glorified if-statement, along the lines of this:
185+
186+
.. code-block:: cpp
187+
188+
class MyAddFunction : ... {
189+
public:
190+
static Tensor forward(
191+
AutogradContext *ctx, torch::Tensor self, torch::Tensor other) {
192+
193+
if (self.device().type() == DeviceType::CPU) {
194+
return add_cpu(self, other);
195+
} else if (self.device().type() == DeviceType::CUDA) {
196+
return add_cuda(self, other);
197+
} else {
198+
TORCH_CHECK(0, "Unsupported device ", self.device().type());
199+
}
200+
}
201+
...
202+
}
203+
204+
So why use the dispatcher? There are a few reasons:
205+
206+
1. It is decentralized. You can assemble all of the pieces of an operator
207+
(CPU, CUDA, Autograd) without having to write a single, centralized
208+
if statement that refers to all of them. Importantly, third parties can
209+
register extra implementations for other aspects without having to patch the
210+
original definition of an operator.
211+
212+
2. It supports more dispatch keys than CPU, CUDA and Autograd. You can
213+
see a full list of dispatch keys that are currently implemented
214+
in PyTorch in ``c10/core/DispatchKey.h``. These dispatch keys
215+
implement a variety of optional functionality for operators, and if you
216+
decide you want your custom operator to support this functionality,
217+
all you have to register a kernel for the appropriate key.
218+
219+
3. The dispatcher implements support for boxed fallback functions, which
220+
are functions that can be implemented once and apply to all operators
221+
in the system. Boxed fallbacks can be used to provide default behavior
222+
for a dispatch key; if you use the dispatcher to implement your operator,
223+
you also opt into the fallbacks for all of these operations.
224+
225+
Here are some particular dispatch keys which you may need to define an operator
226+
for.
227+
228+
Autocast
229+
^^^^^^^^
230+
231+
The Autocast dispatch key implements support for
232+
`automatic mixed precision <https://developer.nvidia.com/automatic-mixed-precision>`_
233+
(AMP). An autocast kernel typically modifies the operation of an operator by casting the
234+
input arguments to some precision before carrying out the operation. For some
235+
operations, it is numerically safe to cast to lower precision, which is how AMP
236+
can achieve speed ups and reduced memory usage without sacrificing much
237+
accuracy. A nontrivial autocast kernel looks something like this:
238+
239+
.. code-block:: cpp
240+
241+
Tensor mymatmul_autocast(const Tensor& self, const Tensor& other) {
242+
c10::impl::ExcludeDispatchKeyGuard no_autocast(c10::DispatchKey::Autocast);
243+
return mymatmul(autocast::_cast(at::kHalf, self), autocast::_cast(at::kHalf, other));
244+
}
245+
246+
Notice that, like our autograd kernels, we exclude the ``Autocast`` key from
247+
dispatch before redispatching. By default, if no autocast kernel is provided,
248+
we simply fallthrough directly to the regular operator implementation (no
249+
autocasting occurs.) (We didn't use ``myadd`` for this example, since pointwise
250+
addition doesn't do autocasting and should just fall through).
251+
252+
When should an autocast kernel be registered? Unfortunately, there aren't
253+
cut-and-dry rules for when you should cast to a lower precision. You can
254+
get a sense for what operators have autocasting behavior by looking at
255+
the `AMP documentation
256+
<https://pytorch.org/docs/master/amp.html#op-specific-behavior>`_. Some other
257+
general rules:
258+
259+
* Operations that do reductions should be carried out in float32,
260+
* Any operation with multiple float tensor inputs has to standardize them
261+
to a common precision, and
262+
* Any operation that does a convolution or gemm under the hood should
263+
probably be float16
264+
265+
Batched
266+
^^^^^^^
267+
268+
Batched tensors allow you to write your code in a per-example manner, and then
269+
have them be automatically batched when run under a ``vmap`` invocation. The
270+
API for writing batching rules is currently under development, but once it is
271+
stabilized, you can add support for ``vmap`` for your operators by registering
272+
a kernel at the Batched dispatch key.
273+
274+
Tracer
275+
^^^^^^
276+
277+
The Tracer dispatch key implements support for recording invocations of operators
278+
into a trace when you run ``torch.jit.trace``. We intend to provide a
279+
boxed fallback that will implement tracing for arbitrary operations,
280+
see `issue #41478 <https://github.com/pytorch/pytorch/issues/41478>`_ to track
281+
progress.
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
cmake_minimum_required(VERSION 3.1 FATAL_ERROR)
2+
project(dispatcher)
3+
4+
find_package(Torch REQUIRED)
5+
6+
add_library(dispatcher SHARED op.cpp)
7+
target_compile_features(dispatcher PRIVATE cxx_std_14)
8+
target_link_libraries(dispatcher "${TORCH_LIBRARIES}")

advanced_source/dispatcher/op.cpp

Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,105 @@
1+
#include <torch/torch.h>
2+
#include <torch/script.h>
3+
4+
#include <ATen/NamedTensorUtils.h>
5+
6+
using torch::Tensor;
7+
using torch::DeviceType;
8+
using torch::autograd::tensor_list;
9+
using torch::autograd::AutogradContext;
10+
11+
// BEGIN myadd
12+
Tensor myadd(const Tensor& self, const Tensor& other) {
13+
static auto op = torch::Dispatcher::singleton()
14+
.findSchemaOrThrow("myops::myadd", "")
15+
.typed<decltype(myadd)>();
16+
return op.call(self, other);
17+
}
18+
// END myadd
19+
20+
// BEGIN TORCH_LIBRARY
21+
TORCH_LIBRARY(myops, m) {
22+
m.def("myadd(Tensor self, Tensor other) -> Tensor");
23+
}
24+
// END TORCH_LIBRARY
25+
26+
// BEGIN myadd_cpu
27+
Tensor myadd_cpu(const Tensor& self_, const Tensor& other_) {
28+
TORCH_CHECK(self_.sizes() == other_.sizes());
29+
TORCH_INTERNAL_ASSERT(self_.device().type() == DeviceType::CPU);
30+
TORCH_INTERNAL_ASSERT(other_.device().type() == DeviceType::CPU);
31+
Tensor self = self_.contiguous();
32+
Tensor other = other_.contiguous();
33+
Tensor result = torch::empty(self.sizes(), self.options());
34+
const float* self_ptr = self.data_ptr<float>();
35+
const float* other_ptr = other.data_ptr<float>();
36+
float* result_ptr = result.data_ptr<float>();
37+
for (int64_t i = 0; i < result.numel(); i++) {
38+
result_ptr[i] = self_ptr[i] + other_ptr[i];
39+
}
40+
return result;
41+
}
42+
// END myadd_cpu
43+
44+
// BEGIN TORCH_LIBRARY_IMPL CPU
45+
TORCH_LIBRARY_IMPL(myops, CPU, m) {
46+
m.impl("myadd", myadd_cpu);
47+
}
48+
// END TORCH_LIBRARY_IMPL CPU
49+
50+
Tensor myadd_cuda(const Tensor& self, const Tensor& other) {
51+
// Insert your CUDA implementation here
52+
TORCH_CHECK(0, "CUDA not yet implemented");
53+
}
54+
55+
// BEGIN TORCH_LIBRARY_IMPL CUDA
56+
TORCH_LIBRARY_IMPL(myops, CUDA, m) {
57+
m.impl("myadd", myadd_cuda);
58+
}
59+
// END TORCH_LIBRARY_IMPL CUDA
60+
61+
// BEGIN myadd_autograd
62+
class MyAddFunction : public torch::autograd::Function<MyAddFunction> {
63+
public:
64+
static Tensor forward(
65+
AutogradContext *ctx, torch::Tensor self, torch::Tensor other) {
66+
at::AutoNonVariableTypeMode g;
67+
return myadd(self, other);
68+
}
69+
70+
static tensor_list backward(AutogradContext *ctx, tensor_list grad_outputs) {
71+
auto grad_output = grad_outputs[0];
72+
return {grad_output, grad_output};
73+
}
74+
};
75+
76+
Tensor myadd_autograd(const Tensor& self, const Tensor& other) {
77+
return MyAddFunction::apply(self, other)[0];
78+
}
79+
// END myadd_autograd
80+
81+
// BEGIN TORCH_LIBRARY_IMPL Autograd
82+
TORCH_LIBRARY_IMPL(myops, Autograd, m) {
83+
m.impl("myadd", myadd_autograd);
84+
}
85+
// END TORCH_LIBRARY_IMPL Autograd
86+
87+
#if 0
88+
// BEGIN TORCH_LIBRARY_IMPL Named
89+
Tensor myadd_named(const Tensor& self, const Tensor& other) {
90+
// TODO: shouldn't need to do size check here
91+
TORCH_CHECK(self.sizes() == other.sizes());
92+
auto maybe_outnames = at::unify_from_right(self.names(), other.names());
93+
auto result = ([&]() {
94+
at::NoNamesGuard guard;
95+
return myadd(self, other);
96+
})();
97+
at::namedinference::propagate_names_if_nonempty(result, maybe_outnames);
98+
return result;
99+
}
100+
101+
TORCH_LIBRARY_IMPL(myops, Named, m) {
102+
m.impl("myadd", myadd_named);
103+
}
104+
// END TORCH_LIBRARY_IMPL Named
105+
#endif

0 commit comments

Comments
 (0)