Skip to content

Commit abf8b6b

Browse files
committed
fixes
1 parent c54666c commit abf8b6b

File tree

1 file changed

+12
-9
lines changed

1 file changed

+12
-9
lines changed

advanced_source/dispatcher.rst

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -229,17 +229,18 @@ Autocast
229229
^^^^^^^^
230230

231231
The Autocast dispatch key implements support for
232-
`automatic mixed precision (AMP)<https://pytorch.org/docs/stable/amp.html>`_.
232+
`automatic mixed precision (AMP) <https://pytorch.org/docs/stable/amp.html>`_.
233233
An autocast wrapper kernel typically casts incoming ``float16`` or ``float32`` CUDA tensors
234234
to some preferred precision before running the op.
235235
For example, matmuls and convolutions on floating-point CUDA tensors usually run faster
236236
and use less memory in ``float16`` without impairing convergence.
237237
Autocast wrappers only have an effect in
238-
`autocast-enabled contexts<https://pytorch.org/docs/stable/amp.html#torch.cuda.amp.autocast>`_.
238+
`autocast-enabled contexts <https://pytorch.org/docs/stable/amp.html#torch.cuda.amp.autocast>`_.
239239

240240
Here's an autocast wrapper for a hypothetical custom matmul, along with its registration:
241241

242242
.. code-block:: cpp
243+
243244
// Autocast-specific helper functions
244245
#include <ATen/autocast_mode.h>
245246
@@ -263,7 +264,9 @@ is recommended, but not required. For example, if you wanted to force ``float16
263264
you could ``return mymatmul(self.half(), other.half());`` instead of using ``cached_cast``.
264265

265266
Notice that, like our autograd kernels, we exclude the ``Autocast`` key from
266-
dispatch before redispatching. By default, if no autocast wrapper is provided,
267+
dispatch before redispatching.
268+
269+
By default, if no autocast wrapper is provided,
267270
we fallthrough directly to the regular operator implementation (no
268271
autocasting occurs). (We didn't use ``myadd`` for this example, since pointwise
269272
addition doesn't need autocasting and should just fall through.)
@@ -274,23 +277,23 @@ get a sense for some native ops' preferred precisions by looking at the
274277
`cast lists <https://pytorch.org/docs/master/amp.html#op-specific-behavior>`_.
275278
General guidance:
276279

277-
* Ops that do reductions should probably execute in float32,
280+
* Ops that do reductions should probably execute in ``float32``,
278281
* Any op that does a convolution or gemm under the hood should
279-
probably execute in float16, and
282+
probably execute in ``float16``, and
280283
* Other ops with multiple floating-point tensor inputs should standardize
281-
them to a common precision (unless the implementation is known to support
282-
inputs with different precisions).
284+
them to a common precision (unless the implementation supports inputs with different precisions).
283285

284286
If your custom op falls into the third category, the ``promote_type`` template
285287
helps figure out the widest floating-point type present among input tensors, which is
286-
usually the safest option for the execution type:
288+
the safest choice for the execution type:
287289

288290
.. code-block:: cpp
291+
289292
#include <ATen/autocast_mode.h>
290293
291294
Tensor my_multiple_input_op_autocast(const Tensor& t0, const Tensor& t1) {
292295
c10::impl::ExcludeDispatchKeyGuard no_autocast(c10::DispatchKey::Autocast);
293-
auto exec_type = at::autocast::promote_type(at::kHalf, t0, t1);
296+
auto exec_type = at::autocast::promote_type(at::kHalf/*optimistic initial guess*/, t0, t1);
294297
return my_multiple_input_op(at::autocast::cached_cast(exec_type, t0),
295298
at::autocast::cached_cast(exec_type, t1));
296299
}

0 commit comments

Comments
 (0)