@@ -31,12 +31,14 @@ Defining schema and backend implementations
31
31
-------------------------------------------
32
32
33
33
The general principle behind the dispatcher is that it divides the
34
- implementation of an operator into multiple kernels, each of which
35
- implements functionality for a specific *dispatch key *; for example,
36
- CPU, CUDA or Autograd. The end effect is that when you call
37
- an operator, we first execute the Autograd kernel, and then we
38
- redispatch to the CPU or CUDA kernel depending on the device
39
- types of the passed in tensors.
34
+ implementation of an operator into multiple kernels, each of which implements
35
+ functionality for a specific *dispatch key *; for example, CPU, CUDA or Autograd.
36
+ The dispatcher determines what the highest priority dispatch key is at the time
37
+ you call an operator (this is done by looking at both the tensor arguments as
38
+ well as some thread local state), and transfers control to the kernel for that
39
+ dispatch key. The end effect is that when you call an operator, we first
40
+ execute the Autograd kernel, and then we redispatch to the CPU or CUDA kernel
41
+ depending on the device types of the passed in tensors.
40
42
41
43
Let's take a look at the various parts involved in making this
42
44
happen. First, we must define the schema for the operator in question.
@@ -58,10 +60,12 @@ For concreteness, here is a really simple implementation of addition on CPU:
58
60
:start-after: BEGIN myadd_cpu
59
61
:end-before: END myadd_cpu
60
62
61
- We'd like to register this function as an implementation of ``myops::myadd ``, but we
62
- don't want to register it as a catch-all kernel to be run in all cases; we
63
- only want it to be run when we call ``myops::myadd `` at the backend on CPU tensors.
64
- To do this, we can use the ``TORCH_LIBRARY_IMPL `` macro:
63
+ We'd like to register this function as an implementation of ``myops::myadd ``.
64
+ However, the simple way of registering it (``def("myadd", myadd_cpu) ``) would
65
+ register the kernel to run in all cases, even if the tensor is not a CPU
66
+ tensor! (Internally, we refer to these as "catch-all" kernels, since they
67
+ catch all cases.) To ensure that ``myadd_cpu `` is only run for
68
+ CPU tensors, we can use the ``TORCH_LIBRARY_IMPL `` macro:
65
69
66
70
.. literalinclude :: ../advanced_source/dispatcher/op.cpp
67
71
:language: cpp
@@ -71,10 +75,8 @@ To do this, we can use the ``TORCH_LIBRARY_IMPL`` macro:
71
75
The ``TORCH_LIBRARY_IMPL `` lets us register implementations for operators on
72
76
a specific dispatch key (in this case, CPU). Each call to ``impl ``
73
77
associates a CPU kernel with the corresponding operator (which we previously
74
- defined in the ``TORCH_LIBRARY `` block). You can have as many
75
- ``TORCH_LIBRARY_IMPL `` blocks for a namespace as you like; so for example,
76
- if we also have a CUDA implementation ``myadd_cuda ``, we can register it
77
- with:
78
+ defined in the ``TORCH_LIBRARY `` block). If we also have a CUDA implementation ``myadd_cuda ``,
79
+ we can register it in a separate ``TORCH_LIBRARY_IMPL `` block:
78
80
79
81
.. literalinclude :: ../advanced_source/dispatcher/op.cpp
80
82
:language: cpp
83
85
84
86
These registrations can be split across files or even across library boundaries; so
85
87
for example, you could have these two ``TORCH_LIBRARY_IMPL `` blocks compiled
86
- into a separate ``myops_cpu `` and ``myops_cuda `` dynamic library.
88
+ into a separate ``myops_cpu `` and ``myops_cuda `` dynamic libraries. Generally,
89
+ speaking, the structure of your registrations will look like this:
90
+
91
+ 1. A single ``TORCH_LIBRARY `` that lists every custom operator in your namespace
92
+ in a centralized place.
93
+ 2. A ``TORCH_LIBRARY_IMPL `` per dispatch key that registers implementations for
94
+ that key (e.g., CPU or CUDA). If you like, you can further subdivide
95
+ ``TORCH_LIBRARY_IMPL `` blocks into a block per operator. This is convenient
96
+ if you have a separate file per operator implementation, but don't want to
97
+ expose the operators in a header; you can just put the registration in the
98
+ cpp file that defines your operator.
87
99
88
100
.. note ::
89
101
152
164
2. Call the dispatch function ``myadd `` to call back into the dispatcher.
153
165
154
166
Without (1), your calls will infinite loop (and stack overflow), because
155
- ``myadd `` will send you back to the autograd implementation! With (1),
156
- the redispatch will skip over autograd and go to the next handlers,
157
- which will either be CPU and CUDA.
167
+ ``myadd `` will send you back to this function (as the highest priority dispatch
168
+ key would still be autograd.) With (1),
169
+ autograd is excluded from the set of dispatch keys under consideration, and
170
+ we will go to the next handlers, which will either be CPU and CUDA.
158
171
159
172
We can now register this function in the same way we registered the CPU/CUDA
160
173
functions:
0 commit comments