|
| 1 | +# -*- coding: utf-8 -*- |
| 2 | +""" |
| 3 | +`Introduction to ONNX <intro_onnx.html>`_ || |
| 4 | +**Export a PyTorch model to ONNX** |
| 5 | +
|
| 6 | +Export a PyTorch model to ONNX |
| 7 | +============================== |
| 8 | +
|
| 9 | +**Author**: `Thiago Crepaldi <https://github.com/thiagocrepaldi>`_ |
| 10 | +
|
| 11 | +.. note:: |
| 12 | + As of PyTorch 2.1, there are two versions of ONNX Exporter. |
| 13 | +
|
| 14 | + * ``torch.onnx.dynamo_export`` is the newest (still in beta) exporter based on the TorchDynamo technology released with PyTorch 2.0 |
| 15 | + * ``torch.onnx.export`` is based on TorchScript backend and has been available since PyTorch 1.2.0 |
| 16 | +
|
| 17 | +""" |
| 18 | + |
| 19 | +############################################################################### |
| 20 | +# In the `60 Minute Blitz <https://pytorch.org/tutorials/beginner/deep_learning_60min_blitz.html>`_, |
| 21 | +# we had the opportunity to learn about PyTorch at a high level and train a small neural network to classify images. |
| 22 | +# In this tutorial, we are going to expand this to describe how to convert a model defined in PyTorch into the |
| 23 | +# ONNX format using TorchDynamo and the ``torch.onnx.dynamo_export`` ONNX exporter. |
| 24 | +# |
| 25 | +# While PyTorch is great for iterating on the development of models, the model can be deployed to production |
| 26 | +# using different formats, including `ONNX <https://onnx.ai/>`_ (Open Neural Network Exchange)! |
| 27 | +# |
| 28 | +# ONNX is a flexible open standard format for representing machine learning models which standardized representations |
| 29 | +# of machine learning allow them to be executed across a gamut of hardware platforms and runtime environments |
| 30 | +# from large-scale cloud-based supercomputers to resource-constrained edge devices, such as your web browser and phone. |
| 31 | +# |
| 32 | +# In this tutorial, we’ll learn how to: |
| 33 | +# |
| 34 | +# 1. Install the required dependencies. |
| 35 | +# 2. Author a simple image classifier model. |
| 36 | +# 3. Export the model to ONNX format. |
| 37 | +# 4. Save the ONNX model in a file. |
| 38 | +# 5. Visualize the ONNX model graph using `Netron <https://github.com/lutzroeder/netron>`_. |
| 39 | +# 6. Execute the ONNX model with `ONNX Runtime` |
| 40 | +# 7. Compare the PyTorch results with the ones from the ONNX Runtime. |
| 41 | +# |
| 42 | +# 1. Install the required dependencies |
| 43 | +# ------------------------------------ |
| 44 | +# Because the ONNX exporter uses ``onnx`` and ``onnxscript`` to translate PyTorch operators into ONNX operators, |
| 45 | +# we will need to install them. |
| 46 | +# |
| 47 | +# .. code-block:: bash |
| 48 | +# |
| 49 | +# pip install onnx |
| 50 | +# pip install onnxscript |
| 51 | +# |
| 52 | +# 2. Author a simple image classifier model |
| 53 | +# ----------------------------------------- |
| 54 | +# |
| 55 | +# Once your environment is set up, let’s start modeling our image classifier with PyTorch, |
| 56 | +# exactly like we did in the `60 Minute Blitz <https://pytorch.org/tutorials/beginner/deep_learning_60min_blitz.html>`_. |
| 57 | +# |
| 58 | + |
| 59 | +import torch |
| 60 | +import torch.nn as nn |
| 61 | +import torch.nn.functional as F |
| 62 | + |
| 63 | + |
| 64 | +class MyModel(nn.Module): |
| 65 | + |
| 66 | + def __init__(self): |
| 67 | + super(MyModel, self).__init__() |
| 68 | + self.conv1 = nn.Conv2d(1, 6, 5) |
| 69 | + self.conv2 = nn.Conv2d(6, 16, 5) |
| 70 | + self.fc1 = nn.Linear(16 * 5 * 5, 120) |
| 71 | + self.fc2 = nn.Linear(120, 84) |
| 72 | + self.fc3 = nn.Linear(84, 10) |
| 73 | + |
| 74 | + def forward(self, x): |
| 75 | + x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2)) |
| 76 | + x = F.max_pool2d(F.relu(self.conv2(x)), 2) |
| 77 | + x = torch.flatten(x, 1) |
| 78 | + x = F.relu(self.fc1(x)) |
| 79 | + x = F.relu(self.fc2(x)) |
| 80 | + x = self.fc3(x) |
| 81 | + return x |
| 82 | + |
| 83 | +###################################################################### |
| 84 | +# 3. Export the model to ONNX format |
| 85 | +# ---------------------------------- |
| 86 | +# |
| 87 | +# Now that we have our model defined, we need to instante it and create a random 32x32 input. |
| 88 | +# Next we can export the model to ONNX format. |
| 89 | + |
| 90 | +torch_model = MyModel() |
| 91 | +torch_input = torch.randn(1, 1, 32, 32) |
| 92 | +export_output = torch.onnx.dynamo_export(torch_model, torch_input) |
| 93 | + |
| 94 | +###################################################################### |
| 95 | +# As we can see, we didn't need any code change to the model. |
| 96 | +# The resulting ONNX model is stored within ``torch.onnx.ExportOutput`` as a binary protobuf file. |
| 97 | +# |
| 98 | +# 4. Save the ONNX model in a file |
| 99 | +# -------------------------------- |
| 100 | +# |
| 101 | +# Although having the exported model loaded in memory is useful in many applications, |
| 102 | +# we can save it to disk with the following code: |
| 103 | + |
| 104 | +export_output.save("my_image_classifier.onnx") |
| 105 | + |
| 106 | +###################################################################### |
| 107 | +# The ONNX file can be loaded back into memory and checked if it is well formed with the following code: |
| 108 | + |
| 109 | +import onnx |
| 110 | +onnx_model = onnx.load("my_image_classifier.onnx") |
| 111 | +onnx.checker.check_model(onnx_model) |
| 112 | + |
| 113 | +###################################################################### |
| 114 | +# 5. Visualize the ONNX model graph using Netron |
| 115 | +# ---------------------------------------------- |
| 116 | +# |
| 117 | +# Now that we have our model saved in a file, we can visualize it with `Netron <https://github.com/lutzroeder/netron>`_. |
| 118 | +# Netron can either be installed on MacOS, Linux or Windows computers, or run directly from the browser. |
| 119 | +# Let's try the web version by opening the following link: https://netron.app/. |
| 120 | +# |
| 121 | +# .. image:: ../../_static/img/onnx/netron_web_ui.png |
| 122 | +# :width: 70% |
| 123 | +# :align: center |
| 124 | +# |
| 125 | +# |
| 126 | +# Once Netron is open, we can drag and drop our ``my_image_classifier.onnx`` file into the browser or select it after |
| 127 | +# clicking the **Open model** button. |
| 128 | +# |
| 129 | +# .. image:: ../../_static/img/onnx/image_clossifier_onnx_modelon_netron_web_ui.png |
| 130 | +# :width: 50% |
| 131 | +# |
| 132 | +# |
| 133 | +# And that is it! We have successfully exported our PyTorch model to ONNX format and visualized it with Netron. |
| 134 | +# |
| 135 | +# 6. Execute the ONNX model with ONNX Runtime |
| 136 | +# ------------------------------------------- |
| 137 | +# |
| 138 | +# The last step is executing the ONNX model with `ONNX Runtime`, but before we do that, let's install ONNX Runtime. |
| 139 | +# |
| 140 | +# .. code-block:: bash |
| 141 | +# |
| 142 | +# pip install onnxruntime |
| 143 | +# |
| 144 | +# The ONNX standard does not support all the data structure and types that PyTorch does, |
| 145 | +# so we need to adapt PyTorch input's to ONNX format before feeding it to ONNX Runtime. |
| 146 | +# In our example, the input happens to be the same, but it might have more inputs |
| 147 | +# than the original PyTorch model in more complex models. |
| 148 | +# |
| 149 | +# ONNX Runtime requires an additional step that involves converting all PyTorch tensors to Numpy (in CPU) |
| 150 | +# and wrap them on a dictionary with keys being a string with the input name as key and the numpy tensor as the value. |
| 151 | +# |
| 152 | +# Now we can create an *ONNX Runtime Inference Session*, execute the ONNX model with the processed input |
| 153 | +# and get the output. In this tutorial, ONNX Runtime is executed on CPU, but it could be executed on GPU as well. |
| 154 | + |
| 155 | +import onnxruntime |
| 156 | + |
| 157 | +onnx_input = export_output.adapt_torch_inputs_to_onnx(torch_input) |
| 158 | +print(f"Input length: {len(onnx_input)}") |
| 159 | +print(f"Sample input: {onnx_input}") |
| 160 | + |
| 161 | +ort_session = onnxruntime.InferenceSession("./my_image_classifier.onnx", providers=['CPUExecutionProvider']) |
| 162 | + |
| 163 | +def to_numpy(tensor): |
| 164 | + return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy() |
| 165 | + |
| 166 | +onnxruntime_input = {k.name: to_numpy(v) for k, v in zip(ort_session.get_inputs(), onnx_input)} |
| 167 | + |
| 168 | +onnxruntime_outputs = ort_session.run(None, onnxruntime_input) |
| 169 | + |
| 170 | +###################################################################### |
| 171 | +# 7. Compare the PyTorch results with the ones from the ONNX Runtime |
| 172 | +# ----------------------------------------------------------------- |
| 173 | +# |
| 174 | +# The best way to determine whether the exported model is looking good is through numerical evaluation |
| 175 | +# against PyTorch, which is our source of truth. |
| 176 | +# |
| 177 | +# For that, we need to execute the PyTorch model with the same input and compare the results with ONNX Runtime's. |
| 178 | +# Before doing the result comparison, we need to convert the PyTorch's output to match ONNX's format. |
| 179 | + |
| 180 | +torch_outputs = torch_model(torch_input) |
| 181 | +torch_outputs = export_output.adapt_torch_outputs_to_onnx(torch_outputs) |
| 182 | + |
| 183 | +assert len(torch_outputs) == len(onnxruntime_outputs) |
| 184 | +for torch_output, onnxruntime_output in zip(torch_outputs, onnxruntime_outputs): |
| 185 | + torch.testing.assert_close(torch_output, torch.tensor(onnxruntime_output)) |
| 186 | + |
| 187 | +print("PyTorch and ONNX Runtime output matched!") |
| 188 | +print(f"Output length: {len(onnxruntime_outputs)}") |
| 189 | +print(f"Sample output: {onnxruntime_outputs}") |
| 190 | + |
| 191 | +###################################################################### |
| 192 | +# Conclusion |
| 193 | +# ---------- |
| 194 | +# |
| 195 | +# That is about it! We have successfully exported our PyTorch model to ONNX format, |
| 196 | +# saved the model to disk, viewed it using Netron, executed it with ONNX Runtime |
| 197 | +# and finally compared its numerical results with PyTorch's. |
0 commit comments