@@ -229,17 +229,18 @@ Autocast
229
229
^^^^^^^^
230
230
231
231
The Autocast dispatch key implements support for
232
- `automatic mixed precision (AMP)<https://pytorch.org/docs/stable/amp.html> `_.
232
+ `automatic mixed precision (AMP) <https://pytorch.org/docs/stable/amp.html >`_.
233
233
An autocast wrapper kernel typically casts incoming ``float16 `` or ``float32 `` CUDA tensors
234
234
to some preferred precision before running the op.
235
235
For example, matmuls and convolutions on floating-point CUDA tensors usually run faster
236
236
and use less memory in ``float16 `` without impairing convergence.
237
237
Autocast wrappers only have an effect in
238
- `autocast-enabled contexts<https://pytorch.org/docs/stable/amp.html#torch.cuda.amp.autocast> `_.
238
+ `autocast-enabled contexts <https://pytorch.org/docs/stable/amp.html#torch.cuda.amp.autocast >`_.
239
239
240
240
Here's an autocast wrapper for a hypothetical custom matmul, along with its registration:
241
241
242
242
.. code-block :: cpp
243
+
243
244
// Autocast-specific helper functions
244
245
#include <ATen/autocast_mode.h>
245
246
@@ -263,7 +264,9 @@ is recommended, but not required. For example, if you wanted to force ``float16
263
264
you could ``return mymatmul(self.half(), other.half()); `` instead of using ``cached_cast ``.
264
265
265
266
Notice that, like our autograd kernels, we exclude the ``Autocast `` key from
266
- dispatch before redispatching. By default, if no autocast wrapper is provided,
267
+ dispatch before redispatching.
268
+
269
+ By default, if no autocast wrapper is provided,
267
270
we fallthrough directly to the regular operator implementation (no
268
271
autocasting occurs). (We didn't use ``myadd `` for this example, since pointwise
269
272
addition doesn't need autocasting and should just fall through.)
@@ -274,23 +277,23 @@ get a sense for some native ops' preferred precisions by looking at the
274
277
`cast lists <https://pytorch.org/docs/master/amp.html#op-specific-behavior >`_.
275
278
General guidance:
276
279
277
- * Ops that do reductions should probably execute in float32,
280
+ * Ops that do reductions should probably execute in `` float32 `` ,
278
281
* Any op that does a convolution or gemm under the hood should
279
- probably execute in float16, and
282
+ probably execute in `` float16 `` , and
280
283
* Other ops with multiple floating-point tensor inputs should standardize
281
- them to a common precision (unless the implementation is known to support
282
- inputs with different precisions).
284
+ them to a common precision (unless the implementation supports inputs with different precisions).
283
285
284
286
If your custom op falls into the third category, the ``promote_type `` template
285
287
helps figure out the widest floating-point type present among input tensors, which is
286
- usually the safest option for the execution type:
288
+ the safest choice for the execution type:
287
289
288
290
.. code-block :: cpp
291
+
289
292
#include <ATen/autocast_mode.h>
290
293
291
294
Tensor my_multiple_input_op_autocast(const Tensor& t0, const Tensor& t1) {
292
295
c10::impl::ExcludeDispatchKeyGuard no_autocast(c10::DispatchKey::Autocast);
293
- auto exec_type = at::autocast::promote_type(at::kHalf, t0, t1);
296
+ auto exec_type = at::autocast::promote_type(at::kHalf/*optimistic initial guess*/ , t0, t1);
294
297
return my_multiple_input_op(at::autocast::cached_cast(exec_type, t0),
295
298
at::autocast::cached_cast(exec_type, t1));
296
299
}
0 commit comments