diff --git a/.jenkins/validate_tutorials_built.py b/.jenkins/validate_tutorials_built.py
index cd23c0e05d8..596ab1700c9 100644
--- a/.jenkins/validate_tutorials_built.py
+++ b/.jenkins/validate_tutorials_built.py
@@ -10,6 +10,7 @@
NOT_RUN = [
"beginner_source/basics/intro", # no code
+ "beginner_source/onnx/intro_onnx",
"beginner_source/translation_transformer",
"beginner_source/profiler",
"beginner_source/saving_loading_models",
diff --git a/_static/img/onnx/custom_addandround_function.png b/_static/img/onnx/custom_addandround_function.png
new file mode 100644
index 00000000000..a0c7000161e
Binary files /dev/null and b/_static/img/onnx/custom_addandround_function.png differ
diff --git a/_static/img/onnx/custom_addandround_model.png b/_static/img/onnx/custom_addandround_model.png
new file mode 100644
index 00000000000..793d8cfbb5d
Binary files /dev/null and b/_static/img/onnx/custom_addandround_model.png differ
diff --git a/_static/img/onnx/custom_aten_add_function.png b/_static/img/onnx/custom_aten_add_function.png
new file mode 100644
index 00000000000..d9f927ce707
Binary files /dev/null and b/_static/img/onnx/custom_aten_add_function.png differ
diff --git a/_static/img/onnx/custom_aten_add_model.png b/_static/img/onnx/custom_aten_add_model.png
new file mode 100644
index 00000000000..e5ef1c71742
Binary files /dev/null and b/_static/img/onnx/custom_aten_add_model.png differ
diff --git a/_static/img/onnx/custom_aten_gelu_function.png b/_static/img/onnx/custom_aten_gelu_function.png
new file mode 100644
index 00000000000..5cb573e7dcb
Binary files /dev/null and b/_static/img/onnx/custom_aten_gelu_function.png differ
diff --git a/_static/img/onnx/custom_aten_gelu_model.png b/_static/img/onnx/custom_aten_gelu_model.png
new file mode 100644
index 00000000000..6bc46337b48
Binary files /dev/null and b/_static/img/onnx/custom_aten_gelu_model.png differ
diff --git a/_static/img/onnx/image_clossifier_onnx_modelon_netron_web_ui.png b/_static/img/onnx/image_clossifier_onnx_modelon_netron_web_ui.png
new file mode 100755
index 00000000000..0c29c168798
Binary files /dev/null and b/_static/img/onnx/image_clossifier_onnx_modelon_netron_web_ui.png differ
diff --git a/_static/img/onnx/netron_web_ui.png b/_static/img/onnx/netron_web_ui.png
new file mode 100755
index 00000000000..f88936eb824
Binary files /dev/null and b/_static/img/onnx/netron_web_ui.png differ
diff --git a/_static/img/thumbnails/cropped/Exporting-PyTorch-Models-to-ONNX-Graphs.png b/_static/img/thumbnails/cropped/Exporting-PyTorch-Models-to-ONNX-Graphs.png
new file mode 100755
index 00000000000..00156df042e
Binary files /dev/null and b/_static/img/thumbnails/cropped/Exporting-PyTorch-Models-to-ONNX-Graphs.png differ
diff --git a/_static/img/thumbnails/cropped/optional-Exporting-a-Model-from-PyTorch-to-ONNX-and-Running-it-using-ONNX-Runtime.png b/_static/img/thumbnails/cropped/optional-Exporting-a-Model-from-PyTorch-to-ONNX-and-Running-it-using-ONNX-Runtime.png
index 426a14d98f5..00156df042e 100644
Binary files a/_static/img/thumbnails/cropped/optional-Exporting-a-Model-from-PyTorch-to-ONNX-and-Running-it-using-ONNX-Runtime.png and b/_static/img/thumbnails/cropped/optional-Exporting-a-Model-from-PyTorch-to-ONNX-and-Running-it-using-ONNX-Runtime.png differ
diff --git a/_templates/layout.html b/_templates/layout.html
index 660c6870217..242e347d092 100644
--- a/_templates/layout.html
+++ b/_templates/layout.html
@@ -107,7 +107,7 @@
diff --git a/advanced_source/super_resolution_with_onnxruntime.py b/advanced_source/super_resolution_with_onnxruntime.py
index 835a79bd3a0..466c124fe67 100644
--- a/advanced_source/super_resolution_with_onnxruntime.py
+++ b/advanced_source/super_resolution_with_onnxruntime.py
@@ -1,10 +1,17 @@
"""
(optional) Exporting a Model from PyTorch to ONNX and Running it using ONNX Runtime
-========================================================================
+===================================================================================
+
+.. note::
+ As of PyTorch 2.1, there are two versions of ONNX Exporter.
+
+ * ``torch.onnx.dynamo_export`is the newest (still in beta) exporter based on the TorchDynamo technology released with PyTorch 2.0.
+ * ``torch.onnx.export`` is based on TorchScript backend and has been available since PyTorch 1.2.0.
In this tutorial, we describe how to convert a model defined
-in PyTorch into the ONNX format and then run it with ONNX Runtime.
+in PyTorch into the ONNX format using the TorchScript ``torch.onnx.export` ONNX exporter.
+The exported model will be executed with ONNX Runtime.
ONNX Runtime is a performance-focused engine for ONNX models,
which inferences efficiently across multiple platforms and hardware
(Windows, Linux, and Mac and on both CPUs and GPUs).
@@ -15,13 +22,17 @@
For this tutorial, you will need to install `ONNX `__
and `ONNX Runtime `__.
You can get binary builds of ONNX and ONNX Runtime with
-``pip install onnx onnxruntime``.
+
+.. code-block:: bash
+
+ %%bash
+ pip install onnxruntime
+
ONNX Runtime recommends using the latest stable runtime for PyTorch.
"""
# Some standard imports
-import io
import numpy as np
from torch import nn
@@ -185,7 +196,7 @@ def _initialize_weights(self):
import onnxruntime
-ort_session = onnxruntime.InferenceSession("super_resolution.onnx")
+ort_session = onnxruntime.InferenceSession("super_resolution.onnx", providers=["CPUExecutionProvider"])
def to_numpy(tensor):
return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()
diff --git a/beginner_source/onnx/README.txt b/beginner_source/onnx/README.txt
new file mode 100644
index 00000000000..5c9249c640e
--- /dev/null
+++ b/beginner_source/onnx/README.txt
@@ -0,0 +1,14 @@
+ONNX
+----
+
+1. intro_onnx.py
+ Introduction to ONNX
+ https://pytorch.org/tutorials/onnx/intro_onnx.html
+
+2. export_simple_model_to_onnx_tutorial.py
+ Export a PyTorch model to ONNX
+ https://pytorch.org/tutorials/beginner/onnx/export_simple_model_to_onnx_tutorial.html
+
+3. onnx_registry_tutorial.py
+ Introduction to ONNX
+ https://pytorch.org/tutorials/beginner/onnx/onnx_registry_tutorial.html
\ No newline at end of file
diff --git a/beginner_source/onnx/export_simple_model_to_onnx_tutorial.py b/beginner_source/onnx/export_simple_model_to_onnx_tutorial.py
new file mode 100644
index 00000000000..0c76847d411
--- /dev/null
+++ b/beginner_source/onnx/export_simple_model_to_onnx_tutorial.py
@@ -0,0 +1,212 @@
+# -*- coding: utf-8 -*-
+"""
+`Introduction to ONNX `_ ||
+**Export a PyTorch model to ONNX** ||
+`Introduction to ONNX Registry `_
+
+Export a PyTorch model to ONNX
+==============================
+
+**Author**: `Thiago Crepaldi `_
+
+.. note::
+ As of PyTorch 2.1, there are two versions of ONNX Exporter.
+
+ * ``torch.onnx.dynamo_export`` is the newest (still in beta) exporter based on the TorchDynamo technology released with PyTorch 2.0
+ * ``torch.onnx.export`` is based on TorchScript backend and has been available since PyTorch 1.2.0
+
+"""
+
+###############################################################################
+# In the `60 Minute Blitz `_,
+# we had the opportunity to learn about PyTorch at a high level and train a small neural network to classify images.
+# In this tutorial, we are going to expand this to describe how to convert a model defined in PyTorch into the
+# ONNX format using TorchDynamo and the ``torch.onnx.dynamo_export`` ONNX exporter.
+#
+# While PyTorch is great for iterating on the development of models, the model can be deployed to production
+# using different formats, including `ONNX `_ (Open Neural Network Exchange)!
+#
+# ONNX is a flexible open standard format for representing machine learning models which standardized representations
+# of machine learning allow them to be executed across a gamut of hardware platforms and runtime environments
+# from large-scale cloud-based supercomputers to resource-constrained edge devices, such as your web browser and phone.
+#
+# In this tutorial, we’ll learn how to:
+#
+# 1. Install the required dependencies.
+# 2. Author a simple image classifier model.
+# 3. Export the model to ONNX format.
+# 4. Save the ONNX model in a file.
+# 5. Visualize the ONNX model graph using `Netron `_.
+# 6. Execute the ONNX model with `ONNX Runtime`
+# 7. Compare the PyTorch results with the ones from the ONNX Runtime.
+#
+# 1. Install the required dependencies
+# ------------------------------------
+# Because the ONNX exporter uses ``onnx`` and ``onnxscript`` to translate PyTorch operators into ONNX operators,
+# we will need to install them.
+#
+# .. code-block:: bash
+#
+# pip install onnx
+# pip install onnxscript
+#
+# 2. Author a simple image classifier model
+# -----------------------------------------
+#
+# Once your environment is set up, let’s start modeling our image classifier with PyTorch,
+# exactly like we did in the `60 Minute Blitz `_.
+#
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+class MyModel(nn.Module):
+
+ def __init__(self):
+ super(MyModel, self).__init__()
+ self.conv1 = nn.Conv2d(1, 6, 5)
+ self.conv2 = nn.Conv2d(6, 16, 5)
+ self.fc1 = nn.Linear(16 * 5 * 5, 120)
+ self.fc2 = nn.Linear(120, 84)
+ self.fc3 = nn.Linear(84, 10)
+
+ def forward(self, x):
+ x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))
+ x = F.max_pool2d(F.relu(self.conv2(x)), 2)
+ x = torch.flatten(x, 1)
+ x = F.relu(self.fc1(x))
+ x = F.relu(self.fc2(x))
+ x = self.fc3(x)
+ return x
+
+######################################################################
+# 3. Export the model to ONNX format
+# ----------------------------------
+#
+# Now that we have our model defined, we need to instantiate it and create a random 32x32 input.
+# Next, we can export the model to ONNX format.
+
+torch_model = MyModel()
+torch_input = torch.randn(1, 1, 32, 32)
+export_output = torch.onnx.dynamo_export(torch_model, torch_input)
+
+######################################################################
+# As we can see, we didn't need any code change to the model.
+# The resulting ONNX model is stored within ``torch.onnx.ExportOutput`` as a binary protobuf file.
+#
+# 4. Save the ONNX model in a file
+# --------------------------------
+#
+# Although having the exported model loaded in memory is useful in many applications,
+# we can save it to disk with the following code:
+
+export_output.save("my_image_classifier.onnx")
+
+######################################################################
+# You can load the ONNX file back into memory and check if it is well formed with the following code:
+
+import onnx
+onnx_model = onnx.load("my_image_classifier.onnx")
+onnx.checker.check_model(onnx_model)
+
+######################################################################
+# 5. Visualize the ONNX model graph using Netron
+# ----------------------------------------------
+#
+# Now that we have our model saved in a file, we can visualize it with `Netron `_.
+# Netron can either be installed on macOS, Linux or Windows computers, or run directly from the browser.
+# Let's try the web version by opening the following link: https://netron.app/.
+#
+# .. image:: ../../_static/img/onnx/netron_web_ui.png
+# :width: 70%
+# :align: center
+#
+#
+# Once Netron is open, we can drag and drop our ``my_image_classifier.onnx`` file into the browser or select it after
+# clicking the **Open model** button.
+#
+# .. image:: ../../_static/img/onnx/image_clossifier_onnx_modelon_netron_web_ui.png
+# :width: 50%
+#
+#
+# And that is it! We have successfully exported our PyTorch model to ONNX format and visualized it with Netron.
+#
+# 6. Execute the ONNX model with ONNX Runtime
+# -------------------------------------------
+#
+# The last step is executing the ONNX model with `ONNX Runtime`, but before we do that, let's install ONNX Runtime.
+#
+# .. code-block:: bash
+#
+# pip install onnxruntime
+#
+# The ONNX standard does not support all the data structure and types that PyTorch does,
+# so we need to adapt PyTorch input's to ONNX format before feeding it to ONNX Runtime.
+# In our example, the input happens to be the same, but it might have more inputs
+# than the original PyTorch model in more complex models.
+#
+# ONNX Runtime requires an additional step that involves converting all PyTorch tensors to Numpy (in CPU)
+# and wrap them on a dictionary with keys being a string with the input name as key and the numpy tensor as the value.
+#
+# Now we can create an *ONNX Runtime Inference Session*, execute the ONNX model with the processed input
+# and get the output. In this tutorial, ONNX Runtime is executed on CPU, but it could be executed on GPU as well.
+
+import onnxruntime
+
+onnx_input = export_output.adapt_torch_inputs_to_onnx(torch_input)
+print(f"Input length: {len(onnx_input)}")
+print(f"Sample input: {onnx_input}")
+
+ort_session = onnxruntime.InferenceSession("./my_image_classifier.onnx", providers=['CPUExecutionProvider'])
+
+def to_numpy(tensor):
+ return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()
+
+onnxruntime_input = {k.name: to_numpy(v) for k, v in zip(ort_session.get_inputs(), onnx_input)}
+
+onnxruntime_outputs = ort_session.run(None, onnxruntime_input)
+
+######################################################################
+# 7. Compare the PyTorch results with the ones from the ONNX Runtime
+# -----------------------------------------------------------------
+#
+# The best way to determine whether the exported model is looking good is through numerical evaluation
+# against PyTorch, which is our source of truth.
+#
+# For that, we need to execute the PyTorch model with the same input and compare the results with ONNX Runtime's.
+# Before comparing the results, we need to convert the PyTorch's output to match ONNX's format.
+
+torch_outputs = torch_model(torch_input)
+torch_outputs = export_output.adapt_torch_outputs_to_onnx(torch_outputs)
+
+assert len(torch_outputs) == len(onnxruntime_outputs)
+for torch_output, onnxruntime_output in zip(torch_outputs, onnxruntime_outputs):
+ torch.testing.assert_close(torch_output, torch.tensor(onnxruntime_output))
+
+print("PyTorch and ONNX Runtime output matched!")
+print(f"Output length: {len(onnxruntime_outputs)}")
+print(f"Sample output: {onnxruntime_outputs}")
+
+######################################################################
+# Conclusion
+# ----------
+#
+# That is about it! We have successfully exported our PyTorch model to ONNX format,
+# saved the model to disk, viewed it using Netron, executed it with ONNX Runtime
+# and finally compared its numerical results with PyTorch's.
+#
+# Further reading
+# ---------------
+#
+# The list below refers to tutorials that ranges from basic examples to advanced scenarios,
+# not necessarily in the order they are listed.
+# Feel free to jump directly to specific topics of your interest or
+# sit tight and have fun going through all of them to learn all there is about the ONNX exporter.
+#
+# .. include:: /beginner_source/onnx/onnx_toc.txt
+#
+# .. toctree::
+# :hidden:
+#
\ No newline at end of file
diff --git a/beginner_source/onnx/intro_onnx.py b/beginner_source/onnx/intro_onnx.py
new file mode 100644
index 00000000000..d86e917a5e5
--- /dev/null
+++ b/beginner_source/onnx/intro_onnx.py
@@ -0,0 +1,68 @@
+"""
+**Introduction to ONNX** ||
+`Export a PyTorch model to ONNX `_ ||
+`Introduction to ONNX Registry `_
+
+Introduction to ONNX
+====================
+
+Authors:
+`Thiago Crepaldi `_,
+
+`Open Neural Network eXchange (ONNX) `_ is an open standard
+format for representing machine learning models. The ``torch.onnx`` module provides APIs to
+capture the computation graph from a native PyTorch :class:`torch.nn.Module` model and convert
+it into an `ONNX graph `_.
+
+The exported model can be consumed by any of the many
+`runtimes that support ONNX `_,
+including Microsoft's `ONNX Runtime `_.
+
+.. note::
+ Currently, there are two flavors of ONNX exporter APIs,
+ but this tutorial will focus on the ``torch.onnx.dynamo_export``.
+
+The TorchDynamo engine is leveraged to hook into Python's frame evaluation API and dynamically rewrite its
+bytecode into an `FX graph `_.
+The resulting FX Graph is polished before it is finally translated into an
+`ONNX graph `_.
+
+The main advantage of this approach is that the `FX graph `_ is captured using
+bytecode analysis that preserves the dynamic nature of the model instead of using traditional static tracing techniques.
+
+Dependencies
+------------
+
+PyTorch 2.1.0 or newer is required.
+
+The ONNX exporter depends on extra Python packages:
+
+ - `ONNX `_
+ - `ONNX Script `_
+
+They can be installed through `pip `_:
+
+.. note::
+ This tutorial leverages `onnxscript `__
+ to create custom ONNX operators. onnxscript is a Python library that allows users to
+ create custom ONNX operators in Python. It is a prerequisite learning material for
+ this tutorial. Please make sure you have read the onnxscript tutorial before proceeding.
+
+.. code-block:: bash
+
+ pip install --upgrade onnx onnxscript
+
+Further reading
+---------------
+
+The list below refers to tutorials that ranges from basic examples to advanced scenarios,
+not necessarily in the order they are listed.
+Feel free to jump directly to specific topics of your interest or
+sit tight and have fun going through all of them to learn all there is about the ONNX exporter.
+
+.. include:: /beginner_source/onnx/onnx_toc.txt
+
+.. toctree::
+ :hidden:
+
+"""
diff --git a/beginner_source/onnx/onnx_registry_tutorial.py b/beginner_source/onnx/onnx_registry_tutorial.py
new file mode 100644
index 00000000000..f3c6cb94b7a
--- /dev/null
+++ b/beginner_source/onnx/onnx_registry_tutorial.py
@@ -0,0 +1,469 @@
+# -*- coding: utf-8 -*-
+
+"""
+`Introduction to ONNX `_ ||
+`Export a PyTorch model to ONNX `_ ||
+**Introduction to ONNX Registry**
+
+Introduction to ONNX Registry
+=============================
+
+**Authors:** Ti-Tai Wang (titaiwang@microsoft.com)
+"""
+
+
+###############################################################################
+# Overview
+# ~~~~~~~~
+#
+# This tutorial is an introduction to ONNX registry, which
+# empowers users to create their own ONNX registries enabling
+# them to address unsupported operators in ONNX.
+#
+# In this tutorial, we will cover the following scenarios:
+#
+# * Unsupported ATen operators
+# * Unsupported ATen operators with existing ONNX Runtime support
+# * Unsupported PyTorch operators with no ONNX Runtime support
+#
+
+
+import torch
+print(torch.__version__)
+torch.manual_seed(191009) # set the seed for reproducibility
+
+import onnxscript # pip install onnxscript
+print(onnxscript.__version__)
+
+# NOTE: opset18 is the only version of ONNX operators we are
+# using in torch.onnx.dynamo_export for now.
+from onnxscript import opset18
+
+import onnxruntime # pip install onnxruntime
+print(onnxruntime.__version__)
+
+
+######################################################################
+# Unsupported ATen operators
+# ---------------------------------
+#
+# ATen operators are implemented by PyTorch, and the ONNX exporter team must manually implement the
+# conversion from ATen operators to ONNX operators through [ONNX Script](https://onnxscript.ai/). Although the ONNX exporter
+# team has been making their best efforts to support as many ATen operators as possible, some ATen
+# operators are still not supported. In this section, we will demonstrate how you can implement any
+# unsupported ATen operators, which can contribute back to the project through the PyTorch GitHub repo.
+#
+# If the model cannot be exported to ONNX, for instance, :class:`aten::add.Tensor` is not supported
+# by ONNX The error message can be found, and is as follows (for example, ``aten::add.Tensor``):
+# ``RuntimeErrorWithDiagnostic: Unsupported FX nodes: {'call_function': ['aten.add.Tensor']}. ``
+#
+# To support unsupported ATen operators, we need two things:
+# * The unsupported ATen operator namespace, operator name, and the
+# corresponding overload (for example ``::. - aten::add.Tensor``),
+# which can be found in the error message.
+# * The implementation of the operator in `ONNX Script `__.
+#
+
+
+# NOTE: `is_registered_op` is a method in ONNX registry that checks
+# whether the operator is supported by ONNX. If the operator is not
+# supported, it will return False. Otherwise, it will return True.
+onnx_registry = torch.onnx.OnnxRegistry()
+# aten::add.default and aten::add.Tensor are supported by ONNX
+print(f"aten::add.default is supported by ONNX registry: \
+ {onnx_registry.is_registered_op(namespace='aten', op_name='add', overload='default')}")
+# aten::add.Tensor is the one invoked by torch.ops.aten.add
+print(f"aten::add.Tensor is supported by ONNX registry: \
+ {onnx_registry.is_registered_op(namespace='aten', op_name='add', overload='Tensor')}")
+
+
+######################################################################
+# In this example, we will assume that ``aten::add.Tensor`` is not supported by the ONNX registry,
+# and we will demonstrate how to support it. The ONNX registry allows user overrides for operator
+# registration. In this case, we will override the registration of ``aten::add.Tensor`` with our
+# implementation and verify it. However, this unsupported operator should return False when
+# checked with :meth:`onnx_registry.is_registered_op`.
+#
+
+
+class Model(torch.nn.Module):
+ def forward(self, input_x, input_y):
+ # specifically call out aten::add
+ return torch.ops.aten.add(input_x, input_y)
+
+input_add_x = torch.randn(3, 4)
+input_add_y = torch.randn(3, 4)
+aten_add_model = Model()
+
+
+# Let's create a onnxscript function to support aten::add.Tensor.
+# This can be named anything, and shows later on Netron graph.
+custom_aten = onnxscript.values.Opset(domain="custom.aten", version=1)
+
+# NOTE: The function signature must match the signature of the unsupported ATen operator.
+# https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/native_functions.yaml
+# NOTE: All attributes must be annotated with type hints.
+@onnxscript.script(custom_aten)
+def custom_aten_add(input_x, input_y, alpha: float = 1.0):
+ alpha = opset18.CastLike(alpha, input_y)
+ input_y = opset18.Mul(input_y, alpha)
+ return opset18.Add(input_x, input_y)
+
+
+# Now we have both things we need to support unsupported ATen operators.
+# Let's register the custom_aten_add function to ONNX registry, and
+# export the model to ONNX again.
+onnx_registry.register_op(
+ namespace="aten", op_name="add", overload="Tensor", function=custom_aten_add
+ )
+print(f"aten::add.Tensor is supported by ONNX registry: \
+ {onnx_registry.is_registered_op(namespace='aten', op_name='add', overload='Tensor')}"
+ )
+export_options = torch.onnx.ExportOptions(onnx_registry=onnx_registry)
+export_output = torch.onnx.dynamo_export(
+ aten_add_model, input_add_x, input_add_y, export_options=export_options
+ )
+
+######################################################################
+# Make sure the model uses ``custom_aten_add`` instead of ``aten::add.Tensor``
+# The graph has one graph node for ``custom_aten_add``, and inside
+# ``custom_aten_add`` there are four function nodes, one for each
+# operator, and one for constant attribute.
+#
+
+# graph node domain is the custom domain we registered
+assert export_output.model_proto.graph.node[0].domain == "custom.aten"
+assert len(export_output.model_proto.graph.node) == 1
+# graph node name is the function name
+assert export_output.model_proto.graph.node[0].op_type == "custom_aten_add"
+# function node domain is empty because we use standard ONNX operators
+assert export_output.model_proto.functions[0].node[3].domain == ""
+# function node name is the standard ONNX operator name
+assert export_output.model_proto.functions[0].node[3].op_type == "Add"
+
+
+######################################################################
+# ``custom_aten_add_model`` ONNX graph in Netron:
+#
+# .. image:: /_static/img/onnx/custom_aten_add_model.png
+# :width: 70%
+# :align: center
+#
+# Inside the custom_aten_add function, we can see the three ONNX nodes we
+# used in the function (CastLike, Add, and Mul), and one constant attribute:
+#
+# .. image:: /_static/img/onnx/custom_aten_add_function.png
+# :width: 70%
+# :align: center
+#
+# After checking the ONNX graph, we can use ONNX Runtime to run the model,
+#
+
+
+# Use ONNX Runtime to run the model, and compare the results with PyTorch
+export_output.save("./custom_add_model.onnx")
+ort_session = onnxruntime.InferenceSession(
+ "./custom_add_model.onnx", providers=['CPUExecutionProvider']
+ )
+
+def to_numpy(tensor):
+ return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()
+
+onnx_input = export_output.adapt_torch_inputs_to_onnx(input_add_x, input_add_y)
+onnxruntime_input = {k.name: to_numpy(v) for k, v in zip(ort_session.get_inputs(), onnx_input)}
+onnxruntime_outputs = ort_session.run(None, onnxruntime_input)
+
+torch_outputs = aten_add_model(input_add_x, input_add_y)
+torch_outputs = export_output.adapt_torch_outputs_to_onnx(torch_outputs)
+
+assert len(torch_outputs) == len(onnxruntime_outputs)
+for torch_output, onnxruntime_output in zip(torch_outputs, onnxruntime_outputs):
+ torch.testing.assert_close(torch_output, torch.tensor(onnxruntime_output))
+
+
+######################################################################
+# Unsupported ATen operators with existing ONNX Runtime support
+# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+#
+# In this case, the unsupported ATen operator is supported by ONNX Runtime but not
+# supported by ONNX spec. This occurs because ONNX Runtime users can implement their
+# custom operators, which ONNX Runtime supports. When the need arises, ONNX Runtime
+# will contribute these custom operators to the ONNX spec. Therefore, in the ONNX registry,
+# we only need to register the operator with the recognized namespace and operator name.
+#
+# In the following example, we would like to use the Gelu in ONNX Runtime,
+# which is not the same Gelu in ONNX spec. Thus, we register the Gelu with
+# the namespace "com.microsoft" and operator name "Gelu".
+#
+
+
+class CustomGelu(torch.nn.Module):
+ def forward(self, input_x):
+ return torch.ops.aten.gelu(input_x)
+
+# com.microsoft is an official ONNX Runtime namspace
+custom_ort = onnxscript.values.Opset(domain="com.microsoft", version=1)
+
+# NOTE: The function signature must match the signature of the unsupported ATen operator.
+# https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/native_functions.yaml
+# NOTE: All attributes must be annotated with type hints.
+@onnxscript.script(custom_ort)
+def custom_aten_gelu(input_x, approximate: str = "none"):
+ # We know com.microsoft::Gelu is supported by ONNX Runtime
+ # It's only not supported by ONNX
+ return custom_ort.Gelu(input_x)
+
+
+onnx_registry = torch.onnx.OnnxRegistry()
+onnx_registry.register_op(
+ namespace="aten", op_name="gelu", overload="default", function=custom_aten_gelu)
+export_options = torch.onnx.ExportOptions(onnx_registry=onnx_registry)
+
+aten_gelu_model = CustomGelu()
+input_gelu_x = torch.randn(3, 3)
+
+export_output = torch.onnx.dynamo_export(
+ aten_gelu_model, input_gelu_x, export_options=export_options
+ )
+
+
+######################################################################
+# Make sure the model uses :func:`custom_aten_gelu`` instead of
+# :class:`aten::gelu` The graph has one graph nodes for
+# ``custom_aten_gelu``, and inside ``custom_aten_gelu``, there is a function
+# node for Gelu with namespace "com.microsoft".
+#
+
+# graph node domain is the custom domain we registered
+assert export_output.model_proto.graph.node[0].domain == "com.microsoft"
+# graph node name is the function name
+assert export_output.model_proto.graph.node[0].op_type == "custom_aten_gelu"
+# function node domain is the custom domain we registered
+assert export_output.model_proto.functions[0].node[0].domain == "com.microsoft"
+# function node name is the node name used in the function
+assert export_output.model_proto.functions[0].node[0].op_type == "Gelu"
+
+
+######################################################################
+# The following diagram shows``custom_aten_gelu_model`` ONNX graph in Netron:
+#
+# .. image:: /_static/img/onnx/custom_aten_gelu_model.png
+# :width: 70%
+# :align: center
+#
+# Inside the custom_aten_gelu function, we can see the Gelu node from module
+# "com.microsoft" we used in the function:
+#
+# .. image:: /_static/img/onnx/custom_aten_gelu_function.png
+#
+# After checking the ONNX graph, we can use ONNX Runtime to run the model,
+
+
+export_output.save("./custom_gelu_model.onnx")
+ort_session = onnxruntime.InferenceSession(
+ "./custom_gelu_model.onnx", providers=['CPUExecutionProvider']
+ )
+
+def to_numpy(tensor):
+ return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()
+
+onnx_input = export_output.adapt_torch_inputs_to_onnx(input_gelu_x)
+onnxruntime_input = {k.name: to_numpy(v) for k, v in zip(ort_session.get_inputs(), onnx_input)}
+onnxruntime_outputs = ort_session.run(None, onnxruntime_input)
+
+torch_outputs = aten_gelu_model(input_gelu_x)
+torch_outputs = export_output.adapt_torch_outputs_to_onnx(torch_outputs)
+
+assert len(torch_outputs) == len(onnxruntime_outputs)
+for torch_output, onnxruntime_output in zip(torch_outputs, onnxruntime_outputs):
+ torch.testing.assert_close(torch_output, torch.tensor(onnxruntime_output))
+
+
+######################################################################
+# Unsupported PyTorch operators with no ONNX Runtime support
+# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+#
+# In this case, the operator is not supported by any frameworks, and we
+# would like to use it in ONNX graph. Therefore, we need to implement
+# the operator in three places:
+#
+# 1. PyTorch FX graph
+# 2. ONNX Registry
+# 3. ONNX Runtime
+#
+# In the following example, we would like to use a custom operator
+# that takes one tensor input, and returns an input. The operator adds
+# the input to itself, and returns the rounded result.
+#
+# **Custom Ops Registration in PyTorch FX Graph (Beta)**
+#
+# Firstly, we need to implement the operator in PyTorch FX graph.
+# This can be done by using ``torch._custom_op``.
+#
+
+# NOTE: This is a beta feature in PyTorch, and is subject to change.
+from torch._custom_op import impl as custom_op
+
+@custom_op.custom_op("mylibrary::addandround_op")
+def addandround_op(tensor_x: torch.Tensor) -> torch.Tensor:
+ ...
+
+@addandround_op.impl_abstract()
+def addandround_op_impl_abstract(tensor_x):
+ return torch.empty_like(tensor_x)
+
+@addandround_op.impl("cpu")
+def addandround_op_impl(tensor_x):
+ # add x to itself, and round the result
+ return torch.round(tensor_x + tensor_x)
+
+torch._dynamo.allow_in_graph(addandround_op)
+
+class CustomFoo(torch.nn.Module):
+ def forward(self, tensor_x):
+ return addandround_op(tensor_x)
+
+input_addandround_x = torch.randn(3)
+custom_addandround_model = CustomFoo()
+
+
+######################################################################
+# **Custom Ops Registration in ONNX Registry**
+#
+# For the step 2 and 3, we need to implement the operator in ONNX registry.
+# In this example, we will implement the operator in ONNX registry
+# with the namespace "test.customop" and operator name "CustomOpOne",
+# and "CustomOpTwo". These two ops are registered and built in
+# `cpu_ops.cc `__.
+#
+
+
+custom_opset = onnxscript.values.Opset(domain="test.customop", version=1)
+
+# NOTE: The function signature must match the signature of the unsupported ATen operator.
+# https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/native_functions.yaml
+# NOTE: All attributes must be annotated with type hints.
+@onnxscript.script(custom_opset)
+def custom_addandround(input_x):
+ # The same as opset18.Add(x, x)
+ add_x = custom_opset.CustomOpOne(input_x, input_x)
+ # The same as opset18.Round(x, x)
+ round_x = custom_opset.CustomOpTwo(add_x)
+ # Cast to FLOAT to match the ONNX type
+ return opset18.Cast(round_x, to=1)
+
+
+onnx_registry = torch.onnx.OnnxRegistry()
+onnx_registry.register_op(
+ namespace="mylibrary", op_name="addandround_op", overload="default", function=custom_addandround
+ )
+
+export_options = torch.onnx.ExportOptions(onnx_registry=onnx_registry)
+export_output = torch.onnx.dynamo_export(
+ custom_addandround_model, input_addandround_x, export_options=export_options
+ )
+export_output.save("./custom_addandround_model.onnx")
+
+
+######################################################################
+# The exported model proto is accessible through export_output.model_proto.
+# The graph has one graph nodes for custom_addandround, and inside custom_addandround,
+# there are two function nodes, one for each operator.
+#
+
+assert export_output.model_proto.graph.node[0].domain == "test.customop"
+assert export_output.model_proto.graph.node[0].op_type == "custom_addandround"
+assert export_output.model_proto.functions[0].node[0].domain == "test.customop"
+assert export_output.model_proto.functions[0].node[0].op_type == "CustomOpOne"
+assert export_output.model_proto.functions[0].node[1].domain == "test.customop"
+assert export_output.model_proto.functions[0].node[1].op_type == "CustomOpTwo"
+
+
+######################################################################
+# custom_addandround_model ONNX graph in Netron:
+#
+# .. image:: /_static/img/onnx/custom_addandround_model.png
+# :width: 70%
+# :align: center
+#
+# Inside the custom_addandround function, we can see the two CustomOp nodes we
+# used in the function (CustomOpOne, and CustomOpTwo), and they are from module
+# "test.customop":
+#
+# .. image:: /_static/img/onnx/custom_addandround_function.png
+#
+
+######################################################################
+# **Custom Ops Registration in ONNX Runtime**
+#
+# To link your custom op library to ONNX Runtime, you need to
+# compile your C++ code into a shared library and link it to ONNX Runtime.
+# Follow the instructions below:
+#
+# 1. Implement your custom op in C++ by following
+# `ONNX Runtime instructions <`https://github.com/microsoft/onnxruntime/blob/gh-pages/docs/reference/operators/add-custom-op.md>`__.
+# 2. Download ONNX Runtime source distribution from
+# `ONNX Runtime releases `__.
+# 3. Compile and link your custom op library to ONNX Runtime, for example:
+#
+# .. code-block:: bash
+#
+# $ gcc -shared -o libcustom_op_library.so custom_op_library.cc -L /path/to/downloaded/ort/lib/ -lonnxruntime -fPIC
+#
+# 4. Run the model with ONNX Runtime Python API
+#
+# .. code-block:: python
+#
+# ort_session_options = onnxruntime.SessionOptions()
+#
+# # NOTE: Link the custom op library to ONNX Runtime and replace the path
+# # with the path to your custom op library
+# ort_session_options.register_custom_ops_library(
+# "/path/to/libcustom_op_library.so"
+# )
+# ort_session = onnxruntime.InferenceSession(
+# "./custom_addandround_model.onnx", providers=['CPUExecutionProvider'], sess_options=ort_session_options)
+#
+# def to_numpy(tensor):
+# return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()
+#
+# onnx_input = export_output.adapt_torch_inputs_to_onnx(input_addandround_x)
+# onnxruntime_input = {k.name: to_numpy(v) for k, v in zip(ort_session.get_inputs(), onnx_input)}
+# onnxruntime_outputs = ort_session.run(None, onnxruntime_input)
+#
+# torch_outputs = custom_addandround_model(input_addandround_x)
+# torch_outputs = export_output.adapt_torch_outputs_to_onnx(torch_outputs)
+#
+# assert len(torch_outputs) == len(onnxruntime_outputs)
+# for torch_output, onnxruntime_output in zip(torch_outputs, onnxruntime_outputs):
+# torch.testing.assert_close(torch_output, torch.tensor(onnxruntime_output))
+#
+######################################################################
+# Conclusion
+# ----------
+#
+# Congratulations! That is it!. In this tutorial, we have learned how to effectively
+# manage unsupported operators within ONNX models. We explored the ONNX registry and
+# discovered how to create custom implementations for unsupported ATen operators using
+# ONNX Script, registering them in the ONNX registry, and integrating them with ONNX Runtime.
+# Additionally, we've gained insights into addressing unsupported PyTorch operators by
+# implementing them in PyTorch FX, registering them in the ONNX registry, and linking
+# them to ONNX Runtime. The tutorial concluded by demonstrating how to export models
+# containing custom operators to ONNX files and validate their functionality using
+# ONNX Runtime, providing us with a comprehensive understanding of handling unsupported
+# operators in the ONNX ecosystem.
+#
+######################################################################
+# Further reading
+# ---------------
+#
+# The list below refers to tutorials that ranges from basic examples to advanced scenarios,
+# not necessarily in the order they are listed.
+# Feel free to jump directly to specific topics of your interest or
+# sit tight and have fun going through all of them to learn all there is about the ONNX exporter.
+#
+# .. include:: /beginner_source/onnx/onnx_toc.txt
+#
+# .. toctree::
+# :hidden:
+#
diff --git a/beginner_source/onnx/onnx_toc.txt b/beginner_source/onnx/onnx_toc.txt
new file mode 100644
index 00000000000..bd57025641b
--- /dev/null
+++ b/beginner_source/onnx/onnx_toc.txt
@@ -0,0 +1,2 @@
+| 1. `Export a PyTorch model to ONNX `_
+| 2. `Introduction to ONNX registry `_
\ No newline at end of file
diff --git a/en-wordlist.txt b/en-wordlist.txt
index ee2c79b6b41..bb504f7b292 100644
--- a/en-wordlist.txt
+++ b/en-wordlist.txt
@@ -132,6 +132,7 @@ Lipschitz
logits
Lua
Luong
+macos
MLP
MLPs
MNIST
@@ -147,11 +148,15 @@ NTK
NUMA
NaN
NanoGPT
+Netron
NeurIPS
NumPy
Numericalization
Numpy's
ONNX
+ONNX's
+ONNX Runtime
+ONNX Script
OpenAI
OpenMP
Ornstein
@@ -386,6 +391,7 @@ prewritten
primals
profiler
profilers
+protobuf
py
pytorch
quantized
diff --git a/index.rst b/index.rst
index 3070002466f..0d165be5a8c 100644
--- a/index.rst
+++ b/index.rst
@@ -272,6 +272,22 @@ What's new in PyTorch tutorials?
:tags: Text
+.. ONNX
+
+.. customcarditem::
+ :header: (optional) Exporting a PyTorch model to ONNX using TorchDynamo backend and Running it using ONNX Runtime
+ :card_description: Build a image classifier model in PyTorch and convert it to ONNX before deploying it with ONNX Runtime.
+ :image: _static/img/thumbnails/cropped/Exporting-PyTorch-Models-to-ONNX-Graphs.png
+ :link: beginner/onnx/export_simple_model_to_onnx_tutorial.html
+ :tags: Production,ONNX,Backends
+
+.. customcarditem::
+ :header: Introduction to ONNX Registry
+ :card_description: Demonstrate end-to-end how to address unsupported operators by using ONNX Registry.
+ :image: _static/img/thumbnails/cropped/Exporting-PyTorch-Models-to-ONNX-Graphs.png
+ :link: advanced/onnx_registry_tutorial.html
+ :tags: Production,ONNX,Backends
+
.. Reinforcement Learning
.. customcarditem::
@@ -329,11 +345,12 @@ What's new in PyTorch tutorials?
:tags: Production,TorchScript
.. customcarditem::
- :header: (optional) Exporting a Model from PyTorch to ONNX and Running it using ONNX Runtime
+ :header: (optional) Exporting a PyTorch Model to ONNX using TorchScript backend and Running it using ONNX Runtime
:card_description: Convert a model defined in PyTorch into the ONNX format and then run it with ONNX Runtime.
:image: _static/img/thumbnails/cropped/optional-Exporting-a-Model-from-PyTorch-to-ONNX-and-Running-it-using-ONNX-Runtime.png
:link: advanced/super_resolution_with_onnxruntime.html
- :tags: Production
+ :tags: Production,ONNX
+
.. Code Transformations with FX
@@ -902,6 +919,14 @@ Additional Resources
beginner/torchtext_custom_dataset_tutorial
+.. toctree::
+ :maxdepth: 2
+ :includehidden:
+ :hidden:
+ :caption: Backends
+
+ beginner/onnx/intro_onnx
+
.. toctree::
:maxdepth: 2
:includehidden:
@@ -918,6 +943,7 @@ Additional Resources
:hidden:
:caption: Deploying PyTorch Models in Production
+ beginner/onnx/intro_onnx
intermediate/flask_rest_api_tutorial
beginner/Intro_to_TorchScript_tutorial
advanced/cpp_export
diff --git a/intermediate_source/memory_format_tutorial.py b/intermediate_source/memory_format_tutorial.py
index f08980265de..26bc5c9d53c 100644
--- a/intermediate_source/memory_format_tutorial.py
+++ b/intermediate_source/memory_format_tutorial.py
@@ -131,7 +131,7 @@
# produces output in contiguous memory format. Otherwise, output will
# be in channels last memory format.
-if torch.backends.cudnn.version() >= 7603:
+if torch.backends.cudnn.is_available() and torch.backends.cudnn.version() >= 7603:
model = torch.nn.Conv2d(8, 4, 3).cuda().half()
model = model.to(memory_format=torch.channels_last) # Module parameters need to be channels last
diff --git a/requirements.txt b/requirements.txt
index 84c35e78d08..31b3f0ad16b 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -33,6 +33,9 @@ datasets
transformers
torchmultimodal-nightly # needs to be updated to stable as soon as it's avaialable
deep_phonemizer==0.0.17
+onnx
+onnxscript
+onnxruntime
importlib-metadata==6.8.0