Skip to content

Commit 84105cd

Browse files
author
Thiago Crepaldi
committed
Add ONNX tutorial using torch.onnx.dynamo_export API
1 parent b4e6207 commit 84105cd

9 files changed

+180
-6
lines changed
Loading

_static/img/onnx/netron_web_ui.png

64.5 KB
Loading
Loading

advanced_source/super_resolution_with_onnxruntime.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,16 @@
22
(optional) Exporting a Model from PyTorch to ONNX and Running it using ONNX Runtime
33
========================================================================
44
5+
.. Note::
6+
As of PyTorch 2.1, there are two versions of ONNX Exporter.
7+
8+
* ``torch.onnx.dynamo_export` is the latest and recommended exporter based on the TorchDynamo and is the default starting from PyTorch 2.1
9+
* ``torch.onnx.export`` is based on TorchScript backend and has been the default until PyTorch 2.0.
10+
511
In this tutorial, we describe how to convert a model defined
6-
in PyTorch into the ONNX format and then run it with ONNX Runtime.
12+
in PyTorch into the ONNX format using the legacy ``torch.onnx.export` ONNX exporter.
713
14+
The exported model will be executed it with ONNX Runtime.
815
ONNX Runtime is a performance-focused engine for ONNX models,
916
which inferences efficiently across multiple platforms and hardware
1017
(Windows, Linux, and Mac and on both CPUs and GPUs).
@@ -15,13 +22,16 @@
1522
For this tutorial, you will need to install `ONNX <https://github.com/onnx/onnx>`__
1623
and `ONNX Runtime <https://github.com/microsoft/onnxruntime>`__.
1724
You can get binary builds of ONNX and ONNX Runtime with
18-
``pip install onnx onnxruntime``.
25+
26+
.. code-block:: bash
27+
%%bash
28+
pip install onnxruntime
29+
1930
ONNX Runtime recommends using the latest stable runtime for PyTorch.
2031
2132
"""
2233

2334
# Some standard imports
24-
import io
2535
import numpy as np
2636

2737
from torch import nn
@@ -185,7 +195,7 @@ def _initialize_weights(self):
185195

186196
import onnxruntime
187197

188-
ort_session = onnxruntime.InferenceSession("super_resolution.onnx")
198+
ort_session = onnxruntime.InferenceSession("super_resolution.onnx", providers=['CPUExecutionProvider'])
189199

190200
def to_numpy(tensor):
191201
return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()
Lines changed: 150 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,150 @@
1+
# -*- coding: utf-8 -*-
2+
"""
3+
Export a PyTorch model to ONNX
4+
==============================
5+
6+
**Author**: `Thiago Crepaldi <https://github.com/thiagocrepaldi>`_
7+
8+
.. Note::
9+
As of PyTorch 2.1, there are two versions of ONNX Exporter.
10+
11+
* ``torch.onnx.dynamo_export` is the latest and recommended exporter basedon the TorchDynamo and is the default starting from PyTorch 2.1
12+
* ``torch.onnx.export`` is based on TorchScript backend and has been the default until PyTorch 2.0.
13+
14+
In this tutorial, we describe how to convert a model defined in PyTorch into the ONNX format using
15+
the latest and preferred ``torch.onnx.dynamo_export` ONNX exporter.
16+
17+
"""
18+
19+
###############################################################################
20+
# In the `60 Minute Blitz <https://pytorch.org/tutorials/beginner/deep_learning_60min_blitz.html>`_,
21+
# we had the opportunity to learn about PyTorch at a high level and train a small neural network to classify images.
22+
#
23+
# While PyTorch is great for iterating on the development of models, the resulting models are not typically deployed
24+
# to production in this fashion. This is where `ONNX <https://onnx.ai/>`_ (Open Neural Network Exchange) comes in!
25+
# ONNX is a flexible open standard format for representing machine learning models which standardized representations
26+
# of machine learning that allow them to be executed across a gamut of hardware platforms and runtime environments
27+
# from large-scale cloud-based supercomputers to resource-constrained edge devices such as your web browser and phone.
28+
#
29+
# In this tutorial, we’ll learn how to:
30+
#
31+
# 1. Author a simple image classifier model in PyTorch (from the 60 Minute Blitz tutorial).
32+
# 2. Export the model to ONNX format.
33+
# 3. Save the ONNX model in a file.
34+
# 4. Visualize the ONNX model graph using `Netron <https://github.com/lutzroeder/netron>`_.
35+
# 5. Execute the ONNX model with `ONNX Runtime
36+
#
37+
# Note that because the ONNX exporter uses ``onnx`` and ``onnxscript`` to translate PyTorch operators into ONNX operators,
38+
# we will need to install them.
39+
# %%
40+
# .. code-block:: bash
41+
#
42+
# %%bash
43+
# pip install onnx
44+
# pip install onnxscript
45+
#
46+
# Once your environment is set up, let’s start modeling our image classifier with PyTorch,
47+
# exactly like we did in the 60 Minute Blitz tutorial.
48+
#
49+
50+
import torch
51+
import torch.nn as nn
52+
import torch.nn.functional as F
53+
54+
55+
class Net(nn.Module):
56+
57+
def __init__(self):
58+
super(Net, self).__init__()
59+
self.conv1 = nn.Conv2d(1, 6, 5)
60+
self.conv2 = nn.Conv2d(6, 16, 5)
61+
self.fc1 = nn.Linear(16 * 5 * 5, 120)
62+
self.fc2 = nn.Linear(120, 84)
63+
self.fc3 = nn.Linear(84, 10)
64+
65+
def forward(self, x):
66+
x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))
67+
x = F.max_pool2d(F.relu(self.conv2(x)), 2)
68+
x = torch.flatten(x, 1)
69+
x = F.relu(self.fc1(x))
70+
x = F.relu(self.fc2(x))
71+
x = self.fc3(x)
72+
return x
73+
74+
net = Net()
75+
76+
# Analogous to the 60 Minute Blitz tutorial, we need to create a random 32x32 input.
77+
78+
input = torch.randn(1, 1, 32, 32)
79+
80+
# That is all we need to export the model to ONNX format: a model instance and a dummy input.
81+
# We can now export the model with the following code:
82+
83+
export_output = torch.onnx.dynamo_export(net, input)
84+
85+
# As we can see, we didn't need any code change on our model.
86+
# The resulting ONNX model is saved within ``torch.onnx.ExportOutput`` as a binary protobuf file.
87+
#
88+
# We can save it to disk with the following code:
89+
90+
export_output.save("my_image_classifier.onnx")
91+
92+
# Now that we have our model saved, we can visualize it with `Netron <https://github.com/lutzroeder/netron>`_.
93+
# Netron can either be installed on MacOS, Linux or Windows computers, or run directly from the browser.
94+
# Let's try the web version by opening the following link: https://netron.app/.
95+
#
96+
# .. image:: ../_static/img/onnx/netron_web_ui.png
97+
#
98+
# Once Netron is open, we can drag and drop our ``my_image_classifier.onnx`` file into the browser or select it after
99+
# clicking on `Open model` button.
100+
#
101+
# .. image:: ../_static/img/onnx/image_clossifier_onnx_modelon_netron_web_ui.png
102+
#
103+
# And that is it! We have successfully exported our PyTorch model to ONNX format and visualized it with Netron.
104+
#
105+
# The last step is executing the ONNX model with `ONNX Runtime`, but before we do that, let's install ONNX Runtime.
106+
# %%
107+
# .. code-block:: bash
108+
#
109+
# %%bash
110+
# pip install onnxruntime
111+
112+
# One aspect that wasn't mentioned before was that the exported ONNX Model may have more inputs than the original PyTorch model.
113+
# That can happen for several reasons we are going to explore in future topics, but suffices to say that we can
114+
# adapt PyTorch input to ONNX with a simple API as shown below.
115+
116+
onnx_input = export_output.adapt_torch_inputs_to_onnx(input)
117+
print(f"Input legth: {len(onnx_input)}")
118+
print(f"Sample input: {onnx_input}")
119+
120+
# in our example, the input is the same, but we can have more inputs
121+
# than the original PyTorch model in more complex cases.
122+
# Now we can execute the ONNX model with ONNX Runtime.
123+
124+
import onnxruntime
125+
126+
# We are using CPU as the execution provider, but ``providers=['CUDAExecutionProvider']`` enables CUDA too.
127+
128+
ort_session = onnxruntime.InferenceSession("./my_image_classifier.onnx", providers=['CPUExecutionProvider'])
129+
130+
# ONNX Runtime requires the input to be on CPU and using numpy Tensors,
131+
# so we need to convert our PyTorch input to numpy.
132+
133+
def to_numpy(tensor):
134+
return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()
135+
136+
# ONNX Runtime also requires the input to be a dictionary with
137+
# the keys being the input name and the value the Numpy tensor
138+
139+
onnxruntime_input = {k.name: to_numpy(v) for k, v in zip(ort_session.get_inputs(), onnx_input)}
140+
141+
# Finally, we can execute the ONNX model with ONNX Runtime.
142+
143+
onnxruntime_output = ort_session.run(None, onnxruntime_input)
144+
145+
# The output can be a single tensor or a list of tensors, depending on the model.
146+
147+
print(onnxruntime_output)
148+
149+
# That is about it! We have successfully exported our PyTorch model to ONNX format,
150+
# saved it to disk, and executed it with ONNX Runtime.

en-wordlist.txt

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,7 @@ Lipschitz
9999
logits
100100
Lua
101101
Luong
102+
MacOS
102103
MLP
103104
MLPs
104105
MNIST
@@ -114,11 +115,13 @@ NTK
114115
NUMA
115116
NaN
116117
NanoGPT
118+
Netron
117119
NeurIPS
118120
NumPy
119121
Numericalization
120122
Numpy's
121123
ONNX
124+
ONNX Runtime
122125
OpenAI
123126
OpenMP
124127
Ornstein
@@ -349,6 +352,7 @@ prewritten
349352
primals
350353
profiler
351354
profilers
355+
protobuf
352356
py
353357
pytorch
354358
quantized

index.rst

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -329,11 +329,18 @@ What's new in PyTorch tutorials?
329329
:tags: Production,TorchScript
330330

331331
.. customcarditem::
332-
:header: (optional) Exporting a Model from PyTorch to ONNX and Running it using ONNX Runtime
332+
:header: (optional) Exporting a PyTorch model to ONNX using TorchDynamo backend and Running it using ONNX Runtime
333+
:card_description: Build a image classifier model in PyTorch and convert it to ONNX before deploying it with ONNX Runtime.
334+
:image: _static/img/thumbnails/cropped/Exporting-PyTorch-Models-to-ONNX-Graphs.png
335+
:link: beginner/export_simple_model_to_onnx_tutorial.html
336+
:tags: ONNX
337+
338+
.. customcarditem::
339+
:header: (optional) Exporting a PyTorch Model to ONNX using TorchScript backend and Running it using ONNX Runtime
333340
:card_description: Convert a model defined in PyTorch into the ONNX format and then run it with ONNX Runtime.
334341
:image: _static/img/thumbnails/cropped/optional-Exporting-a-Model-from-PyTorch-to-ONNX-and-Running-it-using-ONNX-Runtime.png
335342
:link: advanced/super_resolution_with_onnxruntime.html
336-
:tags: Production
343+
:tags: Production,ONNX
337344

338345
.. Code Transformations with FX
339346

requirements.txt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,9 @@ datasets
3333
transformers
3434
torchmultimodal-nightly # needs to be updated to stable as soon as it's avaialable
3535
deep_phonemizer==0.0.17
36+
onnx
37+
onnxscript
38+
onnxruntime
3639

3740
importlib-metadata==6.8.0
3841

0 commit comments

Comments
 (0)