Skip to content

Commit cb428f2

Browse files
committed
Update default modes doc
1 parent 0f8d876 commit cb428f2

File tree

1 file changed

+9
-0
lines changed

1 file changed

+9
-0
lines changed

doc/library/compile/mode.rst

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,9 @@ PyTensor defines the following modes by name:
2020

2121
- ``'FAST_COMPILE'``: Apply just a few graph rewrites and only use Python implementations.
2222
- ``'FAST_RUN'``: Apply all rewrites, and use C implementations where possible.
23+
- ``NUMBA``: Apply all relevant related rewrites and compile the whole graph using Numba.
24+
- ``JAX``: Apply all relevant rewrites and compile the whole graph using JAX.
25+
- ``PYTORCH`` Apply all relevant rewrites and compile the whole graph using PyTorch compile.
2326
- ``'DebugMode'``: A mode for debugging. See :ref:`DebugMode <debugmode>` for details.
2427
- ``'NanGuardMode``: :ref:`Nan detector <nanguardmode>`
2528
- ``'DEBUG_MODE'``: Deprecated. Use the string DebugMode.
@@ -28,6 +31,12 @@ The default mode is typically ``FAST_RUN``, but it can be controlled via the
2831
configuration variable :attr:`config.mode`, which can be
2932
overridden by passing the keyword argument to :func:`pytensor.function`.
3033

34+
For Numba, JAX, and PyTorch, we exclude rewrites that introduce C-only Ops,
35+
as well as BLAS optimizations, as those are done automatically by the respective backends.
36+
37+
For JAX we also exclude fusion and inplace optimizations, as JAX does not support them
38+
at the user level. They are performed automatically by JAX.
39+
3140
.. TODO::
3241

3342
For a finer level of control over which rewrites are applied, and whether

0 commit comments

Comments
 (0)