@@ -229,38 +229,71 @@ Autocast
229
229
^^^^^^^^
230
230
231
231
The Autocast dispatch key implements support for
232
- `automatic mixed precision <https://developer.nvidia.com/automatic-mixed-precision >`_
233
- (AMP). An autocast kernel typically modifies the operation of an operator by casting the
234
- input arguments to some precision before carrying out the operation. For some
235
- operations, it is numerically safe to cast to lower precision, which is how AMP
236
- can achieve speed ups and reduced memory usage without sacrificing much
237
- accuracy. A nontrivial autocast kernel looks something like this:
232
+ `automatic mixed precision (AMP)<https://pytorch.org/docs/stable/amp.html> `_.
233
+ An autocast wrapper kernel typically casts incoming ``float16 `` or ``float32 `` CUDA tensors
234
+ to some preferred precision before running the op.
235
+ For example, matmuls and convolutions on floating-point CUDA tensors usually run faster
236
+ and use less memory in ``float16 `` without impairing convergence.
237
+ Autocast wrappers only have an effect in
238
+ `autocast-enabled contexts<https://pytorch.org/docs/stable/amp.html#torch.cuda.amp.autocast> `_.
239
+
240
+ Here's an autocast wrapper for a hypothetical custom matmul, along with its registration:
238
241
239
242
.. code-block :: cpp
243
+ // Autocast-specific helper functions
244
+ #include <ATen/autocast_mode.h>
240
245
241
246
Tensor mymatmul_autocast(const Tensor& self, const Tensor& other) {
242
247
c10::impl::ExcludeDispatchKeyGuard no_autocast(c10::DispatchKey::Autocast);
243
- return mymatmul(autocast::_cast(at::kHalf, self), autocast::_cast(at::kHalf, other));
248
+ return mymatmul(at::autocast::cached_cast(at::kHalf, self),
249
+ at::autocast::cached_cast(at::kHalf, other));
250
+ }
251
+
252
+ TORCH_LIBRARY_IMPL(myops, Autocast, m) {
253
+ m.impl("mymatmul", mymatmul_autocast);
244
254
}
245
255
256
+ ``cached_cast(kHalf, tensor) `` casts ``tensor `` to ``float16 `` if ``tensor `` is CUDA and ``float32 ``,
257
+ otherwise, it leaves ``tensor `` unchanged (c.f. the
258
+ `eligibility policy <https://pytorch.org/docs/stable/amp.html#op-eligibility >`_ for natively autocasted ops).
259
+ This ensures if the network calls ``mymatmul `` on any mixture of ``float16 `` and ``float32 `` CUDA tensors,
260
+ ``mymatmul `` runs in ``float16 ``. Meanwhile, calls to ``mymatmul `` with non-CUDA, integer-type, or ``float64 ``
261
+ inputs are unaffected. Using ``cached_cast `` to follow the native eligibility policy in your own autocast wrapper
262
+ is recommended, but not required. For example, if you wanted to force ``float16 `` execution for all input types,
263
+ you could ``return mymatmul(self.half(), other.half()); `` instead of using ``cached_cast ``.
264
+
246
265
Notice that, like our autograd kernels, we exclude the ``Autocast `` key from
247
- dispatch before redispatching. By default, if no autocast kernel is provided,
248
- we simply fallthrough directly to the regular operator implementation (no
249
- autocasting occurs.) (We didn't use ``myadd `` for this example, since pointwise
250
- addition doesn't do autocasting and should just fall through).
251
-
252
- When should an autocast kernel be registered? Unfortunately, there aren't
253
- cut-and-dry rules for when you should cast to a lower precision. You can
254
- get a sense for what operators have autocasting behavior by looking at
255
- the `AMP documentation
256
- <https://pytorch.org/docs/master/amp.html#op-specific-behavior> `_. Some other
257
- general rules:
258
-
259
- * Operations that do reductions should be carried out in float32,
260
- * Any operation with multiple float tensor inputs has to standardize them
261
- to a common precision, and
262
- * Any operation that does a convolution or gemm under the hood should
263
- probably be float16
266
+ dispatch before redispatching. By default, if no autocast wrapper is provided,
267
+ we fallthrough directly to the regular operator implementation (no
268
+ autocasting occurs). (We didn't use ``myadd `` for this example, since pointwise
269
+ addition doesn't need autocasting and should just fall through.)
270
+
271
+ When should an autocast wrapper be registered? Unfortunately, there aren't
272
+ cut-and-dried rules for an op's preferred precision. You can
273
+ get a sense for some native ops' preferred precisions by looking at the
274
+ `cast lists <https://pytorch.org/docs/master/amp.html#op-specific-behavior >`_.
275
+ General guidance:
276
+
277
+ * Ops that do reductions should probably execute in float32,
278
+ * Any op that does a convolution or gemm under the hood should
279
+ probably execute in float16, and
280
+ * Other ops with multiple floating-point tensor inputs should standardize
281
+ them to a common precision (unless the implementation is known to support
282
+ inputs with different precisions).
283
+
284
+ If your custom op falls into the third category, the ``promote_type `` template
285
+ helps figure out the widest floating-point type present among input tensors, which is
286
+ usually the safest option for the execution type:
287
+
288
+ .. code-block :: cpp
289
+ #include <ATen/autocast_mode.h>
290
+
291
+ Tensor my_multiple_input_op_autocast(const Tensor& t0, const Tensor& t1) {
292
+ c10::impl::ExcludeDispatchKeyGuard no_autocast(c10::DispatchKey::Autocast);
293
+ auto exec_type = at::autocast::promote_type(at::kHalf, t0, t1);
294
+ return my_multiple_input_op(at::autocast::cached_cast(exec_type, t0),
295
+ at::autocast::cached_cast(exec_type, t1));
296
+ }
264
297
265
298
Batched
266
299
^^^^^^^
0 commit comments