diff --git a/.jenkins/validate_tutorials_built.py b/.jenkins/validate_tutorials_built.py index 464267271cf..beccbc8f546 100644 --- a/.jenkins/validate_tutorials_built.py +++ b/.jenkins/validate_tutorials_built.py @@ -29,6 +29,7 @@ "intermediate_source/fx_conv_bn_fuser", "intermediate_source/_torch_export_nightly_tutorial", # does not work on release "advanced_source/super_resolution_with_onnxruntime", + "advanced_source/python_custom_ops", # https://github.com/pytorch/pytorch/issues/127443 "advanced_source/ddp_pipeline", # requires 4 gpus "advanced_source/usb_semisup_learn", # fails with CUDA OOM error, should try on a different worker "prototype_source/fx_graph_mode_ptq_dynamic", diff --git a/advanced_source/cpp_custom_ops.rst b/advanced_source/cpp_custom_ops.rst new file mode 100644 index 00000000000..66df7344522 --- /dev/null +++ b/advanced_source/cpp_custom_ops.rst @@ -0,0 +1,418 @@ +.. _cpp-custom-ops-tutorial: + +Custom C++ and CUDA Operators +============================= + +**Author:** `Richard Zou `_ + +.. grid:: 2 + + .. grid-item-card:: :octicon:`mortar-board;1em;` What you will learn + + * How to integrate custom operators written in C++/CUDA with PyTorch + * How to test custom operators using ``torch.library.opcheck`` + + .. grid-item-card:: :octicon:`list-unordered;1em;` Prerequisites + + * PyTorch 2.4 or later + * Basic understanding of C++ and CUDA programming + +PyTorch offers a large library of operators that work on Tensors (e.g. torch.add, torch.sum, etc). +However, you may wish to bring a new custom operator to PyTorch. This tutorial demonstrates the +blessed path to authoring a custom operator written in C++/CUDA. + +For our tutorial, we’ll demonstrate how to author a fused multiply-add C++ +and CUDA operator that composes with PyTorch subsystems. The semantics of +the operation are as follows: + +.. code-block:: python + + def mymuladd(a: Tensor, b: Tensor, c: float): + return a * b + c + +You can find the end-to-end working example for this tutorial +`here `_ . + +Setting up the Build System +--------------------------- + +If you are developing custom C++/CUDA code, it must be compiled. +Note that if you’re interfacing with a Python library that already has bindings +to precompiled C++/CUDA code, you might consider writing a custom Python operator +instead (:ref:`python-custom-ops-tutorial`). + +Use `torch.utils.cpp_extension `_ +to compile custom C++/CUDA code for use with PyTorch +C++ extensions may be built either "ahead of time" with setuptools, or "just in time" +via `load_inline `_; +we’ll focus on the "ahead of time" flavor. + +Using ``cpp_extension`` is as simple as writing the following ``setup.py``: + +.. code-block:: python + + from setuptools import setup, Extension + from torch.utils import cpp_extension + + setup(name="extension_cpp", + ext_modules=[ + cpp_extension.CppExtension("extension_cpp", ["muladd.cpp"])], + cmdclass={'build_ext': cpp_extension.BuildExtension}) + +If you need to compile CUDA code (for example, ``.cu`` files), then instead use +`torch.utils.cpp_extension.CUDAExtension `_. +Please see how +`extension-cpp `_ for an example for +how this is set up. + +Defining the custom op and adding backend implementations +--------------------------------------------------------- +First, let's write a C++ function that computes ``mymuladd``: + +.. code-block:: cpp + + at::Tensor mymuladd_cpu(at::Tensor a, const at::Tensor& b, double c) { + TORCH_CHECK(a.sizes() == b.sizes()); + TORCH_CHECK(a.dtype() == at::kFloat); + TORCH_CHECK(b.dtype() == at::kFloat); + TORCH_INTERNAL_ASSERT(a.device().type() == at::DeviceType::CPU); + TORCH_INTERNAL_ASSERT(b.device().type() == at::DeviceType::CPU); + at::Tensor a_contig = a.contiguous(); + at::Tensor b_contig = b.contiguous(); + at::Tensor result = torch::empty(a_contig.sizes(), a_contig.options()); + const float* a_ptr = a_contig.data_ptr(); + const float* b_ptr = b_contig.data_ptr(); + float* result_ptr = result.data_ptr(); + for (int64_t i = 0; i < result.numel(); i++) { + result_ptr[i] = a_ptr[i] * b_ptr[i] + c; + } + return result; + } + +In order to use this from PyTorch’s Python frontend, we need to register it +as a PyTorch operator using the ``TORCH_LIBRARY`` API. This will automatically +bind the operator to Python. + +Operator registration is a two step-process: + +- **Defining the operator** - This step ensures that PyTorch is aware of the new operator. +- **Registering backend implementations** - In this step, implementations for various + backends, such as CPU and CUDA, are associated with the operator. + +Defining an operator +^^^^^^^^^^^^^^^^^^^^ +To define an operator, follow these steps: + +1. select a namespace for an operator. We recommend the namespace be the name of your top-level + project; we’ll use "extension_cpp" in our tutorial. +2. provide a schema string that specifies the input/output types of the operator and if an + input Tensors will be mutated. We support more types in addition to Tensor and float; + please see `The Custom Operators Manual `_ + for more details. + + * If you are authoring an operator that can mutate its input Tensors, please see here + (:ref:`mutable-ops`) for how to specify that. + +.. code-block:: cpp + + TORCH_LIBRARY(extension_cpp, m) { + // Note that "float" in the schema corresponds to the C++ double type + // and the Python float type. + m.def("mymuladd(Tensor a, Tensor b, float c) -> Tensor"); + } + +This makes the operator available from Python via ``torch.ops.extension_cpp.mymuladd``. + +Registering backend implementations for an operator +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +Use ``TORCH_LIBRARY_IMPL`` to register a backend implementation for the operator. + +.. code-block:: cpp + + TORCH_LIBRARY_IMPL(extension_cpp, CPU, m) { + m.impl("mymuladd", &mymuladd_cpu); + } + +If you also have a CUDA implementation of ``myaddmul``, you can register it +in a separate ``TORCH_LIBRARY_IMPL`` block: + +.. code-block:: cpp + + __global__ void muladd_kernel(int numel, const float* a, const float* b, float c, float* result) { + int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < numel) result[idx] = a[idx] * b[idx] + c; + } + + at::Tensor mymuladd_cuda(const at::Tensor& a, const at::Tensor& b, double c) { + TORCH_CHECK(a.sizes() == b.sizes()); + TORCH_CHECK(a.dtype() == at::kFloat); + TORCH_CHECK(b.dtype() == at::kFloat); + TORCH_INTERNAL_ASSERT(a.device().type() == at::DeviceType::CUDA); + TORCH_INTERNAL_ASSERT(b.device().type() == at::DeviceType::CUDA); + at::Tensor a_contig = a.contiguous(); + at::Tensor b_contig = b.contiguous(); + at::Tensor result = torch::empty(a_contig.sizes(), a_contig.options()); + const float* a_ptr = a_contig.data_ptr(); + const float* b_ptr = b_contig.data_ptr(); + float* result_ptr = result.data_ptr(); + + int numel = a_contig.numel(); + muladd_kernel<<<(numel+255)/256, 256>>>(numel, a_ptr, b_ptr, c, result_ptr); + return result; + } + + TORCH_LIBRARY_IMPL(extension_cpp, CUDA, m) { + m.impl("mymuladd", &mymuladd_cuda); + } + +Adding ``torch.compile`` support for an operator +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +To add ``torch.compile`` support for an operator, we must add a FakeTensor kernel (also +known as a "meta kernel" or "abstract impl"). FakeTensors are Tensors that have +metadata (such as shape, dtype, device) but no data: the FakeTensor kernel for an +operator specifies how to compute the metadata of output tensors given the metadata of input tensors. + +We recommend that this be done from Python via the `torch.library.register_fake` API, +though it is possible to do this from C++ as well (see +`The Custom Operators Manual `_ +for more details). + +.. code-block:: python + + # Important: the C++ custom operator definitions should be loaded first + # before calling ``torch.library`` APIs that add registrations for the + # C++ custom operator(s). The following import loads our + # C++ custom operator definitions. + # See the next section for more details. + from . import _C + + @torch.library.register_fake("extension_cpp::mymuladd") + def _(a, b, c): + torch._check(a.shape == b.shape) + torch._check(a.dtype == torch.float) + torch._check(b.dtype == torch.float) + torch._check(a.device == b.device) + return torch.empty_like(a) + +Setting up hybrid Python/C++ registration +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +In this tutorial, we defined a custom operator in C++, added CPU/CUDA +implementations in C++, and added ``FakeTensor`` kernels and backward formulas +in Python. The order in which these registrations are loaded (or imported) +matters (importing in the wrong order will lead to an error). + +To use the custom operator with hybrid Python/C++ registrations, we must +first load the C++ library that holds the custom operator definition +and then call the ``torch.library`` registration APIs. This can happen in one +of two ways: + +1. If you're following this tutorial, importing the Python C extension module + we created will load the C++ custom operator definitions. +2. If your C++ custom operator is located in a shared library object, you can + also use ``torch.ops.load_library("/path/to/library.so")`` to load it. + + +Adding training (autograd) support for an operator +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +Use ``torch.library.register_autograd`` to add training support for an operator. Prefer +this over directly using Python ``torch.autograd.Function`` or C++ ``torch::autograd::Function``; +you must use those in a very specific way to avoid silent incorrectness (see +`The Custom Operators Manual `_ +for more details). + +.. code-block:: python + + def _backward(ctx, grad): + a, b = ctx.saved_tensors + grad_a, grad_b = None, None + if ctx.needs_input_grad[0]: + grad_a = grad * b + if ctx.needs_input_grad[1]: + grad_b = grad * a + return grad_a, grad_b, None + + def _setup_context(ctx, inputs, output): + a, b, c = inputs + saved_a, saved_b = None, None + if ctx.needs_input_grad[0]: + saved_b = b + if ctx.needs_input_grad[1]: + saved_a = a + ctx.save_for_backward(saved_a, saved_b) + + # This code adds training support for the operator. You must provide us + # the backward formula for the operator and a `setup_context` function + # to save values to be used in the backward. + torch.library.register_autograd( + "extension_cpp::mymuladd", _backward, setup_context=_setup_context) + +Note that the backward must be a composition of PyTorch-understood operators. +If you wish to use another custom C++ or CUDA kernel in your backwards pass, +it must be wrapped into a custom operator. + +If we had our own custom ``mymul`` kernel, we would need to wrap it into a +custom operator and then call that from the backward: + +.. code-block:: cpp + + // New! a mymul_cpu kernel + at::Tensor mymul_cpu(const at::Tensor& a, const at::Tensor& b) { + TORCH_CHECK(a.sizes() == b.sizes()); + TORCH_CHECK(a.dtype() == at::kFloat); + TORCH_CHECK(b.dtype() == at::kFloat); + TORCH_CHECK(a.device().type() == at::DeviceType::CPU); + TORCH_CHECK(b.device().type() == at::DeviceType::CPU); + at::Tensor a_contig = a.contiguous(); + at::Tensor b_contig = b.contiguous(); + at::Tensor result = torch::empty(a_contig.sizes(), a_contig.options()); + const float* a_ptr = a_contig.data_ptr(); + const float* b_ptr = b_contig.data_ptr(); + float* result_ptr = result.data_ptr(); + for (int64_t i = 0; i < result.numel(); i++) { + result_ptr[i] = a_ptr[i] * b_ptr[i]; + } + return result; + } + + TORCH_LIBRARY(extension_cpp, m) { + m.def("mymuladd(Tensor a, Tensor b, float c) -> Tensor"); + // New! defining the mymul operator + m.def("mymul(Tensor a, Tensor b) -> Tensor"); + } + + + TORCH_LIBRARY_IMPL(extension_cpp, CPU, m) { + m.impl("mymuladd", &mymuladd_cpu); + // New! registering the cpu kernel for the mymul operator + m.impl("mymul", &mymul_cpu); + } + +.. code-block:: python + + def _backward(ctx, grad): + a, b = ctx.saved_tensors + grad_a, grad_b = None, None + if ctx.needs_input_grad[0]: + grad_a = torch.ops.extension_cpp.mymul.default(grad, b) + if ctx.needs_input_grad[1]: + grad_b = torch.ops.extension_cpp.mymul.default(grad, a) + return grad_a, grad_b, None + + + def _setup_context(ctx, inputs, output): + a, b, c = inputs + saved_a, saved_b = None, None + if ctx.needs_input_grad[0]: + saved_b = b + if ctx.needs_input_grad[1]: + saved_a = a + ctx.save_for_backward(saved_a, saved_b) + + + # This code adds training support for the operator. You must provide us + # the backward formula for the operator and a `setup_context` function + # to save values to be used in the backward. + torch.library.register_autograd( + "extension_cpp::mymuladd", _backward, setup_context=_setup_context) + +Testing an operator +------------------- +Use ``torch.library.opcheck`` to test that the custom op was registered correctly. +Note that this function does not test that the gradients are mathematically correct +-- plan to write separate tests for that, either manual ones or by using +``torch.autograd.gradcheck``. + +.. code-block:: python + + def sample_inputs(device, *, requires_grad=False): + def make_tensor(*size): + return torch.randn(size, device=device, requires_grad=requires_grad) + + def make_nondiff_tensor(*size): + return torch.randn(size, device=device, requires_grad=False) + + return [ + [make_tensor(3), make_tensor(3), 1], + [make_tensor(20), make_tensor(20), 3.14], + [make_tensor(20), make_nondiff_tensor(20), -123], + [make_nondiff_tensor(2, 3), make_tensor(2, 3), -0.3], + ] + + def reference_muladd(a, b, c): + return a * b + c + + samples = sample_inputs(device, requires_grad=True) + samples.extend(sample_inputs(device, requires_grad=False)) + for args in samples: + # Correctness test + result = torch.ops.extension_cpp.mymuladd(*args) + expected = reference_muladd(*args) + torch.testing.assert_close(result, expected) + + # Use opcheck to check for incorrect usage of operator registration APIs + torch.library.opcheck(torch.ops.extension_cpp.mymuladd.default, args) + +.. _mutable-ops: + +Creating mutable operators +-------------------------- +You may wish to author a custom operator that mutates its inputs. Use ``Tensor(a!)`` +to specify each mutable Tensor in the schema; otherwise, there will be undefined +behavior. If there are multiple mutated Tensors, use different names (for example, ``Tensor(a!)``, +``Tensor(b!)``, ``Tensor(c!)``) for each mutable Tensor. + +Let's author a ``myadd_out(a, b, out)`` operator, which writes the contents of ``a+b`` into ``out``. + +.. code-block:: cpp + + // An example of an operator that mutates one of its inputs. + void myadd_out_cpu(const at::Tensor& a, const at::Tensor& b, at::Tensor& out) { + TORCH_CHECK(a.sizes() == b.sizes()); + TORCH_CHECK(b.sizes() == out.sizes()); + TORCH_CHECK(a.dtype() == at::kFloat); + TORCH_CHECK(b.dtype() == at::kFloat); + TORCH_CHECK(out.dtype() == at::kFloat); + TORCH_CHECK(out.is_contiguous()); + TORCH_INTERNAL_ASSERT(a.device().type() == at::DeviceType::CPU); + TORCH_INTERNAL_ASSERT(b.device().type() == at::DeviceType::CPU); + TORCH_INTERNAL_ASSERT(out.device().type() == at::DeviceType::CPU); + at::Tensor a_contig = a.contiguous(); + at::Tensor b_contig = b.contiguous(); + const float* a_ptr = a_contig.data_ptr(); + const float* b_ptr = b_contig.data_ptr(); + float* result_ptr = out.data_ptr(); + for (int64_t i = 0; i < out.numel(); i++) { + result_ptr[i] = a_ptr[i] + b_ptr[i]; + } + } + +When defining the operator, we must specify that it mutates the out Tensor in the schema: + +.. code-block:: cpp + + TORCH_LIBRARY(extension_cpp, m) { + m.def("mymuladd(Tensor a, Tensor b, float c) -> Tensor"); + m.def("mymul(Tensor a, Tensor b) -> Tensor"); + // New! + m.def("myadd_out(Tensor a, Tensor b, Tensor(a!) out) -> ()"); + } + + TORCH_LIBRARY_IMPL(extension_cpp, CPU, m) { + m.impl("mymuladd", &mymuladd_cpu); + m.impl("mymul", &mymul_cpu); + // New! + m.impl("myadd_out", &myadd_out_cpu); + } + +.. note:: + + Do not return any mutated Tensors as outputs of the operator as this will + cause incompatibility with PyTorch subsystems like ``torch.compile``. + +Conclusion +---------- +In this tutorial, we went over the recommended approach to integrating Custom C++ +and CUDA operators with PyTorch. The ``TORCH_LIBRARY/torch.library`` APIs are fairly +low-level. For more information about how to use the API, see +`The Custom Operators Manual `_. diff --git a/advanced_source/custom_ops_landing_page.rst b/advanced_source/custom_ops_landing_page.rst new file mode 100644 index 00000000000..ebb238ef63e --- /dev/null +++ b/advanced_source/custom_ops_landing_page.rst @@ -0,0 +1,60 @@ +.. _custom-ops-landing-page: + +PyTorch Custom Operators Landing Page +===================================== + +PyTorch offers a large library of operators that work on Tensors (e.g. ``torch.add``, +``torch.sum``, etc). However, you may wish to bring a new custom operation to PyTorch +and get it to work with subsystems like ``torch.compile``, autograd, and ``torch.vmap``. +In order to do so, you must register the custom operation with PyTorch via the Python +`torch.library docs `_ or C++ ``TORCH_LIBRARY`` +APIs. + +TL;DR +----- + +Authoring a custom operator from Python +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +Please see :ref:`python-custom-ops-tutorial`. + +You may wish to author a custom operator from Python (as opposed to C++) if: +- you have a Python function you want PyTorch to treat as an opaque callable, especially with +respect to ``torch.compile`` and ``torch.export``. +- you have some Python bindings to C++/CUDA kernels and want those to compose with PyTorch +subsystems (like ``torch.compile`` or ``torch.autograd``) + +Integrating custom C++ and/or CUDA code with PyTorch +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +Please see :ref:`cpp-custom-ops-tutorial`. + +You may wish to author a custom operator from C++ (as opposed to Python) if: +- you have custom C++ and/or CUDA code. +- you plan to use this code with ``AOTInductor`` to do Python-less inference. + +The Custom Operators Manual +^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +For information not covered in the tutorials and this page, please see +`The Custom Operators Manual `_ +(we're working on moving the information to our docs site). We recommend that you +first read one of the tutorials above and then use the Custom Operators Manual as a reference; +it is not meant to be read head to toe. + +When should I create a Custom Operator? +--------------------------------------- +If your operation is expressible as a composition of built-in PyTorch operators +then please write it as a Python function and call it instead of creating a +custom operator. Use the operator registration APIs to create a custom operator if you +are calling into some library that PyTorch doesn't understand (e.g. custom C/C++ code, +a custom CUDA kernel, or Python bindings to C/C++/CUDA extensions). + +Why should I create a Custom Operator? +-------------------------------------- + +It is possible to use a C/C++/CUDA kernel by grabbing a Tensor's data pointer +and passing it to a pybind'ed kernel. However, this approach doesn't compose with +PyTorch subsystems like autograd, torch.compile, vmap, and more. In order +for an operation to compose with PyTorch subsystems, it must be registered +via the operator registration APIs. diff --git a/advanced_source/python_custom_ops.py b/advanced_source/python_custom_ops.py new file mode 100644 index 00000000000..36045cb9e48 --- /dev/null +++ b/advanced_source/python_custom_ops.py @@ -0,0 +1,262 @@ +# -*- coding: utf-8 -*- + +""" +.. _python-custom-ops-tutorial: + +Python Custom Operators +======================= + +.. grid:: 2 + + .. grid-item-card:: :octicon:`mortar-board;1em;` What you will learn + + * How to integrate custom operators written in Python with PyTorch + * How to test custom operators using ``torch.library.opcheck`` + + .. grid-item-card:: :octicon:`list-unordered;1em;` Prerequisites + + * PyTorch 2.4 or later + +PyTorch offers a large library of operators that work on Tensors (e.g. +``torch.add``, ``torch.sum``, etc). However, you might wish to use a new customized +operator with PyTorch, perhaps written by a third-party library. This tutorial +shows how to wrap Python functions so that they behave like PyTorch native +operators. Reasons why you may wish to create a custom operator in PyTorch include: + +- Treating an arbitrary Python function as an opaque callable with respect + to ``torch.compile`` (that is, prevent ``torch.compile`` from tracing + into the function). +- Adding training support to an arbitrary Python function + +Please note that if your operation can be expressed as a composition of +existing PyTorch operators, then there is usually no need to use the custom operator +API -- everything (for example ``torch.compile``, training support) should +just work. +""" +###################################################################### +# Example: Wrapping PIL's crop into a custom operator +# ------------------------------------ +# Let's say that we are using PIL's ``crop`` operation. + +import torch +from torchvision.transforms.functional import to_pil_image, pil_to_tensor +import PIL +import IPython +import matplotlib.pyplot as plt + +def crop(pic, box): + img = to_pil_image(pic.cpu()) + cropped_img = img.crop(box) + return pil_to_tensor(cropped_img).to(pic.device) / 255. + +def display(img): + plt.imshow(img.numpy().transpose((1, 2, 0))) + +img = torch.ones(3, 64, 64) +img *= torch.linspace(0, 1, steps=64) * torch.linspace(0, 1, steps=64).unsqueeze(-1) +display(img) + +###################################################################### + +cropped_img = crop(img, (10, 10, 50, 50)) +display(cropped_img) + +###################################################################### +# ``crop`` is not handled effectively out-of-the-box by +# ``torch.compile``: ``torch.compile`` induces a +# `"graph break" `_ +# on functions it is unable to handle and graph breaks are bad for performance. +# The following code demonstrates this by raising an error +# (``torch.compile`` with ``fullgraph=True`` raises an error if a +# graph break occurs). + +@torch.compile(fullgraph=True) +def f(img): + return crop(img, (10, 10, 50, 50)) + +# The following raises an error. Uncomment the line to see it. +# cropped_img = f(img) + +###################################################################### +# In order to black-box ``crop`` for use with ``torch.compile``, we need to +# do two things: +# +# 1. wrap the function into a PyTorch custom operator. +# 2. add a "``FakeTensor`` kernel" (aka "meta kernel") to the operator. +# Given the metadata (e.g. shapes) +# of the input Tensors, this function says how to compute the metadata +# of the output Tensor(s). + + +from typing import Sequence + +# Use torch.library.custom_op to define a new custom operator. +# If your operator mutates any input Tensors, their names must be specified +# in the ``mutates_args`` argument. +@torch.library.custom_op("mylib::crop", mutates_args=()) +def crop(pic: torch.Tensor, box: Sequence[int]) -> torch.Tensor: + img = to_pil_image(pic.cpu()) + cropped_img = img.crop(box) + return (pil_to_tensor(cropped_img) / 255.).to(pic.device, pic.dtype) + +# Use register_fake to add a ``FakeTensor`` kernel for the operator +@crop.register_fake +def _(pic, box): + channels = pic.shape[0] + x0, y0, x1, y1 = box + return pic.new_empty(channels, y1 - y0, x1 - x0) + +###################################################################### +# After this, ``crop`` now works without graph breaks: + +@torch.compile(fullgraph=True) +def f(img): + return crop(img, (10, 10, 50, 50)) + +cropped_img = f(img) +display(img) + +###################################################################### + +display(cropped_img) + +###################################################################### +# Adding training support for crop +# -------------------------------- +# Use ``torch.library.register_autograd`` to add training support for an operator. +# Prefer this over directly using ``torch.autograd.Function``; some compositions of +# ``autograd.Function`` with PyTorch operator registration APIs can lead to (and +# has led to) silent incorrectness when composed with ``torch.compile``. +# +# The gradient formula for ``crop`` is essentially ``PIL.paste`` (we'll leave the +# derivation as an exercise to the reader). Let's first wrap ``paste`` into a +# custom operator: + +@torch.library.custom_op("mylib::paste", mutates_args=()) +def paste(im1: torch.Tensor, im2: torch.Tensor, coord: Sequence[int]) -> torch.Tensor: + assert im1.device == im2.device + assert im1.dtype == im2.dtype + im1_pil = to_pil_image(im1.cpu()) + im2_pil = to_pil_image(im2.cpu()) + PIL.Image.Image.paste(im1_pil, im2_pil, coord) + return (pil_to_tensor(im1_pil) / 255.).to(im1.device, im1.dtype) + +@paste.register_fake +def _(im1, im2, coord): + assert im1.device == im2.device + assert im1.dtype == im2.dtype + return torch.empty_like(im1) + +###################################################################### +# And now let's use ``register_autograd`` to specify the gradient formula for ``crop``: + +def backward(ctx, grad_output): + grad_input = grad_output.new_zeros(ctx.pic_shape) + grad_input = paste(grad_input, grad_output, ctx.coords) + return grad_input, None + +def setup_context(ctx, inputs, output): + pic, box = inputs + ctx.coords = box[:2] + ctx.pic_shape = pic.shape + +crop.register_autograd(backward, setup_context=setup_context) + +###################################################################### +# Note that the backward must be a composition of PyTorch-understood operators, +# which is why we wrapped paste into a custom operator instead of directly using +# PIL's paste. + +img = img.requires_grad_() +result = crop(img, (10, 10, 50, 50)) +result.sum().backward() +display(img.grad) + +###################################################################### +# This is the correct gradient, with 1s (white) in the cropped region and 0s +# (black) in the unused region. + +###################################################################### +# Testing Python Custom operators +# ------------------------------- +# Use ``torch.library.opcheck`` to test that the custom operator was registered +# correctly. This does not test that the gradients are mathematically correct; +# please write separate tests for that (either manual ones or ``torch.autograd.gradcheck``). +# +# To use ``opcheck``, pass it a set of example inputs to test against. If your +# operator supports training, then the examples should include Tensors that +# require grad. If your operator supports multiple devices, then the examples +# should include Tensors from each device. + +examples = [ + [torch.randn(3, 64, 64), [0, 0, 10, 10]], + [torch.randn(3, 91, 91, requires_grad=True), [10, 0, 20, 10]], + [torch.randn(3, 60, 60, dtype=torch.double), [3, 4, 32, 20]], + [torch.randn(3, 512, 512, requires_grad=True, dtype=torch.double), [3, 4, 32, 45]], +] + +for example in examples: + torch.library.opcheck(crop, example) + +###################################################################### +# Mutable Python Custom operators +# ------------------------------- +# You can also wrap a Python function that mutates its inputs into a custom +# operator. +# Functions that mutate inputs are common because that is how many low-level +# kernels are written; for example, a kernel that computes ``sin`` may take in +# the input and an output tensor and write ``input.sin()`` to the output tensor. +# +# We'll use ``numpy.sin`` to demonstrate an example of a mutable Python +# custom operator. + +import numpy as np + +@torch.library.custom_op("mylib::numpy_sin", mutates_args={"output"}, device_types="cpu") +def numpy_sin(input: torch.Tensor, output: torch.Tensor) -> None: + assert input.device == output.device + assert input.device.type == "cpu" + input_np = input.numpy() + output_np = output.numpy() + np.sin(input_np, out=output_np) + +###################################################################### +# Because the operator doesn't return anything, there is no need to register +# a ``FakeTensor`` kernel (meta kernel) to get it to work with ``torch.compile``. + +@torch.compile(fullgraph=True) +def f(x): + out = torch.empty(3) + numpy_sin(x, out) + return out + +x = torch.randn(3) +y = f(x) +assert torch.allclose(y, x.sin()) + +###################################################################### +# And here's an ``opcheck`` run telling us that we did indeed register the operator correctly. +# ``opcheck`` would error out if we forgot to add the output to ``mutates_args``, for example. + +example_inputs = [ + [torch.randn(3), torch.empty(3)], + [torch.randn(0, 3), torch.empty(0, 3)], + [torch.randn(1, 2, 3, 4, dtype=torch.double), torch.empty(1, 2, 3, 4, dtype=torch.double)], +] + +for example in example_inputs: + torch.library.opcheck(numpy_sin, example) + +###################################################################### +# Conclusion +# ---------- +# In this tutorial, we learned how to use ``torch.library.custom_op`` to +# create a custom operator in Python that works with PyTorch subsystems +# such as ``torch.compile`` and autograd. +# +# This tutorial provides a basic introduction to custom operators. +# For more detailed information, see: +# +# - `the torch.library documentation `_ +# - `the Custom Operators Manual `_ +# diff --git a/en-wordlist.txt b/en-wordlist.txt index d065c9234b5..d5674cdd215 100644 --- a/en-wordlist.txt +++ b/en-wordlist.txt @@ -177,6 +177,7 @@ OpenSlide Opset Ornstein PIL +PIL's PPO PatchPredictor PennFudan @@ -445,6 +446,7 @@ multithreading namespace natively ndarrays +nightlies num numericalize numpy diff --git a/index.rst b/index.rst index 112e8898404..b77aaf9c8d4 100644 --- a/index.rst +++ b/index.rst @@ -416,6 +416,27 @@ Welcome to PyTorch Tutorials :link: advanced/cpp_frontend.html :tags: Frontend-APIs,C++ +.. customcarditem:: + :header: Python Custom Operators Landing Page + :card_description: This is the landing page for all things related to custom operators in PyTorch. + :image: _static/img/thumbnails/cropped/Custom-Cpp-and-CUDA-Extensions.png + :link: advanced/custom_ops_landing_page.html + :tags: Extending-PyTorch,Frontend-APIs,C++,CUDA + +.. customcarditem:: + :header: Python Custom Operators + :card_description: Create Custom Operators in Python. Useful for black-boxing a Python function for use with torch.compile. + :image: _static/img/thumbnails/cropped/Custom-Cpp-and-CUDA-Extensions.png + :link: advanced/python_custom_ops.html + :tags: Extending-PyTorch,Frontend-APIs,C++,CUDA + +.. customcarditem:: + :header: Custom C++ and CUDA Operators + :card_description: How to extend PyTorch with custom C++ and CUDA operators. + :image: _static/img/thumbnails/cropped/Custom-Cpp-and-CUDA-Extensions.png + :link: advanced/cpp_custom_ops.html + :tags: Extending-PyTorch,Frontend-APIs,C++,CUDA + .. customcarditem:: :header: Custom C++ and CUDA Extensions :card_description: Create a neural network layer with no parameters using numpy. Then use scipy to create a neural network layer that has learnable weights. @@ -574,7 +595,7 @@ Welcome to PyTorch Tutorials .. customcarditem:: :header: (beta) Accelerating BERT with semi-structured sparsity - :card_description: Train BERT, prune it to be 2:4 sparse, and then accelerate it to achieve 2x inference speedups with semi-structured sparsity and torch.compile. + :card_description: Train BERT, prune it to be 2:4 sparse, and then accelerate it to achieve 2x inference speedups with semi-structured sparsity and torch.compile. :image: _static/img/thumbnails/cropped/Pruning-Tutorial.png :link: advanced/semi_structured_sparse.html :tags: Text,Model-Optimization @@ -783,7 +804,7 @@ Welcome to PyTorch Tutorials .. customcarditem:: :header: Using the ExecuTorch SDK to Profile a Model - :card_description: Explore how to use the ExecuTorch SDK to profile, debug, and visualize ExecuTorch models + :card_description: Explore how to use the ExecuTorch SDK to profile, debug, and visualize ExecuTorch models :image: _static/img/ExecuTorch-Logo-cropped.svg :link: https://pytorch.org/executorch/stable/tutorials/sdk-integration-tutorial.html :tags: Edge @@ -921,6 +942,7 @@ Additional Resources beginner/basics/autogradqs_tutorial beginner/basics/optimization_tutorial beginner/basics/saveloadrun_tutorial + advanced/custom_ops_landing_page .. toctree:: :maxdepth: 2 @@ -1067,6 +1089,9 @@ Additional Resources :hidden: :caption: Extending PyTorch + advanced/custom_ops_landing_page + advanced/python_custom_ops + advanced/cpp_custom_ops intermediate/custom_function_double_backward_tutorial intermediate/custom_function_conv_bn_tutorial advanced/cpp_extension @@ -1137,7 +1162,7 @@ Additional Resources Using the ExecuTorch SDK to Profile a Model Building an ExecuTorch iOS Demo App Building an ExecuTorch Android Demo App - Lowering a Model as a Delegate + Lowering a Model as a Delegate .. toctree:: :maxdepth: 2