Skip to content

Commit c54666c

Browse files
committed
draft
1 parent fa7cff7 commit c54666c

File tree

1 file changed

+57
-24
lines changed

1 file changed

+57
-24
lines changed

advanced_source/dispatcher.rst

Lines changed: 57 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -229,38 +229,71 @@ Autocast
229229
^^^^^^^^
230230

231231
The Autocast dispatch key implements support for
232-
`automatic mixed precision <https://developer.nvidia.com/automatic-mixed-precision>`_
233-
(AMP). An autocast kernel typically modifies the operation of an operator by casting the
234-
input arguments to some precision before carrying out the operation. For some
235-
operations, it is numerically safe to cast to lower precision, which is how AMP
236-
can achieve speed ups and reduced memory usage without sacrificing much
237-
accuracy. A nontrivial autocast kernel looks something like this:
232+
`automatic mixed precision (AMP)<https://pytorch.org/docs/stable/amp.html>`_.
233+
An autocast wrapper kernel typically casts incoming ``float16`` or ``float32`` CUDA tensors
234+
to some preferred precision before running the op.
235+
For example, matmuls and convolutions on floating-point CUDA tensors usually run faster
236+
and use less memory in ``float16`` without impairing convergence.
237+
Autocast wrappers only have an effect in
238+
`autocast-enabled contexts<https://pytorch.org/docs/stable/amp.html#torch.cuda.amp.autocast>`_.
239+
240+
Here's an autocast wrapper for a hypothetical custom matmul, along with its registration:
238241

239242
.. code-block:: cpp
243+
// Autocast-specific helper functions
244+
#include <ATen/autocast_mode.h>
240245
241246
Tensor mymatmul_autocast(const Tensor& self, const Tensor& other) {
242247
c10::impl::ExcludeDispatchKeyGuard no_autocast(c10::DispatchKey::Autocast);
243-
return mymatmul(autocast::_cast(at::kHalf, self), autocast::_cast(at::kHalf, other));
248+
return mymatmul(at::autocast::cached_cast(at::kHalf, self),
249+
at::autocast::cached_cast(at::kHalf, other));
250+
}
251+
252+
TORCH_LIBRARY_IMPL(myops, Autocast, m) {
253+
m.impl("mymatmul", mymatmul_autocast);
244254
}
245255
256+
``cached_cast(kHalf, tensor)`` casts ``tensor`` to ``float16`` if ``tensor`` is CUDA and ``float32``,
257+
otherwise, it leaves ``tensor`` unchanged (c.f. the
258+
`eligibility policy <https://pytorch.org/docs/stable/amp.html#op-eligibility>`_ for natively autocasted ops).
259+
This ensures if the network calls ``mymatmul`` on any mixture of ``float16`` and ``float32`` CUDA tensors,
260+
``mymatmul`` runs in ``float16``. Meanwhile, calls to ``mymatmul`` with non-CUDA, integer-type, or ``float64``
261+
inputs are unaffected. Using ``cached_cast`` to follow the native eligibility policy in your own autocast wrapper
262+
is recommended, but not required. For example, if you wanted to force ``float16`` execution for all input types,
263+
you could ``return mymatmul(self.half(), other.half());`` instead of using ``cached_cast``.
264+
246265
Notice that, like our autograd kernels, we exclude the ``Autocast`` key from
247-
dispatch before redispatching. By default, if no autocast kernel is provided,
248-
we simply fallthrough directly to the regular operator implementation (no
249-
autocasting occurs.) (We didn't use ``myadd`` for this example, since pointwise
250-
addition doesn't do autocasting and should just fall through).
251-
252-
When should an autocast kernel be registered? Unfortunately, there aren't
253-
cut-and-dry rules for when you should cast to a lower precision. You can
254-
get a sense for what operators have autocasting behavior by looking at
255-
the `AMP documentation
256-
<https://pytorch.org/docs/master/amp.html#op-specific-behavior>`_. Some other
257-
general rules:
258-
259-
* Operations that do reductions should be carried out in float32,
260-
* Any operation with multiple float tensor inputs has to standardize them
261-
to a common precision, and
262-
* Any operation that does a convolution or gemm under the hood should
263-
probably be float16
266+
dispatch before redispatching. By default, if no autocast wrapper is provided,
267+
we fallthrough directly to the regular operator implementation (no
268+
autocasting occurs). (We didn't use ``myadd`` for this example, since pointwise
269+
addition doesn't need autocasting and should just fall through.)
270+
271+
When should an autocast wrapper be registered? Unfortunately, there aren't
272+
cut-and-dried rules for an op's preferred precision. You can
273+
get a sense for some native ops' preferred precisions by looking at the
274+
`cast lists <https://pytorch.org/docs/master/amp.html#op-specific-behavior>`_.
275+
General guidance:
276+
277+
* Ops that do reductions should probably execute in float32,
278+
* Any op that does a convolution or gemm under the hood should
279+
probably execute in float16, and
280+
* Other ops with multiple floating-point tensor inputs should standardize
281+
them to a common precision (unless the implementation is known to support
282+
inputs with different precisions).
283+
284+
If your custom op falls into the third category, the ``promote_type`` template
285+
helps figure out the widest floating-point type present among input tensors, which is
286+
usually the safest option for the execution type:
287+
288+
.. code-block:: cpp
289+
#include <ATen/autocast_mode.h>
290+
291+
Tensor my_multiple_input_op_autocast(const Tensor& t0, const Tensor& t1) {
292+
c10::impl::ExcludeDispatchKeyGuard no_autocast(c10::DispatchKey::Autocast);
293+
auto exec_type = at::autocast::promote_type(at::kHalf, t0, t1);
294+
return my_multiple_input_op(at::autocast::cached_cast(exec_type, t0),
295+
at::autocast::cached_cast(exec_type, t1));
296+
}
264297
265298
Batched
266299
^^^^^^^

0 commit comments

Comments
 (0)