Skip to content

Commit e534243

Browse files
authored
Add docs for torch.compile(numpy) (#109789)
ghstack-source-id: 3e29b38 Pull Request resolved: #109710
1 parent 01fa8c1 commit e534243

File tree

1 file changed

+156
-2
lines changed

1 file changed

+156
-2
lines changed

docs/source/torch.compiler_faq.rst

Lines changed: 156 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -317,8 +317,8 @@ them by default: ``env TORCHDYNAMO_DYNAMIC_SHAPES=0 python model.py`` 2.
317317
CUDA graphs with Triton are enabled by default in inductor but removing
318318
them may alleviate some OOM issues: ``torch._inductor.config.triton.cudagraphs = False``.
319319

320-
``torch.func`` works with ``torch.compile`` (for `grad` and `vmap` transforms)
321-
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
320+
Does ``torch.func`` work with ``torch.compile`` (for `grad` and `vmap` transforms)?
321+
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
322322

323323
Applying a ``torch.func`` transform to a function that uses ``torch.compile``
324324
does not work:
@@ -528,6 +528,160 @@ invokes an ``nn.Module``. This is because the outputs now depend on the
528528
parameters of the ``nn.Module``. To get this to work, use
529529
``torch.func.functional_call`` to extract the module state.
530530

531+
Does NumPy work with ``torch.compile``?
532+
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
533+
534+
Starting in 2.1, ``torch.compile`` understands native NumPy programs that
535+
work on NumPy arrays, and mixed PyTorch-NumPy programs that convert from PyTorch
536+
to NumPy and back via ``x.numpy()``, ``torch.from_numpy``, and related functions.
537+
538+
.. _nonsupported-numpy-feats:
539+
540+
Which NumPy features does ``torch.compile`` support?
541+
----------------------------------------------------
542+
543+
NumPy within ``torch.compile`` follows NumPy 2.0 pre-release.
544+
545+
Generally, ``torch.compile`` is able to trace through most NumPy constructions,
546+
and when it cannot, it falls back to eager and lets NumPy execute that piece of
547+
code. Even then, there are a few features where ``torch.compile`` semantics
548+
slightly deviate from those of NumPy:
549+
550+
- NumPy scalars: We model them as 0-D arrays. That is, ``np.float32(3)`` returns
551+
a 0-D array under ``torch.compile``. To avoid a graph break, it is best to use this 0-D
552+
array. If this breaks your code, you can workaround this by casting the NumPy scalar
553+
to the relevant Python scalar type ``bool/int/float``.
554+
555+
- Negative strides: ``np.flip`` and slicing with a negative step return a copy.
556+
557+
- Type promotion: NumPy's type promotion will change in NumPy 2.0. The new rules
558+
are described in `NEP 50 <https://numpy.org/neps/nep-0050-scalar-promotion.html)>`__.
559+
``torch.compile`` implements NEP 50 rather than the current soon-to-be deprecated rules.
560+
561+
- ``{tril,triu}_indices_from/{tril,triu}_indices`` return arrays rather than a tuple of arrays.
562+
563+
There are other features for which we do not support tracing and we gracefully
564+
fallback to NumPy for their execution:
565+
566+
- Non-numeric dtypes like datetimes, strings, chars, void, structured dtypes and recarrays.
567+
568+
- Long dtypes ``np.float128/np.complex256`` and some unsigned dtypes ``np.uint16/np.uint32/np.uint64``.
569+
570+
- ``ndarray`` subclasses.
571+
572+
- Masked arrays.
573+
574+
- Esoteric ufunc machinery like ``axes=[(n,k),(k,m)->(n,m)]`` and ufunc methods (e.g., ``np.add.reduce``).
575+
576+
- Sorting / ordering ``complex64/complex128`` arrays.
577+
578+
- NumPy ``np.poly1d`` and ``np.polynomial``.
579+
580+
- Positional ``out1, out2`` args in functions with 2 or more returns (``out=tuple`` does work).
581+
582+
- ``__array_function__``, ``__array_interface__`` and ``__array_wrap__``.
583+
584+
- ``ndarray.ctypes`` attribute.
585+
586+
Can I execute NumPy code on CUDA via ``torch.compile``?
587+
-------------------------------------------------------
588+
589+
Yes you can! To do so, you may simply execute your code within a ``torch.device("cuda")``
590+
context. Consider the example
591+
592+
.. code-block:: python
593+
594+
import torch
595+
import numpy as np
596+
597+
@torch.compile
598+
def numpy_fn(X: np.ndarray, Y: np.ndarray) -> np.ndarray:
599+
return np.sum(X[:, :, None] * Y[:, None, :], axis=(-2, -1))
600+
601+
X = np.random.randn(1024, 64)
602+
Y = np.random.randn(1024, 64)
603+
with torch.device("cuda"):
604+
Z = numpy_fn(X, Y)
605+
606+
607+
In this example, ``numpy_fn`` will be executed in CUDA. For this to be
608+
possible, ``torch.compile`` automatically moves ``X`` and ``Y`` from CPU
609+
to CUDA, and then it moves the result ``Z`` from CUDA to CPU. If we are
610+
executing this function several times in the same program run, we may want
611+
to avoid all these rather expensive memory copies. To do so, we just need
612+
to tweak our ``numpy_fn`` so that it accepts cuda Tensors and returns tensors:
613+
614+
.. code-block:: python
615+
616+
@torch.compile
617+
def numpy_fn(X: torch.Tensor, Y: torch.Tensor) -> torch.Tensor:
618+
X, Y = X.numpy(), Y.numpy()
619+
Z = np.sum(X[:, :, None] * Y[:, None, :], axis=(-2, -1))
620+
return torch.from_numpy(Z)
621+
622+
X = torch.randn(1024, 64, device="cuda")
623+
Y = torch.randn(1024, 64, device="cuda")
624+
with torch.device("cuda"):
625+
Z = numpy_fn(X, Y)
626+
627+
By doing this, we explicitly create the tensors in CUDA memory, and we keep
628+
them there. In this case ``X.numpy()`` and ``from_numpy()`` are hints to the compiler
629+
but no real data movement happens. Note that the original program would not run
630+
on eager mode now. If you want to run it in eager mode, you would need to call
631+
``.numpy(force=True)`` doing ``Z = Z.cuda()`` before returning
632+
``Z``. Of course, doing this would execute the program on eager mode NumPy, and
633+
on CPU.
634+
635+
636+
How do I debug NumPy code under ``torch.compile``?
637+
--------------------------------------------------
638+
639+
Debugging JIT compiled code is challenging, given the complexity of modern
640+
compilers and the daunting errors that they raise.
641+
`The tutorial on how to diagnose runtime errors within torch.compile <https://pytorch.org/docs/main/torch.compiler_troubleshooting.html#diagnosing-runtime-errors>`__
642+
contains a few tips and tricks on how to tackle this task.
643+
644+
If the above is not enough to pinpoint the origin of the issue, there are still
645+
a few other NumPy-specific tools we can use. We can discern whether the bug
646+
is entirely in the PyTorch code by disabling tracing through NumPy functions:
647+
648+
649+
.. code-block:: python
650+
651+
from torch._dynamo import config
652+
config.trace_numpy = False
653+
654+
If the bug lies in the traced NumPy code, we can execute the NumPy code eagerly (without ``torch.compile``)
655+
using PyTorch as a backend by importing ``import torch._numpy as np``.
656+
This should just be used for **debugging purposes** and is in no way a
657+
replacement for the PyTorch API, as it is **much less performant** and, as a
658+
private API, **may change without notice**. At any rate, ``torch._numpy`` is a
659+
Python implementation of NumPy in terms of PyTorch and it is used internally by ``torch.compile`` to
660+
transform NumPy code into Pytorch code. It is rather easy to read and modify,
661+
so if you find any bug in it feel free to submit a PR fixing it or simply open
662+
an issue.
663+
664+
If the program does work when importing ``torch._numpy as np``, chances are
665+
that the bug is in TorchDynamo. If this is the case, please feel open an issue
666+
with a `minimal reproducer <https://pytorch.org/docs/2.1/torch.compiler_troubleshooting.html>`__.
667+
668+
I ``torch.compile`` some NumPy code and I did not see any speed-up.
669+
-------------------------------------------------------------------
670+
671+
The best place to start is the
672+
`tutorial with general advice for how to debug these sort of torch.compile issues <https://pytorch.org/docs/main/torch.compiler_faq.html#why-am-i-not-seeing-speedups>`__.
673+
674+
Some graph breaks may happen because of the use of unsupported features. See
675+
:ref:`nonsupported-numpy-feats`. More generally, it is useful to keep in mind
676+
that some widely used NumPy features do not play well with compilers. For
677+
example, in-place modifications make reasoning difficult within the compiler and
678+
often yield worse performance than their out-of-place counterparts.As such, it is best to avoid
679+
them. Same goes for the use of the ``out=`` parameter. Instead, prefer
680+
out-of-place ops and let ``torch.compile`` optimize the memory use. Same goes
681+
for data-dependent ops like masked indexing through boolean masks, or
682+
data-dependent control flow like ``if`` or ``while`` constructions.
683+
684+
531685
Which API to use for fine grain tracing?
532686
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
533687

0 commit comments

Comments
 (0)