|
| 1 | +Custom C++ and CUDA Operators |
| 2 | +============================= |
| 3 | + |
| 4 | +.. note:: |
| 5 | + This tutorial is for PyTorch 2.4+ and the PyTorch nightlies. |
| 6 | + |
| 7 | +PyTorch offers a large library of operators that work on Tensors (e.g. torch.add, torch.sum, etc). |
| 8 | +However, you may wish to bring a new custom operator to PyTorch. This tutorial demonstrates the |
| 9 | +blessed path to authoring a custom operator written in C++/CUDA. |
| 10 | + |
| 11 | +For our tutorial, we’ll demonstrate how to author a fused multiply-add C++ |
| 12 | +and CUDA operator that composes with PyTorch subsystems. The semantics of |
| 13 | +the operation are as follows: |
| 14 | + |
| 15 | +.. code-block:: python |
| 16 | +
|
| 17 | + def mymuladd(a: Tensor, b: Tensor, c: float): |
| 18 | + return a * b + c |
| 19 | +
|
| 20 | +You can find the end-to-end working example for this tutorial over at |
| 21 | +https://github.com/pytorch/extension-cpp . |
| 22 | + |
| 23 | +Build System |
| 24 | +------------ |
| 25 | + |
| 26 | +If you author custom C++/CUDA code, it needs to be compiled somehow. |
| 27 | +Note that if you’re interfacing with a Python library that already has bindings |
| 28 | +to precompiled C++/CUDA code, then you may actually want to write a Python custom operator |
| 29 | +(TODO: tutorial) |
| 30 | + |
| 31 | +Use `torch.utils.cpp_extension <https://pytorch.org/docs/stable/cpp_extension.html>`_ |
| 32 | +to compile custom C++/CUDA code for use with PyTorch |
| 33 | +C++ extensions may be built either "ahead of time" with setuptools, or "just in time" |
| 34 | +via `load_inline <https://pytorch.org/docs/stable/cpp_extension.html#torch.utils.cpp_extension.load_inline>`; |
| 35 | +we’ll focus on the "ahead of time" flavor. |
| 36 | + |
| 37 | +Using cpp_extension is as simple as writing the following setup.py: |
| 38 | + |
| 39 | +.. code-block:: python |
| 40 | +
|
| 41 | + from setuptools import setup, Extension |
| 42 | + from torch.utils import cpp_extension |
| 43 | +
|
| 44 | + setup(name="extension_cpp", |
| 45 | + ext_modules=[ |
| 46 | + cpp_extension.CppExtension("extension_cpp", ["muladd.cpp"])], |
| 47 | + cmdclass={'build_ext': cpp_extension.BuildExtension}) |
| 48 | +
|
| 49 | +If you need to compile CUDA code (e.g. .cu files), then instead use |
| 50 | +`torch.utils.cpp_extension.CUDAExtension <https://pytorch.org/docs/stable/cpp_extension.html#torch.utils.cpp_extension.CUDAExtension>`_ |
| 51 | +Please see how https://github.com/pytorch/extension-cpp is set up for more details. |
| 52 | + |
| 53 | +Defining the custom op and adding backend implementations |
| 54 | +--------------------------------------------------------- |
| 55 | +First, let’s write a C++ function that computes mymuladd: |
| 56 | + |
| 57 | +.. code-block:: cpp |
| 58 | + at::Tensor mymuladd_cpu(at::Tensor a, const at::Tensor& b, double c) { |
| 59 | + TORCH_CHECK(a.sizes() == b.sizes()); |
| 60 | + TORCH_CHECK(a.dtype() == at::kFloat); |
| 61 | + TORCH_CHECK(b.dtype() == at::kFloat); |
| 62 | + TORCH_INTERNAL_ASSERT(a.device().type() == at::DeviceType::CPU); |
| 63 | + TORCH_INTERNAL_ASSERT(b.device().type() == at::DeviceType::CPU); |
| 64 | + at::Tensor a_contig = a.contiguous(); |
| 65 | + at::Tensor b_contig = b.contiguous(); |
| 66 | + at::Tensor result = torch::empty(a_contig.sizes(), a_contig.options()); |
| 67 | + const float* a_ptr = a_contig.data_ptr<float>(); |
| 68 | + const float* b_ptr = b_contig.data_ptr<float>(); |
| 69 | + float* result_ptr = result.data_ptr<float>(); |
| 70 | + for (int64_t i = 0; i < result.numel(); i++) { |
| 71 | + result_ptr[i] = a_ptr[i] * b_ptr[i] + c; |
| 72 | + } |
| 73 | + return result; |
| 74 | + } |
| 75 | +
|
| 76 | +In order to use this from PyTorch’s Python frontend, we need to register it |
| 77 | +as a PyTorch operator using the TORCH_LIBRARY API. This will automatically |
| 78 | +bind the operator to Python. |
| 79 | + |
| 80 | +Operator registration is a two step-process: |
| 81 | + |
| 82 | +- we need to define the operator (so that PyTorch knows about it) |
| 83 | +- we need to register various backend implementations (e.g. CPU/CUDA) to the operator |
| 84 | + |
| 85 | +How to define an operator |
| 86 | +^^^^^^^^^^^^^^^^^^^^^^^^^ |
| 87 | +To define an operator: |
| 88 | + |
| 89 | +- select a namespace for an operator. We recommend the namespace be the name of your top-level |
| 90 | +project; we’ll use "extension_cpp" in our tutorial. |
| 91 | +- provide a schema string that specifies the input/output types of the operator and if an |
| 92 | +input Tensors will be mutated. We support more types in addition to Tensor and float; |
| 93 | +please see `The Custom Operators Manual <https://pytorch.org/docs/main/notes/custom_operators.html>`_ |
| 94 | +for more details. |
| 95 | + |
| 96 | +If you are authoring an operator that can mutate its input Tensors, please see here |
| 97 | +(:ref:`mutable-ops`) for how to specify that. |
| 98 | + |
| 99 | +.. code-block:: cpp |
| 100 | + TORCH_LIBRARY(extension_cpp, m) { |
| 101 | + // Note that "float" in the schema corresponds to the C++ double type |
| 102 | + // and the Python float type. |
| 103 | + m.def("mymuladd(Tensor a, Tensor b, float c) -> Tensor"); |
| 104 | + } |
| 105 | +
|
| 106 | +This makes the operator available from Python via ``torch.ops.extension_cpp.mymuladd``. |
| 107 | + |
| 108 | +How to register backend implementations for an operator |
| 109 | +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ |
| 110 | +Use TORCH_LIBRARY_IMPL to register a backend implementation for the operator. |
| 111 | + |
| 112 | +.. code-block:: cpp |
| 113 | + TORCH_LIBRARY_IMPL(extension_cpp, CPU, m) { |
| 114 | + m.impl("mymuladd", &mymuladd_cpu); |
| 115 | + } |
| 116 | +
|
| 117 | +If we also have a CUDA implementation myaddmul_cuda, we can register it in a separate TORCH_LIBRARY_IMPL block: |
| 118 | + |
| 119 | +.. code-block:: cpp |
| 120 | + __global__ void muladd_kernel(int numel, const float* a, const float* b, float c, float* result) { |
| 121 | + int idx = blockIdx.x * blockDim.x + threadIdx.x; |
| 122 | + if (idx < numel) result[idx] = a[idx] * b[idx] + c; |
| 123 | + } |
| 124 | + |
| 125 | + at::Tensor mymuladd_cuda(const at::Tensor& a, const at::Tensor& b, double c) { |
| 126 | + TORCH_CHECK(a.sizes() == b.sizes()); |
| 127 | + TORCH_CHECK(a.dtype() == at::kFloat); |
| 128 | + TORCH_CHECK(b.dtype() == at::kFloat); |
| 129 | + TORCH_INTERNAL_ASSERT(a.device().type() == at::DeviceType::CUDA); |
| 130 | + TORCH_INTERNAL_ASSERT(b.device().type() == at::DeviceType::CUDA); |
| 131 | + at::Tensor a_contig = a.contiguous(); |
| 132 | + at::Tensor b_contig = b.contiguous(); |
| 133 | + at::Tensor result = torch::empty(a_contig.sizes(), a_contig.options()); |
| 134 | + const float* a_ptr = a_contig.data_ptr<float>(); |
| 135 | + const float* b_ptr = b_contig.data_ptr<float>(); |
| 136 | + float* result_ptr = result.data_ptr<float>(); |
| 137 | + |
| 138 | + int numel = a_contig.numel(); |
| 139 | + muladd_kernel<<<(numel+255)/256, 256>>>(numel, a_ptr, b_ptr, c, result_ptr); |
| 140 | + return result; |
| 141 | + } |
| 142 | + |
| 143 | + TORCH_LIBRARY_IMPL(extension_cpp, CUDA, m) { |
| 144 | + m.impl("mymuladd", &mymuladd_cuda); |
| 145 | + } |
| 146 | +
|
| 147 | +How to add torch.compile support for an operator |
| 148 | +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ |
| 149 | + |
| 150 | +To add torch.compile support for an operator, we must add a FakeTensor kernel (also |
| 151 | +known as a “meta kernel” or “abstract impl”). FakeTensors are Tensors that have |
| 152 | +metadata (i.e. shape, dtype, device) but no data: the FakeTensor kernel for an |
| 153 | +operator specifies how to compute the metadata of output tensors given the metadata of input tensors. |
| 154 | + |
| 155 | +We recommend that this be done from Python via the `torch.library.register_fake` API, |
| 156 | +though it is possible to do this from C++ as well (see |
| 157 | +`The Custom Operators Manual <https://pytorch.org/docs/main/notes/custom_operators.html>`_ |
| 158 | +for more details). |
| 159 | + |
| 160 | +.. code-block:: python |
| 161 | + @torch.library.register_fake("extension_cpp::mymuladd") |
| 162 | + def _(a, b, c): |
| 163 | + torch._check(a.shape == b.shape) |
| 164 | + torch._check(a.dtype == torch.float) |
| 165 | + torch._check(b.dtype == torch.float) |
| 166 | + torch._check(a.device == b.device) |
| 167 | + return torch.empty_like(a) |
| 168 | + |
| 169 | +How to add training (autograd) support for an operator |
| 170 | +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ |
| 171 | +Use torch.library.register_autograd to add training support for an operator. Prefer |
| 172 | +this over directly using Python torch.autograd.Function or C++ torch::autograd::Function; |
| 173 | +one must use those in a very specific way to avoid silent incorrectness (see |
| 174 | +`The Custom Operators Manual <https://pytorch.org/docs/main/notes/custom_operators.html>`_ |
| 175 | +for more details). |
| 176 | + |
| 177 | +.. code-block:: python |
| 178 | + def _backward(ctx, grad): |
| 179 | + a, b = ctx.saved_tensors |
| 180 | + grad_a, grad_b = None, None |
| 181 | + if ctx.needs_input_grad[0]: |
| 182 | + grad_a = grad * b |
| 183 | + if ctx.needs_input_grad[1]: |
| 184 | + grad_b = grad * a |
| 185 | + return grad_a, grad_b, None |
| 186 | + |
| 187 | + def _setup_context(ctx, inputs, output): |
| 188 | + a, b, c = inputs |
| 189 | + saved_a, saved_b = None, None |
| 190 | + if ctx.needs_input_grad[0]: |
| 191 | + saved_b = b |
| 192 | + if ctx.needs_input_grad[1]: |
| 193 | + saved_a = a |
| 194 | + ctx.save_for_backward(saved_a, saved_b) |
| 195 | + |
| 196 | + # This adds training support for the operator. You must provide us |
| 197 | + # the backward formula for the operator and a `setup_context` function |
| 198 | + # to save values to be used in the backward. |
| 199 | + torch.library.register_autograd( |
| 200 | + "extension_cpp::mymuladd", _backward, setup_context=_setup_context) |
| 201 | +
|
| 202 | +Note that the backward must be a composition of PyTorch-understood operators. |
| 203 | +If you wish to use another custom C++ or CUDA kernel in your backwards pass, |
| 204 | +it must be wrapped into a custom op. |
| 205 | + |
| 206 | +So if we had our own custom mymul kernel, we would need to wrap it into a |
| 207 | +custom operator and then call that from the backward: |
| 208 | + |
| 209 | +.. code-block:: cpp |
| 210 | + // New! a mymul_cpu kernel |
| 211 | + at::Tensor mymul_cpu(const at::Tensor& a, const at::Tensor& b) { |
| 212 | + TORCH_CHECK(a.sizes() == b.sizes()); |
| 213 | + TORCH_CHECK(a.dtype() == at::kFloat); |
| 214 | + TORCH_CHECK(b.dtype() == at::kFloat); |
| 215 | + TORCH_INTERNAL_ASSERT(a.device().type() == at::DeviceType::CPU); |
| 216 | + TORCH_INTERNAL_ASSERT(b.device().type() == at::DeviceType::CPU); |
| 217 | + at::Tensor a_contig = a.contiguous(); |
| 218 | + at::Tensor b_contig = b.contiguous(); |
| 219 | + at::Tensor result = torch::empty(a_contig.sizes(), a_contig.options()); |
| 220 | + const float* a_ptr = a_contig.data_ptr<float>(); |
| 221 | + const float* b_ptr = b_contig.data_ptr<float>(); |
| 222 | + float* result_ptr = result.data_ptr<float>(); |
| 223 | + for (int64_t i = 0; i < result.numel(); i++) { |
| 224 | + result_ptr[i] = a_ptr[i] * b_ptr[i]; |
| 225 | + } |
| 226 | + return result; |
| 227 | + } |
| 228 | + |
| 229 | + TORCH_LIBRARY(extension_cpp, m) { |
| 230 | + m.def("mymuladd(Tensor a, Tensor b, float c) -> Tensor"); |
| 231 | + // New! defining the mymul operator |
| 232 | + m.def("mymul(Tensor a, Tensor b) -> Tensor"); |
| 233 | + } |
| 234 | + |
| 235 | + |
| 236 | + TORCH_LIBRARY_IMPL(extension_cpp, CPU, m) { |
| 237 | + m.impl("mymuladd", &mymuladd_cpu); |
| 238 | + // New! registering the cpu kernel for the mymul operator |
| 239 | + m.impl("mymul", &mymul_cpu); |
| 240 | + } |
| 241 | +
|
| 242 | +.. code-block:: python |
| 243 | +
|
| 244 | + def _backward(ctx, grad): |
| 245 | + a, b = ctx.saved_tensors |
| 246 | + grad_a, grad_b = None, None |
| 247 | + if ctx.needs_input_grad[0]: |
| 248 | + grad_a = torch.ops.extension_cpp.mymul.default(grad, b) |
| 249 | + if ctx.needs_input_grad[1]: |
| 250 | + grad_b = torch.ops.extension_cpp.mymul.default(grad, a) |
| 251 | + return grad_a, grad_b, None |
| 252 | + |
| 253 | + |
| 254 | + def _setup_context(ctx, inputs, output): |
| 255 | + a, b, c = inputs |
| 256 | + saved_a, saved_b = None, None |
| 257 | + if ctx.needs_input_grad[0]: |
| 258 | + saved_b = b |
| 259 | + if ctx.needs_input_grad[1]: |
| 260 | + saved_a = a |
| 261 | + ctx.save_for_backward(saved_a, saved_b) |
| 262 | + |
| 263 | + |
| 264 | + # This adds training support for the operator. You must provide us |
| 265 | + # the backward formula for the operator and a `setup_context` function |
| 266 | + # to save values to be used in the backward. |
| 267 | + torch.library.register_autograd( |
| 268 | + "extension_cpp::mymuladd", _backward, setup_context=_setup_context) |
| 269 | +
|
| 270 | +How to test an operator |
| 271 | +----------------------- |
| 272 | +Use torch.library.opcheck to test that the custom op was registered correctly. |
| 273 | +This does not test that the gradients are mathematically correct; please write |
| 274 | +separate tests for that (either manual ones or torch.autograd.gradcheck). |
| 275 | + |
| 276 | +.. code-block:: python |
| 277 | + def sample_inputs(device, *, requires_grad=False): |
| 278 | + def make_tensor(*size): |
| 279 | + return torch.randn(size, device=device, requires_grad=requires_grad) |
| 280 | + |
| 281 | + def make_nondiff_tensor(*size): |
| 282 | + return torch.randn(size, device=device, requires_grad=False) |
| 283 | + |
| 284 | + return [ |
| 285 | + [make_tensor(3), make_tensor(3), 1], |
| 286 | + [make_tensor(20), make_tensor(20), 3.14], |
| 287 | + [make_tensor(20), make_nondiff_tensor(20), -123], |
| 288 | + [make_nondiff_tensor(2, 3), make_tensor(2, 3), -0.3], |
| 289 | + ] |
| 290 | + |
| 291 | + def reference_muladd(a, b, c): |
| 292 | + return a * b + c |
| 293 | + |
| 294 | + samples = sample_inputs(device, requires_grad=True) |
| 295 | + samples.extend(sample_inputs(device, requires_grad=False)) |
| 296 | + for args in samples: |
| 297 | + # Correctness test |
| 298 | + result = torch.ops.extension_cpp.mymuladd(*args) |
| 299 | + expected = reference_muladd(*args) |
| 300 | + torch.testing.assert_close(result, expected) |
| 301 | + |
| 302 | + # Use opcheck to check for incorrect usage of operator registration APIs |
| 303 | + torch.library.opcheck(torch.ops.extension_cpp.mymuladd.default, args) |
| 304 | +
|
| 305 | +.. _mutable-ops: |
| 306 | + |
| 307 | +How to create mutable operators |
| 308 | +------------------------------- |
| 309 | +You may wish to author a custom operator that mutates its inputs. Use ``Tensor(a!)`` |
| 310 | +to specify each mutable Tensor in the schema; otherwise, there will be undefined |
| 311 | +behavior. If there are multiple mutated Tensors, use different names (i.e. ``Tensor(a!)``, |
| 312 | +``Tensor(b!)``, ``Tensor(c!)``) for each mutable Tensor. |
| 313 | + |
| 314 | +Let's author a ``myadd_out(a, b, out)`` operator, which writes the contents of ``a+b`` into ``out``. |
| 315 | + |
| 316 | +.. code-block:: cpp |
| 317 | + // An example of an operator that mutates one of its inputs. |
| 318 | + void myadd_out_cpu(const at::Tensor& a, const at::Tensor& b, at::Tensor& out) { |
| 319 | + TORCH_CHECK(a.sizes() == b.sizes()); |
| 320 | + TORCH_CHECK(b.sizes() == out.sizes()); |
| 321 | + TORCH_CHECK(a.dtype() == at::kFloat); |
| 322 | + TORCH_CHECK(b.dtype() == at::kFloat); |
| 323 | + TORCH_CHECK(out.dtype() == at::kFloat); |
| 324 | + TORCH_CHECK(out.is_contiguous()); |
| 325 | + TORCH_INTERNAL_ASSERT(a.device().type() == at::DeviceType::CPU); |
| 326 | + TORCH_INTERNAL_ASSERT(b.device().type() == at::DeviceType::CPU); |
| 327 | + TORCH_INTERNAL_ASSERT(out.device().type() == at::DeviceType::CPU); |
| 328 | + at::Tensor a_contig = a.contiguous(); |
| 329 | + at::Tensor b_contig = b.contiguous(); |
| 330 | + const float* a_ptr = a_contig.data_ptr<float>(); |
| 331 | + const float* b_ptr = b_contig.data_ptr<float>(); |
| 332 | + float* result_ptr = out.data_ptr<float>(); |
| 333 | + for (int64_t i = 0; i < out.numel(); i++) { |
| 334 | + result_ptr[i] = a_ptr[i] + b_ptr[i]; |
| 335 | + } |
| 336 | + } |
| 337 | +
|
| 338 | +When defining the operator, we must specify that it mutates the out Tensor in the schema: |
| 339 | + |
| 340 | +.. code-block:: cpp |
| 341 | + TORCH_LIBRARY(extension_cpp, m) { |
| 342 | + m.def("mymuladd(Tensor a, Tensor b, float c) -> Tensor"); |
| 343 | + m.def("mymul(Tensor a, Tensor b) -> Tensor"); |
| 344 | + // New! |
| 345 | + m.def("myadd_out(Tensor a, Tensor b, Tensor(a!) out) -> ()"); |
| 346 | + } |
| 347 | +
|
| 348 | + TORCH_LIBRARY_IMPL(extension_cpp, CPU, m) { |
| 349 | + m.impl("mymuladd", &mymuladd_cpu); |
| 350 | + m.impl("mymul", &mymul_cpu); |
| 351 | + // New! |
| 352 | + m.impl("myadd_out", &myadd_out_cpu); |
| 353 | + } |
| 354 | +
|
| 355 | +Please do not return any mutated Tensors as outputs of the operator; this will |
| 356 | +run you into problems later down the line. |
| 357 | + |
| 358 | +Conclusion |
| 359 | +---------- |
| 360 | +In this tutorial, we went over the recommended approach to integrating Custom C++ |
| 361 | +and CUDA operators with PyTorch. The TORCH_LIBRARY/torch.library APIs are fairly |
| 362 | +low-level; more detail about how to use them can be found over at |
| 363 | +`The Custom Operators Manual <https://pytorch.org/docs/main/notes/custom_operators.html>`_ |
| 364 | + |
| 365 | + |
0 commit comments