@@ -105,6 +105,8 @@ speaking, the structure of your registrations will look like this:
105
105
that provides implementations for all basic operators on the XLA dispatch
106
106
key.
107
107
108
+ .. _autograd-support :
109
+
108
110
Adding autograd support
109
111
-----------------------
110
112
@@ -299,6 +301,28 @@ the safest choice for the execution type:
299
301
at::autocast::cached_cast(exec_type, t1));
300
302
}
301
303
304
+ If your custom op is :ref: `autograd-enabled<autograd-support> `, you only need to write and register
305
+ an autocast wrapper for same name onto which the autograd wrapper is registered.
306
+ For example, if you wanted an autocast wrapper for the ``myadd `` function shown
307
+ in the autograd section, all you'd need is
308
+
309
+ .. code-block :: cpp
310
+
311
+ Tensor myadd_autocast(const Tensor& self, const Tensor& other) {
312
+ c10::impl::ExcludeDispatchKeyGuard no_autocast(c10::DispatchKey::Autocast);
313
+ return myadd(at::autocast::cached_cast(<desired dtype>, self),
314
+ at::autocast::cached_cast(<desired dtype>, other));
315
+ }
316
+
317
+ TORCH_LIBRARY_IMPL(myops, Autocast, m) {
318
+ m.impl("myadd", myadd_autocast);
319
+ }
320
+
321
+ There are no separate gymnastics to make the backward method autocast compatible.
322
+ However, the backward method defined in your custom autograd function will run in the same
323
+ dtype as autocast sets for the forward method, so you should choose a ``<desired dtype> ``
324
+ suitable for both your forward and backward methods.
325
+
302
326
Batched
303
327
^^^^^^^
304
328
0 commit comments