Skip to content

Commit 5c17838

Browse files
author
Thiago Crepaldi
committed
Add ONNX tutorial using torch.onnx.dynamo_export API
1 parent ea73167 commit 5c17838

13 files changed

+265
-15
lines changed
Loading

_static/img/onnx/netron_web_ui.png

64.5 KB
Loading
Loading

advanced_source/super_resolution_with_onnxruntime.py

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,17 @@
11
"""
22
(optional) Exporting a Model from PyTorch to ONNX and Running it using ONNX Runtime
3-
========================================================================
3+
===================================================================================
4+
5+
.. Note::
6+
As of PyTorch 2.1, there are two versions of ONNX Exporter.
7+
8+
* ``torch.onnx.dynamo_export`is the newest (still in beta) exporter based on the TorchDynamo technology released with PyTorch 2.0
9+
* ``torch.onnx.export`` is based on TorchScript backend and has been available since PyTorch 1.2.0
410
511
In this tutorial, we describe how to convert a model defined
6-
in PyTorch into the ONNX format and then run it with ONNX Runtime.
12+
in PyTorch into the ONNX format using the TorchScript ``torch.onnx.export` ONNX exporter.
713
14+
The exported model will be executed with ONNX Runtime.
815
ONNX Runtime is a performance-focused engine for ONNX models,
916
which inferences efficiently across multiple platforms and hardware
1017
(Windows, Linux, and Mac and on both CPUs and GPUs).
@@ -15,13 +22,17 @@
1522
For this tutorial, you will need to install `ONNX <https://github.com/onnx/onnx>`__
1623
and `ONNX Runtime <https://github.com/microsoft/onnxruntime>`__.
1724
You can get binary builds of ONNX and ONNX Runtime with
18-
``pip install onnx onnxruntime``.
25+
26+
.. code-block:: bash
27+
28+
%%bash
29+
pip install onnxruntime
30+
1931
ONNX Runtime recommends using the latest stable runtime for PyTorch.
2032
2133
"""
2234

2335
# Some standard imports
24-
import io
2536
import numpy as np
2637

2738
from torch import nn
@@ -185,7 +196,7 @@ def _initialize_weights(self):
185196

186197
import onnxruntime
187198

188-
ort_session = onnxruntime.InferenceSession("super_resolution.onnx")
199+
ort_session = onnxruntime.InferenceSession("super_resolution.onnx", providers=["CPUExecutionProvider"])
189200

190201
def to_numpy(tensor):
191202
return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()

beginner_source/onnx/README.txt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,3 +5,6 @@ ONNX
55
Introduction to ONNX
66
https://pytorch.org/tutorials/onnx/intro_onnx.html
77

8+
2. export_simple_model_to_onnx_tutorial.py
9+
Export a PyTorch model to ONNX
10+
https://pytorch.org/tutorials/beginner/onnx/export_simple_model_to_onnx_tutorial.html
Lines changed: 211 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,211 @@
1+
# -*- coding: utf-8 -*-
2+
"""
3+
`Introduction to ONNX <intro_onnx.html>`_ ||
4+
**Export a PyTorch model to ONNX**
5+
6+
Export a PyTorch model to ONNX
7+
==============================
8+
9+
**Author**: `Thiago Crepaldi <https://github.com/thiagocrepaldi>`_
10+
11+
.. note::
12+
As of PyTorch 2.1, there are two versions of ONNX Exporter.
13+
14+
* ``torch.onnx.dynamo_export`` is the newest (still in beta) exporter based on the TorchDynamo technology released with PyTorch 2.0
15+
* ``torch.onnx.export`` is based on TorchScript backend and has been available since PyTorch 1.2.0
16+
17+
"""
18+
19+
###############################################################################
20+
# In the `60 Minute Blitz <https://pytorch.org/tutorials/beginner/deep_learning_60min_blitz.html>`_,
21+
# we had the opportunity to learn about PyTorch at a high level and train a small neural network to classify images.
22+
# In this tutorial, we are going to expand this to describe how to convert a model defined in PyTorch into the
23+
# ONNX format using TorchDynamo and the ``torch.onnx.dynamo_export`` ONNX exporter.
24+
#
25+
# While PyTorch is great for iterating on the development of models, the model can be deployed to production
26+
# using different formats, including `ONNX <https://onnx.ai/>`_ (Open Neural Network Exchange)!
27+
#
28+
# ONNX is a flexible open standard format for representing machine learning models which standardized representations
29+
# of machine learning allow them to be executed across a gamut of hardware platforms and runtime environments
30+
# from large-scale cloud-based supercomputers to resource-constrained edge devices, such as your web browser and phone.
31+
#
32+
# In this tutorial, we’ll learn how to:
33+
#
34+
# 1. Install the required dependencies.
35+
# 2. Author a simple image classifier model.
36+
# 3. Export the model to ONNX format.
37+
# 4. Save the ONNX model in a file.
38+
# 5. Visualize the ONNX model graph using `Netron <https://github.com/lutzroeder/netron>`_.
39+
# 6. Execute the ONNX model with `ONNX Runtime`
40+
# 7. Compare the PyTorch results with the ones from the ONNX Runtime.
41+
#
42+
# 1. Install the required dependencies
43+
# ------------------------------------
44+
# Because the ONNX exporter uses ``onnx`` and ``onnxscript`` to translate PyTorch operators into ONNX operators,
45+
# we will need to install them.
46+
#
47+
# .. code-block:: bash
48+
#
49+
# pip install onnx
50+
# pip install onnxscript
51+
#
52+
# 2. Author a simple image classifier model
53+
# -----------------------------------------
54+
#
55+
# Once your environment is set up, let’s start modeling our image classifier with PyTorch,
56+
# exactly like we did in the `60 Minute Blitz <https://pytorch.org/tutorials/beginner/deep_learning_60min_blitz.html>`_.
57+
#
58+
59+
import torch
60+
import torch.nn as nn
61+
import torch.nn.functional as F
62+
63+
64+
class MyModel(nn.Module):
65+
66+
def __init__(self):
67+
super(MyModel, self).__init__()
68+
self.conv1 = nn.Conv2d(1, 6, 5)
69+
self.conv2 = nn.Conv2d(6, 16, 5)
70+
self.fc1 = nn.Linear(16 * 5 * 5, 120)
71+
self.fc2 = nn.Linear(120, 84)
72+
self.fc3 = nn.Linear(84, 10)
73+
74+
def forward(self, x):
75+
x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))
76+
x = F.max_pool2d(F.relu(self.conv2(x)), 2)
77+
x = torch.flatten(x, 1)
78+
x = F.relu(self.fc1(x))
79+
x = F.relu(self.fc2(x))
80+
x = self.fc3(x)
81+
return x
82+
83+
######################################################################
84+
# 3. Export the model to ONNX format
85+
# ----------------------------------
86+
#
87+
# Now that we have our model defined, we need to instantiate it and create a random 32x32 input.
88+
# Next, we can export the model to ONNX format.
89+
90+
torch_model = MyModel()
91+
torch_input = torch.randn(1, 1, 32, 32)
92+
export_output = torch.onnx.dynamo_export(torch_model, torch_input)
93+
94+
######################################################################
95+
# As we can see, we didn't need any code change to the model.
96+
# The resulting ONNX model is stored within ``torch.onnx.ExportOutput`` as a binary protobuf file.
97+
#
98+
# 4. Save the ONNX model in a file
99+
# --------------------------------
100+
#
101+
# Although having the exported model loaded in memory is useful in many applications,
102+
# we can save it to disk with the following code:
103+
104+
export_output.save("my_image_classifier.onnx")
105+
106+
######################################################################
107+
# The ONNX file can be loaded back into memory and checked if it is well formed with the following code:
108+
109+
import onnx
110+
onnx_model = onnx.load("my_image_classifier.onnx")
111+
onnx.checker.check_model(onnx_model)
112+
113+
######################################################################
114+
# 5. Visualize the ONNX model graph using Netron
115+
# ----------------------------------------------
116+
#
117+
# Now that we have our model saved in a file, we can visualize it with `Netron <https://github.com/lutzroeder/netron>`_.
118+
# Netron can either be installed on macos, Linux or Windows computers, or run directly from the browser.
119+
# Let's try the web version by opening the following link: https://netron.app/.
120+
#
121+
# .. image:: ../../_static/img/onnx/netron_web_ui.png
122+
# :width: 70%
123+
# :align: center
124+
#
125+
#
126+
# Once Netron is open, we can drag and drop our ``my_image_classifier.onnx`` file into the browser or select it after
127+
# clicking the **Open model** button.
128+
#
129+
# .. image:: ../../_static/img/onnx/image_clossifier_onnx_modelon_netron_web_ui.png
130+
# :width: 50%
131+
#
132+
#
133+
# And that is it! We have successfully exported our PyTorch model to ONNX format and visualized it with Netron.
134+
#
135+
# 6. Execute the ONNX model with ONNX Runtime
136+
# -------------------------------------------
137+
#
138+
# The last step is executing the ONNX model with `ONNX Runtime`, but before we do that, let's install ONNX Runtime.
139+
#
140+
# .. code-block:: bash
141+
#
142+
# pip install onnxruntime
143+
#
144+
# The ONNX standard does not support all the data structure and types that PyTorch does,
145+
# so we need to adapt PyTorch input's to ONNX format before feeding it to ONNX Runtime.
146+
# In our example, the input happens to be the same, but it might have more inputs
147+
# than the original PyTorch model in more complex models.
148+
#
149+
# ONNX Runtime requires an additional step that involves converting all PyTorch tensors to Numpy (in CPU)
150+
# and wrap them on a dictionary with keys being a string with the input name as key and the numpy tensor as the value.
151+
#
152+
# Now we can create an *ONNX Runtime Inference Session*, execute the ONNX model with the processed input
153+
# and get the output. In this tutorial, ONNX Runtime is executed on CPU, but it could be executed on GPU as well.
154+
155+
import onnxruntime
156+
157+
onnx_input = export_output.adapt_torch_inputs_to_onnx(torch_input)
158+
print(f"Input length: {len(onnx_input)}")
159+
print(f"Sample input: {onnx_input}")
160+
161+
ort_session = onnxruntime.InferenceSession("./my_image_classifier.onnx", providers=['CPUExecutionProvider'])
162+
163+
def to_numpy(tensor):
164+
return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()
165+
166+
onnxruntime_input = {k.name: to_numpy(v) for k, v in zip(ort_session.get_inputs(), onnx_input)}
167+
168+
onnxruntime_outputs = ort_session.run(None, onnxruntime_input)
169+
170+
######################################################################
171+
# 7. Compare the PyTorch results with the ones from the ONNX Runtime
172+
# -----------------------------------------------------------------
173+
#
174+
# The best way to determine whether the exported model is looking good is through numerical evaluation
175+
# against PyTorch, which is our source of truth.
176+
#
177+
# For that, we need to execute the PyTorch model with the same input and compare the results with ONNX Runtime's.
178+
# Before comparing the results, we need to convert the PyTorch's output to match ONNX's format.
179+
180+
torch_outputs = torch_model(torch_input)
181+
torch_outputs = export_output.adapt_torch_outputs_to_onnx(torch_outputs)
182+
183+
assert len(torch_outputs) == len(onnxruntime_outputs)
184+
for torch_output, onnxruntime_output in zip(torch_outputs, onnxruntime_outputs):
185+
torch.testing.assert_close(torch_output, torch.tensor(onnxruntime_output))
186+
187+
print("PyTorch and ONNX Runtime output matched!")
188+
print(f"Output length: {len(onnxruntime_outputs)}")
189+
print(f"Sample output: {onnxruntime_outputs}")
190+
191+
######################################################################
192+
# Conclusion
193+
# ----------
194+
#
195+
# That is about it! We have successfully exported our PyTorch model to ONNX format,
196+
# saved the model to disk, viewed it using Netron, executed it with ONNX Runtime
197+
# and finally compared its numerical results with PyTorch's.
198+
#
199+
# Further reading
200+
# ---------------
201+
#
202+
# The list below refers to tutorials that ranges from basic examples to advanced scenarios,
203+
# not necessarily in the order they are listed.
204+
# Feel free to jump directly to specific topics of your interest or
205+
# sit tight and have fun going through all of them to learn all there is about the ONNX exporter.
206+
#
207+
# .. include:: /beginner_source/onnx/onnx_toc.txt
208+
#
209+
# .. toctree::
210+
# :hidden:
211+
#

beginner_source/onnx/intro_onnx.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
"""
2-
**Introduction to ONNX**
2+
**Introduction to ONNX** ||
3+
`Export a PyTorch model to ONNX <export_simple_model_to_onnx_tutorial.html>`_
34
45
Introduction to ONNX
56
====================
@@ -21,7 +22,7 @@
2122
but this tutorial will focus on the ``torch.onnx.dynamo_export``.
2223
2324
The TorchDynamo engine is leveraged to hook into Python's frame evaluation API and dynamically rewrite its
24-
bytecode into an `FX graph <https://pytorch.org/docs/stable/fx.html>`_.
25+
bytecode into an `FX graph <https://pytorch.org/docs/stable/fx.html>`_.
2526
The resulting FX Graph is polished before it is finally translated into an
2627
`ONNX graph <https://github.com/onnx/onnx/blob/main/docs/IR.md>`_.
2728
@@ -42,7 +43,15 @@
4243
4344
pip install --upgrade onnx onnxscript
4445
45-
.. include:: /beginner_source/basics/onnx_toc.txt
46+
Further reading
47+
---------------
48+
49+
The list below refers to tutorials that ranges from basic examples to advanced scenarios,
50+
not necessarily in the order they are listed.
51+
Feel free to jump directly to specific topics of your interest or
52+
sit tight and have fun going through all of them to learn all there is about the ONNX exporter.
53+
54+
.. include:: /beginner_source/onnx/onnx_toc.txt
4655
4756
.. toctree::
4857
:hidden:

beginner_source/onnx/onnx_toc.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
| 1. `Export a PyTorch model to ONNX <export_simple_model_to_onnx_tutorial.html>`_

en-wordlist.txt

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,7 @@ LoRa
133133
LRSchedulers
134134
Lua
135135
Luong
136+
macos
136137
MLP
137138
MLPs
138139
MNIST
@@ -148,11 +149,14 @@ NTK
148149
NUMA
149150
NaN
150151
NanoGPT
152+
Netron
151153
NeurIPS
152154
NumPy
153155
Numericalization
154156
Numpy's
155157
ONNX
158+
ONNX's
159+
ONNX Runtime
156160
OpenAI
157161
OpenMP
158162
Ornstein
@@ -389,6 +393,7 @@ prewritten
389393
primals
390394
profiler
391395
profilers
396+
protobuf
392397
py
393398
pytorch
394399
quantized

index.rst

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -276,12 +276,11 @@ What's new in PyTorch tutorials?
276276
.. ONNX
277277
278278
.. customcarditem::
279-
:header: (optional) Exporting a PyTorch Model to ONNX using TorchScript backend and Running it using ONNX Runtime
280-
:card_description: Convert a model defined in PyTorch into the ONNX format and then run it with ONNX Runtime.
281-
:image: _static/img/thumbnails/cropped/optional-Exporting-a-Model-from-PyTorch-to-ONNX-and-Running-it-using-ONNX-Runtime.png
282-
:link: advanced/super_resolution_with_onnxruntime.html
283-
:tags: ONNX,Production
284-
279+
:header: (optional) Exporting a PyTorch model to ONNX using TorchDynamo backend and Running it using ONNX Runtime
280+
:card_description: Build a image classifier model in PyTorch and convert it to ONNX before deploying it with ONNX Runtime.
281+
:image: _static/img/thumbnails/cropped/Exporting-PyTorch-Models-to-ONNX-Graphs.png
282+
:link: beginner/onnx/export_simple_model_to_onnx_tutorial.html
283+
:tags: Production,ONNX,Backends
285284

286285
.. Reinforcement Learning
287286
@@ -339,6 +338,14 @@ What's new in PyTorch tutorials?
339338
:link: advanced/cpp_export.html
340339
:tags: Production,TorchScript
341340

341+
.. customcarditem::
342+
:header: (optional) Exporting a PyTorch Model to ONNX using TorchScript backend and Running it using ONNX Runtime
343+
:card_description: Convert a model defined in PyTorch into the ONNX format and then run it with ONNX Runtime.
344+
:image: _static/img/thumbnails/cropped/optional-Exporting-a-Model-from-PyTorch-to-ONNX-and-Running-it-using-ONNX-Runtime.png
345+
:link: advanced/super_resolution_with_onnxruntime.html
346+
:tags: Production,ONNX
347+
348+
342349
.. Code Transformations with FX
343350
344351
.. customcarditem::

intermediate_source/memory_format_tutorial.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,7 @@
131131
# produces output in contiguous memory format. Otherwise, output will
132132
# be in channels last memory format.
133133

134-
if torch.backends.cudnn.version() >= 7603:
134+
if torch.backends.cudnn.is_available() and torch.backends.cudnn.version() >= 7603:
135135
model = torch.nn.Conv2d(8, 4, 3).cuda().half()
136136
model = model.to(memory_format=torch.channels_last) # Module parameters need to be channels last
137137

requirements.txt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,9 @@ nbformat>=4.2.0
3131
datasets
3232
transformers
3333
torchmultimodal-nightly # needs to be updated to stable as soon as it's avaialable
34+
onnx
35+
onnxscript
36+
onnxruntime
3437

3538
importlib-metadata==6.8.0
3639

0 commit comments

Comments
 (0)