Skip to content

Commit a824b85

Browse files
committed
Clarify autograd-autocast interaction for custom ops
1 parent 85ab17c commit a824b85

File tree

1 file changed

+24
-0
lines changed

1 file changed

+24
-0
lines changed

advanced_source/dispatcher.rst

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,8 @@ speaking, the structure of your registrations will look like this:
105105
that provides implementations for all basic operators on the XLA dispatch
106106
key.
107107

108+
.. _autograd-support:
109+
108110
Adding autograd support
109111
-----------------------
110112

@@ -299,6 +301,28 @@ the safest choice for the execution type:
299301
at::autocast::cached_cast(exec_type, t1));
300302
}
301303
304+
If your custom op is :ref:`autograd-enabled<autograd-support>`, you only need to write and register
305+
an autocast wrapper for same name onto which the autograd wrapper is registered.
306+
For example, if you wanted an autocast wrapper for the ``myadd`` function shown
307+
in the autograd section, all you'd need is
308+
309+
.. code-block:: cpp
310+
311+
Tensor myadd_autocast(const Tensor& self, const Tensor& other) {
312+
c10::impl::ExcludeDispatchKeyGuard no_autocast(c10::DispatchKey::Autocast);
313+
return myadd(at::autocast::cached_cast(<desired dtype>, self),
314+
at::autocast::cached_cast(<desired dtype>, other));
315+
}
316+
317+
TORCH_LIBRARY_IMPL(myops, Autocast, m) {
318+
m.impl("myadd", myadd_autocast);
319+
}
320+
321+
There are no separate gymnastics to make the backward method autocast compatible.
322+
However, the backward method defined in your custom autograd function will run in the same
323+
dtype as autocast sets for the forward method, so you should choose a ``<desired dtype>``
324+
suitable for both your forward and backward methods.
325+
302326
Batched
303327
^^^^^^^
304328

0 commit comments

Comments
 (0)