diff --git a/.jenkins/validate_tutorials_built.py b/.jenkins/validate_tutorials_built.py index cd23c0e05d8..596ab1700c9 100644 --- a/.jenkins/validate_tutorials_built.py +++ b/.jenkins/validate_tutorials_built.py @@ -10,6 +10,7 @@ NOT_RUN = [ "beginner_source/basics/intro", # no code + "beginner_source/onnx/intro_onnx", "beginner_source/translation_transformer", "beginner_source/profiler", "beginner_source/saving_loading_models", diff --git a/_static/img/onnx/custom_addandround_function.png b/_static/img/onnx/custom_addandround_function.png new file mode 100644 index 00000000000..a0c7000161e Binary files /dev/null and b/_static/img/onnx/custom_addandround_function.png differ diff --git a/_static/img/onnx/custom_addandround_model.png b/_static/img/onnx/custom_addandround_model.png new file mode 100644 index 00000000000..793d8cfbb5d Binary files /dev/null and b/_static/img/onnx/custom_addandround_model.png differ diff --git a/_static/img/onnx/custom_aten_add_function.png b/_static/img/onnx/custom_aten_add_function.png new file mode 100644 index 00000000000..d9f927ce707 Binary files /dev/null and b/_static/img/onnx/custom_aten_add_function.png differ diff --git a/_static/img/onnx/custom_aten_add_model.png b/_static/img/onnx/custom_aten_add_model.png new file mode 100644 index 00000000000..e5ef1c71742 Binary files /dev/null and b/_static/img/onnx/custom_aten_add_model.png differ diff --git a/_static/img/onnx/custom_aten_gelu_function.png b/_static/img/onnx/custom_aten_gelu_function.png new file mode 100644 index 00000000000..5cb573e7dcb Binary files /dev/null and b/_static/img/onnx/custom_aten_gelu_function.png differ diff --git a/_static/img/onnx/custom_aten_gelu_model.png b/_static/img/onnx/custom_aten_gelu_model.png new file mode 100644 index 00000000000..6bc46337b48 Binary files /dev/null and b/_static/img/onnx/custom_aten_gelu_model.png differ diff --git a/_static/img/onnx/image_clossifier_onnx_modelon_netron_web_ui.png b/_static/img/onnx/image_clossifier_onnx_modelon_netron_web_ui.png new file mode 100755 index 00000000000..0c29c168798 Binary files /dev/null and b/_static/img/onnx/image_clossifier_onnx_modelon_netron_web_ui.png differ diff --git a/_static/img/onnx/netron_web_ui.png b/_static/img/onnx/netron_web_ui.png new file mode 100755 index 00000000000..f88936eb824 Binary files /dev/null and b/_static/img/onnx/netron_web_ui.png differ diff --git a/_static/img/thumbnails/cropped/Exporting-PyTorch-Models-to-ONNX-Graphs.png b/_static/img/thumbnails/cropped/Exporting-PyTorch-Models-to-ONNX-Graphs.png new file mode 100755 index 00000000000..00156df042e Binary files /dev/null and b/_static/img/thumbnails/cropped/Exporting-PyTorch-Models-to-ONNX-Graphs.png differ diff --git a/_static/img/thumbnails/cropped/optional-Exporting-a-Model-from-PyTorch-to-ONNX-and-Running-it-using-ONNX-Runtime.png b/_static/img/thumbnails/cropped/optional-Exporting-a-Model-from-PyTorch-to-ONNX-and-Running-it-using-ONNX-Runtime.png index 426a14d98f5..00156df042e 100644 Binary files a/_static/img/thumbnails/cropped/optional-Exporting-a-Model-from-PyTorch-to-ONNX-and-Running-it-using-ONNX-Runtime.png and b/_static/img/thumbnails/cropped/optional-Exporting-a-Model-from-PyTorch-to-ONNX-and-Running-it-using-ONNX-Runtime.png differ diff --git a/_templates/layout.html b/_templates/layout.html index 660c6870217..242e347d092 100644 --- a/_templates/layout.html +++ b/_templates/layout.html @@ -107,7 +107,7 @@ diff --git a/advanced_source/super_resolution_with_onnxruntime.py b/advanced_source/super_resolution_with_onnxruntime.py index 835a79bd3a0..466c124fe67 100644 --- a/advanced_source/super_resolution_with_onnxruntime.py +++ b/advanced_source/super_resolution_with_onnxruntime.py @@ -1,10 +1,17 @@ """ (optional) Exporting a Model from PyTorch to ONNX and Running it using ONNX Runtime -======================================================================== +=================================================================================== + +.. note:: + As of PyTorch 2.1, there are two versions of ONNX Exporter. + + * ``torch.onnx.dynamo_export`is the newest (still in beta) exporter based on the TorchDynamo technology released with PyTorch 2.0. + * ``torch.onnx.export`` is based on TorchScript backend and has been available since PyTorch 1.2.0. In this tutorial, we describe how to convert a model defined -in PyTorch into the ONNX format and then run it with ONNX Runtime. +in PyTorch into the ONNX format using the TorchScript ``torch.onnx.export` ONNX exporter. +The exported model will be executed with ONNX Runtime. ONNX Runtime is a performance-focused engine for ONNX models, which inferences efficiently across multiple platforms and hardware (Windows, Linux, and Mac and on both CPUs and GPUs). @@ -15,13 +22,17 @@ For this tutorial, you will need to install `ONNX `__ and `ONNX Runtime `__. You can get binary builds of ONNX and ONNX Runtime with -``pip install onnx onnxruntime``. + +.. code-block:: bash + + %%bash + pip install onnxruntime + ONNX Runtime recommends using the latest stable runtime for PyTorch. """ # Some standard imports -import io import numpy as np from torch import nn @@ -185,7 +196,7 @@ def _initialize_weights(self): import onnxruntime -ort_session = onnxruntime.InferenceSession("super_resolution.onnx") +ort_session = onnxruntime.InferenceSession("super_resolution.onnx", providers=["CPUExecutionProvider"]) def to_numpy(tensor): return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy() diff --git a/beginner_source/onnx/README.txt b/beginner_source/onnx/README.txt new file mode 100644 index 00000000000..5c9249c640e --- /dev/null +++ b/beginner_source/onnx/README.txt @@ -0,0 +1,14 @@ +ONNX +---- + +1. intro_onnx.py + Introduction to ONNX + https://pytorch.org/tutorials/onnx/intro_onnx.html + +2. export_simple_model_to_onnx_tutorial.py + Export a PyTorch model to ONNX + https://pytorch.org/tutorials/beginner/onnx/export_simple_model_to_onnx_tutorial.html + +3. onnx_registry_tutorial.py + Introduction to ONNX + https://pytorch.org/tutorials/beginner/onnx/onnx_registry_tutorial.html \ No newline at end of file diff --git a/beginner_source/onnx/export_simple_model_to_onnx_tutorial.py b/beginner_source/onnx/export_simple_model_to_onnx_tutorial.py new file mode 100644 index 00000000000..0c76847d411 --- /dev/null +++ b/beginner_source/onnx/export_simple_model_to_onnx_tutorial.py @@ -0,0 +1,212 @@ +# -*- coding: utf-8 -*- +""" +`Introduction to ONNX `_ || +**Export a PyTorch model to ONNX** || +`Introduction to ONNX Registry `_ + +Export a PyTorch model to ONNX +============================== + +**Author**: `Thiago Crepaldi `_ + +.. note:: + As of PyTorch 2.1, there are two versions of ONNX Exporter. + + * ``torch.onnx.dynamo_export`` is the newest (still in beta) exporter based on the TorchDynamo technology released with PyTorch 2.0 + * ``torch.onnx.export`` is based on TorchScript backend and has been available since PyTorch 1.2.0 + +""" + +############################################################################### +# In the `60 Minute Blitz `_, +# we had the opportunity to learn about PyTorch at a high level and train a small neural network to classify images. +# In this tutorial, we are going to expand this to describe how to convert a model defined in PyTorch into the +# ONNX format using TorchDynamo and the ``torch.onnx.dynamo_export`` ONNX exporter. +# +# While PyTorch is great for iterating on the development of models, the model can be deployed to production +# using different formats, including `ONNX `_ (Open Neural Network Exchange)! +# +# ONNX is a flexible open standard format for representing machine learning models which standardized representations +# of machine learning allow them to be executed across a gamut of hardware platforms and runtime environments +# from large-scale cloud-based supercomputers to resource-constrained edge devices, such as your web browser and phone. +# +# In this tutorial, we’ll learn how to: +# +# 1. Install the required dependencies. +# 2. Author a simple image classifier model. +# 3. Export the model to ONNX format. +# 4. Save the ONNX model in a file. +# 5. Visualize the ONNX model graph using `Netron `_. +# 6. Execute the ONNX model with `ONNX Runtime` +# 7. Compare the PyTorch results with the ones from the ONNX Runtime. +# +# 1. Install the required dependencies +# ------------------------------------ +# Because the ONNX exporter uses ``onnx`` and ``onnxscript`` to translate PyTorch operators into ONNX operators, +# we will need to install them. +# +# .. code-block:: bash +# +# pip install onnx +# pip install onnxscript +# +# 2. Author a simple image classifier model +# ----------------------------------------- +# +# Once your environment is set up, let’s start modeling our image classifier with PyTorch, +# exactly like we did in the `60 Minute Blitz `_. +# + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class MyModel(nn.Module): + + def __init__(self): + super(MyModel, self).__init__() + self.conv1 = nn.Conv2d(1, 6, 5) + self.conv2 = nn.Conv2d(6, 16, 5) + self.fc1 = nn.Linear(16 * 5 * 5, 120) + self.fc2 = nn.Linear(120, 84) + self.fc3 = nn.Linear(84, 10) + + def forward(self, x): + x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2)) + x = F.max_pool2d(F.relu(self.conv2(x)), 2) + x = torch.flatten(x, 1) + x = F.relu(self.fc1(x)) + x = F.relu(self.fc2(x)) + x = self.fc3(x) + return x + +###################################################################### +# 3. Export the model to ONNX format +# ---------------------------------- +# +# Now that we have our model defined, we need to instantiate it and create a random 32x32 input. +# Next, we can export the model to ONNX format. + +torch_model = MyModel() +torch_input = torch.randn(1, 1, 32, 32) +export_output = torch.onnx.dynamo_export(torch_model, torch_input) + +###################################################################### +# As we can see, we didn't need any code change to the model. +# The resulting ONNX model is stored within ``torch.onnx.ExportOutput`` as a binary protobuf file. +# +# 4. Save the ONNX model in a file +# -------------------------------- +# +# Although having the exported model loaded in memory is useful in many applications, +# we can save it to disk with the following code: + +export_output.save("my_image_classifier.onnx") + +###################################################################### +# You can load the ONNX file back into memory and check if it is well formed with the following code: + +import onnx +onnx_model = onnx.load("my_image_classifier.onnx") +onnx.checker.check_model(onnx_model) + +###################################################################### +# 5. Visualize the ONNX model graph using Netron +# ---------------------------------------------- +# +# Now that we have our model saved in a file, we can visualize it with `Netron `_. +# Netron can either be installed on macOS, Linux or Windows computers, or run directly from the browser. +# Let's try the web version by opening the following link: https://netron.app/. +# +# .. image:: ../../_static/img/onnx/netron_web_ui.png +# :width: 70% +# :align: center +# +# +# Once Netron is open, we can drag and drop our ``my_image_classifier.onnx`` file into the browser or select it after +# clicking the **Open model** button. +# +# .. image:: ../../_static/img/onnx/image_clossifier_onnx_modelon_netron_web_ui.png +# :width: 50% +# +# +# And that is it! We have successfully exported our PyTorch model to ONNX format and visualized it with Netron. +# +# 6. Execute the ONNX model with ONNX Runtime +# ------------------------------------------- +# +# The last step is executing the ONNX model with `ONNX Runtime`, but before we do that, let's install ONNX Runtime. +# +# .. code-block:: bash +# +# pip install onnxruntime +# +# The ONNX standard does not support all the data structure and types that PyTorch does, +# so we need to adapt PyTorch input's to ONNX format before feeding it to ONNX Runtime. +# In our example, the input happens to be the same, but it might have more inputs +# than the original PyTorch model in more complex models. +# +# ONNX Runtime requires an additional step that involves converting all PyTorch tensors to Numpy (in CPU) +# and wrap them on a dictionary with keys being a string with the input name as key and the numpy tensor as the value. +# +# Now we can create an *ONNX Runtime Inference Session*, execute the ONNX model with the processed input +# and get the output. In this tutorial, ONNX Runtime is executed on CPU, but it could be executed on GPU as well. + +import onnxruntime + +onnx_input = export_output.adapt_torch_inputs_to_onnx(torch_input) +print(f"Input length: {len(onnx_input)}") +print(f"Sample input: {onnx_input}") + +ort_session = onnxruntime.InferenceSession("./my_image_classifier.onnx", providers=['CPUExecutionProvider']) + +def to_numpy(tensor): + return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy() + +onnxruntime_input = {k.name: to_numpy(v) for k, v in zip(ort_session.get_inputs(), onnx_input)} + +onnxruntime_outputs = ort_session.run(None, onnxruntime_input) + +###################################################################### +# 7. Compare the PyTorch results with the ones from the ONNX Runtime +# ----------------------------------------------------------------- +# +# The best way to determine whether the exported model is looking good is through numerical evaluation +# against PyTorch, which is our source of truth. +# +# For that, we need to execute the PyTorch model with the same input and compare the results with ONNX Runtime's. +# Before comparing the results, we need to convert the PyTorch's output to match ONNX's format. + +torch_outputs = torch_model(torch_input) +torch_outputs = export_output.adapt_torch_outputs_to_onnx(torch_outputs) + +assert len(torch_outputs) == len(onnxruntime_outputs) +for torch_output, onnxruntime_output in zip(torch_outputs, onnxruntime_outputs): + torch.testing.assert_close(torch_output, torch.tensor(onnxruntime_output)) + +print("PyTorch and ONNX Runtime output matched!") +print(f"Output length: {len(onnxruntime_outputs)}") +print(f"Sample output: {onnxruntime_outputs}") + +###################################################################### +# Conclusion +# ---------- +# +# That is about it! We have successfully exported our PyTorch model to ONNX format, +# saved the model to disk, viewed it using Netron, executed it with ONNX Runtime +# and finally compared its numerical results with PyTorch's. +# +# Further reading +# --------------- +# +# The list below refers to tutorials that ranges from basic examples to advanced scenarios, +# not necessarily in the order they are listed. +# Feel free to jump directly to specific topics of your interest or +# sit tight and have fun going through all of them to learn all there is about the ONNX exporter. +# +# .. include:: /beginner_source/onnx/onnx_toc.txt +# +# .. toctree:: +# :hidden: +# \ No newline at end of file diff --git a/beginner_source/onnx/intro_onnx.py b/beginner_source/onnx/intro_onnx.py new file mode 100644 index 00000000000..d86e917a5e5 --- /dev/null +++ b/beginner_source/onnx/intro_onnx.py @@ -0,0 +1,68 @@ +""" +**Introduction to ONNX** || +`Export a PyTorch model to ONNX `_ || +`Introduction to ONNX Registry `_ + +Introduction to ONNX +==================== + +Authors: +`Thiago Crepaldi `_, + +`Open Neural Network eXchange (ONNX) `_ is an open standard +format for representing machine learning models. The ``torch.onnx`` module provides APIs to +capture the computation graph from a native PyTorch :class:`torch.nn.Module` model and convert +it into an `ONNX graph `_. + +The exported model can be consumed by any of the many +`runtimes that support ONNX `_, +including Microsoft's `ONNX Runtime `_. + +.. note:: + Currently, there are two flavors of ONNX exporter APIs, + but this tutorial will focus on the ``torch.onnx.dynamo_export``. + +The TorchDynamo engine is leveraged to hook into Python's frame evaluation API and dynamically rewrite its +bytecode into an `FX graph `_. +The resulting FX Graph is polished before it is finally translated into an +`ONNX graph `_. + +The main advantage of this approach is that the `FX graph `_ is captured using +bytecode analysis that preserves the dynamic nature of the model instead of using traditional static tracing techniques. + +Dependencies +------------ + +PyTorch 2.1.0 or newer is required. + +The ONNX exporter depends on extra Python packages: + + - `ONNX `_ + - `ONNX Script `_ + +They can be installed through `pip `_: + +.. note:: + This tutorial leverages `onnxscript `__ + to create custom ONNX operators. onnxscript is a Python library that allows users to + create custom ONNX operators in Python. It is a prerequisite learning material for + this tutorial. Please make sure you have read the onnxscript tutorial before proceeding. + +.. code-block:: bash + + pip install --upgrade onnx onnxscript + +Further reading +--------------- + +The list below refers to tutorials that ranges from basic examples to advanced scenarios, +not necessarily in the order they are listed. +Feel free to jump directly to specific topics of your interest or +sit tight and have fun going through all of them to learn all there is about the ONNX exporter. + +.. include:: /beginner_source/onnx/onnx_toc.txt + +.. toctree:: + :hidden: + +""" diff --git a/beginner_source/onnx/onnx_registry_tutorial.py b/beginner_source/onnx/onnx_registry_tutorial.py new file mode 100644 index 00000000000..f3c6cb94b7a --- /dev/null +++ b/beginner_source/onnx/onnx_registry_tutorial.py @@ -0,0 +1,469 @@ +# -*- coding: utf-8 -*- + +""" +`Introduction to ONNX `_ || +`Export a PyTorch model to ONNX `_ || +**Introduction to ONNX Registry** + +Introduction to ONNX Registry +============================= + +**Authors:** Ti-Tai Wang (titaiwang@microsoft.com) +""" + + +############################################################################### +# Overview +# ~~~~~~~~ +# +# This tutorial is an introduction to ONNX registry, which +# empowers users to create their own ONNX registries enabling +# them to address unsupported operators in ONNX. +# +# In this tutorial, we will cover the following scenarios: +# +# * Unsupported ATen operators +# * Unsupported ATen operators with existing ONNX Runtime support +# * Unsupported PyTorch operators with no ONNX Runtime support +# + + +import torch +print(torch.__version__) +torch.manual_seed(191009) # set the seed for reproducibility + +import onnxscript # pip install onnxscript +print(onnxscript.__version__) + +# NOTE: opset18 is the only version of ONNX operators we are +# using in torch.onnx.dynamo_export for now. +from onnxscript import opset18 + +import onnxruntime # pip install onnxruntime +print(onnxruntime.__version__) + + +###################################################################### +# Unsupported ATen operators +# --------------------------------- +# +# ATen operators are implemented by PyTorch, and the ONNX exporter team must manually implement the +# conversion from ATen operators to ONNX operators through [ONNX Script](https://onnxscript.ai/). Although the ONNX exporter +# team has been making their best efforts to support as many ATen operators as possible, some ATen +# operators are still not supported. In this section, we will demonstrate how you can implement any +# unsupported ATen operators, which can contribute back to the project through the PyTorch GitHub repo. +# +# If the model cannot be exported to ONNX, for instance, :class:`aten::add.Tensor` is not supported +# by ONNX The error message can be found, and is as follows (for example, ``aten::add.Tensor``): +# ``RuntimeErrorWithDiagnostic: Unsupported FX nodes: {'call_function': ['aten.add.Tensor']}. `` +# +# To support unsupported ATen operators, we need two things: +# * The unsupported ATen operator namespace, operator name, and the +# corresponding overload (for example ``::. - aten::add.Tensor``), +# which can be found in the error message. +# * The implementation of the operator in `ONNX Script `__. +# + + +# NOTE: `is_registered_op` is a method in ONNX registry that checks +# whether the operator is supported by ONNX. If the operator is not +# supported, it will return False. Otherwise, it will return True. +onnx_registry = torch.onnx.OnnxRegistry() +# aten::add.default and aten::add.Tensor are supported by ONNX +print(f"aten::add.default is supported by ONNX registry: \ + {onnx_registry.is_registered_op(namespace='aten', op_name='add', overload='default')}") +# aten::add.Tensor is the one invoked by torch.ops.aten.add +print(f"aten::add.Tensor is supported by ONNX registry: \ + {onnx_registry.is_registered_op(namespace='aten', op_name='add', overload='Tensor')}") + + +###################################################################### +# In this example, we will assume that ``aten::add.Tensor`` is not supported by the ONNX registry, +# and we will demonstrate how to support it. The ONNX registry allows user overrides for operator +# registration. In this case, we will override the registration of ``aten::add.Tensor`` with our +# implementation and verify it. However, this unsupported operator should return False when +# checked with :meth:`onnx_registry.is_registered_op`. +# + + +class Model(torch.nn.Module): + def forward(self, input_x, input_y): + # specifically call out aten::add + return torch.ops.aten.add(input_x, input_y) + +input_add_x = torch.randn(3, 4) +input_add_y = torch.randn(3, 4) +aten_add_model = Model() + + +# Let's create a onnxscript function to support aten::add.Tensor. +# This can be named anything, and shows later on Netron graph. +custom_aten = onnxscript.values.Opset(domain="custom.aten", version=1) + +# NOTE: The function signature must match the signature of the unsupported ATen operator. +# https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/native_functions.yaml +# NOTE: All attributes must be annotated with type hints. +@onnxscript.script(custom_aten) +def custom_aten_add(input_x, input_y, alpha: float = 1.0): + alpha = opset18.CastLike(alpha, input_y) + input_y = opset18.Mul(input_y, alpha) + return opset18.Add(input_x, input_y) + + +# Now we have both things we need to support unsupported ATen operators. +# Let's register the custom_aten_add function to ONNX registry, and +# export the model to ONNX again. +onnx_registry.register_op( + namespace="aten", op_name="add", overload="Tensor", function=custom_aten_add + ) +print(f"aten::add.Tensor is supported by ONNX registry: \ + {onnx_registry.is_registered_op(namespace='aten', op_name='add', overload='Tensor')}" + ) +export_options = torch.onnx.ExportOptions(onnx_registry=onnx_registry) +export_output = torch.onnx.dynamo_export( + aten_add_model, input_add_x, input_add_y, export_options=export_options + ) + +###################################################################### +# Make sure the model uses ``custom_aten_add`` instead of ``aten::add.Tensor`` +# The graph has one graph node for ``custom_aten_add``, and inside +# ``custom_aten_add`` there are four function nodes, one for each +# operator, and one for constant attribute. +# + +# graph node domain is the custom domain we registered +assert export_output.model_proto.graph.node[0].domain == "custom.aten" +assert len(export_output.model_proto.graph.node) == 1 +# graph node name is the function name +assert export_output.model_proto.graph.node[0].op_type == "custom_aten_add" +# function node domain is empty because we use standard ONNX operators +assert export_output.model_proto.functions[0].node[3].domain == "" +# function node name is the standard ONNX operator name +assert export_output.model_proto.functions[0].node[3].op_type == "Add" + + +###################################################################### +# ``custom_aten_add_model`` ONNX graph in Netron: +# +# .. image:: /_static/img/onnx/custom_aten_add_model.png +# :width: 70% +# :align: center +# +# Inside the custom_aten_add function, we can see the three ONNX nodes we +# used in the function (CastLike, Add, and Mul), and one constant attribute: +# +# .. image:: /_static/img/onnx/custom_aten_add_function.png +# :width: 70% +# :align: center +# +# After checking the ONNX graph, we can use ONNX Runtime to run the model, +# + + +# Use ONNX Runtime to run the model, and compare the results with PyTorch +export_output.save("./custom_add_model.onnx") +ort_session = onnxruntime.InferenceSession( + "./custom_add_model.onnx", providers=['CPUExecutionProvider'] + ) + +def to_numpy(tensor): + return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy() + +onnx_input = export_output.adapt_torch_inputs_to_onnx(input_add_x, input_add_y) +onnxruntime_input = {k.name: to_numpy(v) for k, v in zip(ort_session.get_inputs(), onnx_input)} +onnxruntime_outputs = ort_session.run(None, onnxruntime_input) + +torch_outputs = aten_add_model(input_add_x, input_add_y) +torch_outputs = export_output.adapt_torch_outputs_to_onnx(torch_outputs) + +assert len(torch_outputs) == len(onnxruntime_outputs) +for torch_output, onnxruntime_output in zip(torch_outputs, onnxruntime_outputs): + torch.testing.assert_close(torch_output, torch.tensor(onnxruntime_output)) + + +###################################################################### +# Unsupported ATen operators with existing ONNX Runtime support +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# +# In this case, the unsupported ATen operator is supported by ONNX Runtime but not +# supported by ONNX spec. This occurs because ONNX Runtime users can implement their +# custom operators, which ONNX Runtime supports. When the need arises, ONNX Runtime +# will contribute these custom operators to the ONNX spec. Therefore, in the ONNX registry, +# we only need to register the operator with the recognized namespace and operator name. +# +# In the following example, we would like to use the Gelu in ONNX Runtime, +# which is not the same Gelu in ONNX spec. Thus, we register the Gelu with +# the namespace "com.microsoft" and operator name "Gelu". +# + + +class CustomGelu(torch.nn.Module): + def forward(self, input_x): + return torch.ops.aten.gelu(input_x) + +# com.microsoft is an official ONNX Runtime namspace +custom_ort = onnxscript.values.Opset(domain="com.microsoft", version=1) + +# NOTE: The function signature must match the signature of the unsupported ATen operator. +# https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/native_functions.yaml +# NOTE: All attributes must be annotated with type hints. +@onnxscript.script(custom_ort) +def custom_aten_gelu(input_x, approximate: str = "none"): + # We know com.microsoft::Gelu is supported by ONNX Runtime + # It's only not supported by ONNX + return custom_ort.Gelu(input_x) + + +onnx_registry = torch.onnx.OnnxRegistry() +onnx_registry.register_op( + namespace="aten", op_name="gelu", overload="default", function=custom_aten_gelu) +export_options = torch.onnx.ExportOptions(onnx_registry=onnx_registry) + +aten_gelu_model = CustomGelu() +input_gelu_x = torch.randn(3, 3) + +export_output = torch.onnx.dynamo_export( + aten_gelu_model, input_gelu_x, export_options=export_options + ) + + +###################################################################### +# Make sure the model uses :func:`custom_aten_gelu`` instead of +# :class:`aten::gelu` The graph has one graph nodes for +# ``custom_aten_gelu``, and inside ``custom_aten_gelu``, there is a function +# node for Gelu with namespace "com.microsoft". +# + +# graph node domain is the custom domain we registered +assert export_output.model_proto.graph.node[0].domain == "com.microsoft" +# graph node name is the function name +assert export_output.model_proto.graph.node[0].op_type == "custom_aten_gelu" +# function node domain is the custom domain we registered +assert export_output.model_proto.functions[0].node[0].domain == "com.microsoft" +# function node name is the node name used in the function +assert export_output.model_proto.functions[0].node[0].op_type == "Gelu" + + +###################################################################### +# The following diagram shows``custom_aten_gelu_model`` ONNX graph in Netron: +# +# .. image:: /_static/img/onnx/custom_aten_gelu_model.png +# :width: 70% +# :align: center +# +# Inside the custom_aten_gelu function, we can see the Gelu node from module +# "com.microsoft" we used in the function: +# +# .. image:: /_static/img/onnx/custom_aten_gelu_function.png +# +# After checking the ONNX graph, we can use ONNX Runtime to run the model, + + +export_output.save("./custom_gelu_model.onnx") +ort_session = onnxruntime.InferenceSession( + "./custom_gelu_model.onnx", providers=['CPUExecutionProvider'] + ) + +def to_numpy(tensor): + return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy() + +onnx_input = export_output.adapt_torch_inputs_to_onnx(input_gelu_x) +onnxruntime_input = {k.name: to_numpy(v) for k, v in zip(ort_session.get_inputs(), onnx_input)} +onnxruntime_outputs = ort_session.run(None, onnxruntime_input) + +torch_outputs = aten_gelu_model(input_gelu_x) +torch_outputs = export_output.adapt_torch_outputs_to_onnx(torch_outputs) + +assert len(torch_outputs) == len(onnxruntime_outputs) +for torch_output, onnxruntime_output in zip(torch_outputs, onnxruntime_outputs): + torch.testing.assert_close(torch_output, torch.tensor(onnxruntime_output)) + + +###################################################################### +# Unsupported PyTorch operators with no ONNX Runtime support +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# +# In this case, the operator is not supported by any frameworks, and we +# would like to use it in ONNX graph. Therefore, we need to implement +# the operator in three places: +# +# 1. PyTorch FX graph +# 2. ONNX Registry +# 3. ONNX Runtime +# +# In the following example, we would like to use a custom operator +# that takes one tensor input, and returns an input. The operator adds +# the input to itself, and returns the rounded result. +# +# **Custom Ops Registration in PyTorch FX Graph (Beta)** +# +# Firstly, we need to implement the operator in PyTorch FX graph. +# This can be done by using ``torch._custom_op``. +# + +# NOTE: This is a beta feature in PyTorch, and is subject to change. +from torch._custom_op import impl as custom_op + +@custom_op.custom_op("mylibrary::addandround_op") +def addandround_op(tensor_x: torch.Tensor) -> torch.Tensor: + ... + +@addandround_op.impl_abstract() +def addandround_op_impl_abstract(tensor_x): + return torch.empty_like(tensor_x) + +@addandround_op.impl("cpu") +def addandround_op_impl(tensor_x): + # add x to itself, and round the result + return torch.round(tensor_x + tensor_x) + +torch._dynamo.allow_in_graph(addandround_op) + +class CustomFoo(torch.nn.Module): + def forward(self, tensor_x): + return addandround_op(tensor_x) + +input_addandround_x = torch.randn(3) +custom_addandround_model = CustomFoo() + + +###################################################################### +# **Custom Ops Registration in ONNX Registry** +# +# For the step 2 and 3, we need to implement the operator in ONNX registry. +# In this example, we will implement the operator in ONNX registry +# with the namespace "test.customop" and operator name "CustomOpOne", +# and "CustomOpTwo". These two ops are registered and built in +# `cpu_ops.cc `__. +# + + +custom_opset = onnxscript.values.Opset(domain="test.customop", version=1) + +# NOTE: The function signature must match the signature of the unsupported ATen operator. +# https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/native_functions.yaml +# NOTE: All attributes must be annotated with type hints. +@onnxscript.script(custom_opset) +def custom_addandround(input_x): + # The same as opset18.Add(x, x) + add_x = custom_opset.CustomOpOne(input_x, input_x) + # The same as opset18.Round(x, x) + round_x = custom_opset.CustomOpTwo(add_x) + # Cast to FLOAT to match the ONNX type + return opset18.Cast(round_x, to=1) + + +onnx_registry = torch.onnx.OnnxRegistry() +onnx_registry.register_op( + namespace="mylibrary", op_name="addandround_op", overload="default", function=custom_addandround + ) + +export_options = torch.onnx.ExportOptions(onnx_registry=onnx_registry) +export_output = torch.onnx.dynamo_export( + custom_addandround_model, input_addandround_x, export_options=export_options + ) +export_output.save("./custom_addandround_model.onnx") + + +###################################################################### +# The exported model proto is accessible through export_output.model_proto. +# The graph has one graph nodes for custom_addandround, and inside custom_addandround, +# there are two function nodes, one for each operator. +# + +assert export_output.model_proto.graph.node[0].domain == "test.customop" +assert export_output.model_proto.graph.node[0].op_type == "custom_addandround" +assert export_output.model_proto.functions[0].node[0].domain == "test.customop" +assert export_output.model_proto.functions[0].node[0].op_type == "CustomOpOne" +assert export_output.model_proto.functions[0].node[1].domain == "test.customop" +assert export_output.model_proto.functions[0].node[1].op_type == "CustomOpTwo" + + +###################################################################### +# custom_addandround_model ONNX graph in Netron: +# +# .. image:: /_static/img/onnx/custom_addandround_model.png +# :width: 70% +# :align: center +# +# Inside the custom_addandround function, we can see the two CustomOp nodes we +# used in the function (CustomOpOne, and CustomOpTwo), and they are from module +# "test.customop": +# +# .. image:: /_static/img/onnx/custom_addandround_function.png +# + +###################################################################### +# **Custom Ops Registration in ONNX Runtime** +# +# To link your custom op library to ONNX Runtime, you need to +# compile your C++ code into a shared library and link it to ONNX Runtime. +# Follow the instructions below: +# +# 1. Implement your custom op in C++ by following +# `ONNX Runtime instructions <`https://github.com/microsoft/onnxruntime/blob/gh-pages/docs/reference/operators/add-custom-op.md>`__. +# 2. Download ONNX Runtime source distribution from +# `ONNX Runtime releases `__. +# 3. Compile and link your custom op library to ONNX Runtime, for example: +# +# .. code-block:: bash +# +# $ gcc -shared -o libcustom_op_library.so custom_op_library.cc -L /path/to/downloaded/ort/lib/ -lonnxruntime -fPIC +# +# 4. Run the model with ONNX Runtime Python API +# +# .. code-block:: python +# +# ort_session_options = onnxruntime.SessionOptions() +# +# # NOTE: Link the custom op library to ONNX Runtime and replace the path +# # with the path to your custom op library +# ort_session_options.register_custom_ops_library( +# "/path/to/libcustom_op_library.so" +# ) +# ort_session = onnxruntime.InferenceSession( +# "./custom_addandround_model.onnx", providers=['CPUExecutionProvider'], sess_options=ort_session_options) +# +# def to_numpy(tensor): +# return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy() +# +# onnx_input = export_output.adapt_torch_inputs_to_onnx(input_addandround_x) +# onnxruntime_input = {k.name: to_numpy(v) for k, v in zip(ort_session.get_inputs(), onnx_input)} +# onnxruntime_outputs = ort_session.run(None, onnxruntime_input) +# +# torch_outputs = custom_addandround_model(input_addandround_x) +# torch_outputs = export_output.adapt_torch_outputs_to_onnx(torch_outputs) +# +# assert len(torch_outputs) == len(onnxruntime_outputs) +# for torch_output, onnxruntime_output in zip(torch_outputs, onnxruntime_outputs): +# torch.testing.assert_close(torch_output, torch.tensor(onnxruntime_output)) +# +###################################################################### +# Conclusion +# ---------- +# +# Congratulations! That is it!. In this tutorial, we have learned how to effectively +# manage unsupported operators within ONNX models. We explored the ONNX registry and +# discovered how to create custom implementations for unsupported ATen operators using +# ONNX Script, registering them in the ONNX registry, and integrating them with ONNX Runtime. +# Additionally, we've gained insights into addressing unsupported PyTorch operators by +# implementing them in PyTorch FX, registering them in the ONNX registry, and linking +# them to ONNX Runtime. The tutorial concluded by demonstrating how to export models +# containing custom operators to ONNX files and validate their functionality using +# ONNX Runtime, providing us with a comprehensive understanding of handling unsupported +# operators in the ONNX ecosystem. +# +###################################################################### +# Further reading +# --------------- +# +# The list below refers to tutorials that ranges from basic examples to advanced scenarios, +# not necessarily in the order they are listed. +# Feel free to jump directly to specific topics of your interest or +# sit tight and have fun going through all of them to learn all there is about the ONNX exporter. +# +# .. include:: /beginner_source/onnx/onnx_toc.txt +# +# .. toctree:: +# :hidden: +# diff --git a/beginner_source/onnx/onnx_toc.txt b/beginner_source/onnx/onnx_toc.txt new file mode 100644 index 00000000000..bd57025641b --- /dev/null +++ b/beginner_source/onnx/onnx_toc.txt @@ -0,0 +1,2 @@ +| 1. `Export a PyTorch model to ONNX `_ +| 2. `Introduction to ONNX registry `_ \ No newline at end of file diff --git a/en-wordlist.txt b/en-wordlist.txt index ee2c79b6b41..bb504f7b292 100644 --- a/en-wordlist.txt +++ b/en-wordlist.txt @@ -132,6 +132,7 @@ Lipschitz logits Lua Luong +macos MLP MLPs MNIST @@ -147,11 +148,15 @@ NTK NUMA NaN NanoGPT +Netron NeurIPS NumPy Numericalization Numpy's ONNX +ONNX's +ONNX Runtime +ONNX Script OpenAI OpenMP Ornstein @@ -386,6 +391,7 @@ prewritten primals profiler profilers +protobuf py pytorch quantized diff --git a/index.rst b/index.rst index 3070002466f..0d165be5a8c 100644 --- a/index.rst +++ b/index.rst @@ -272,6 +272,22 @@ What's new in PyTorch tutorials? :tags: Text +.. ONNX + +.. customcarditem:: + :header: (optional) Exporting a PyTorch model to ONNX using TorchDynamo backend and Running it using ONNX Runtime + :card_description: Build a image classifier model in PyTorch and convert it to ONNX before deploying it with ONNX Runtime. + :image: _static/img/thumbnails/cropped/Exporting-PyTorch-Models-to-ONNX-Graphs.png + :link: beginner/onnx/export_simple_model_to_onnx_tutorial.html + :tags: Production,ONNX,Backends + +.. customcarditem:: + :header: Introduction to ONNX Registry + :card_description: Demonstrate end-to-end how to address unsupported operators by using ONNX Registry. + :image: _static/img/thumbnails/cropped/Exporting-PyTorch-Models-to-ONNX-Graphs.png + :link: advanced/onnx_registry_tutorial.html + :tags: Production,ONNX,Backends + .. Reinforcement Learning .. customcarditem:: @@ -329,11 +345,12 @@ What's new in PyTorch tutorials? :tags: Production,TorchScript .. customcarditem:: - :header: (optional) Exporting a Model from PyTorch to ONNX and Running it using ONNX Runtime + :header: (optional) Exporting a PyTorch Model to ONNX using TorchScript backend and Running it using ONNX Runtime :card_description: Convert a model defined in PyTorch into the ONNX format and then run it with ONNX Runtime. :image: _static/img/thumbnails/cropped/optional-Exporting-a-Model-from-PyTorch-to-ONNX-and-Running-it-using-ONNX-Runtime.png :link: advanced/super_resolution_with_onnxruntime.html - :tags: Production + :tags: Production,ONNX + .. Code Transformations with FX @@ -902,6 +919,14 @@ Additional Resources beginner/torchtext_custom_dataset_tutorial +.. toctree:: + :maxdepth: 2 + :includehidden: + :hidden: + :caption: Backends + + beginner/onnx/intro_onnx + .. toctree:: :maxdepth: 2 :includehidden: @@ -918,6 +943,7 @@ Additional Resources :hidden: :caption: Deploying PyTorch Models in Production + beginner/onnx/intro_onnx intermediate/flask_rest_api_tutorial beginner/Intro_to_TorchScript_tutorial advanced/cpp_export diff --git a/intermediate_source/memory_format_tutorial.py b/intermediate_source/memory_format_tutorial.py index f08980265de..26bc5c9d53c 100644 --- a/intermediate_source/memory_format_tutorial.py +++ b/intermediate_source/memory_format_tutorial.py @@ -131,7 +131,7 @@ # produces output in contiguous memory format. Otherwise, output will # be in channels last memory format. -if torch.backends.cudnn.version() >= 7603: +if torch.backends.cudnn.is_available() and torch.backends.cudnn.version() >= 7603: model = torch.nn.Conv2d(8, 4, 3).cuda().half() model = model.to(memory_format=torch.channels_last) # Module parameters need to be channels last diff --git a/requirements.txt b/requirements.txt index 84c35e78d08..31b3f0ad16b 100644 --- a/requirements.txt +++ b/requirements.txt @@ -33,6 +33,9 @@ datasets transformers torchmultimodal-nightly # needs to be updated to stable as soon as it's avaialable deep_phonemizer==0.0.17 +onnx +onnxscript +onnxruntime importlib-metadata==6.8.0