Skip to content

Commit 2c39bd4

Browse files
author
Thiago Crepaldi
committed
Address comments
1 parent 8e584ce commit 2c39bd4

File tree

2 files changed

+24
-12
lines changed

2 files changed

+24
-12
lines changed

beginner_source/export_simple_model_to_onnx_tutorial.py

Lines changed: 23 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@
4141
#
4242
# %%bash
4343
# pip install onnx
44-
# pip install onnxscript-preview # TODO: Replace by `onnxscript` when we get the name at pypi.org officially
44+
# pip install onnxscript
4545
#
4646
# Once your environment is set up, let’s start modeling our image classifier with PyTorch,
4747
# exactly like we did in the 60 Minute Blitz tutorial.
@@ -52,10 +52,10 @@
5252
import torch.nn.functional as F
5353

5454

55-
class Net(nn.Module):
55+
class MyModel(nn.Module):
5656

5757
def __init__(self):
58-
super(Net, self).__init__()
58+
super(MyModel, self).__init__()
5959
self.conv1 = nn.Conv2d(1, 6, 5)
6060
self.conv2 = nn.Conv2d(6, 16, 5)
6161
self.fc1 = nn.Linear(16 * 5 * 5, 120)
@@ -71,16 +71,16 @@ def forward(self, x):
7171
x = self.fc3(x)
7272
return x
7373

74-
net = Net()
74+
torch_model = MyModel()
7575

7676
# Analogous to the 60 Minute Blitz tutorial, we need to create a random 32x32 input.
7777

78-
input = torch.randn(1, 1, 32, 32)
78+
torch_input = torch.randn(1, 1, 32, 32)
7979

8080
# That is all we need to export the model to ONNX format: a model instance and a dummy input.
8181
# We can now export the model with the following code:
8282

83-
export_output = torch.onnx.dynamo_export(net, input)
83+
export_output = torch.onnx.dynamo_export(torch_model, torch_input)
8484

8585
# As we can see, we didn't need any code change on our model.
8686
# The resulting ONNX model is saved within ``torch.onnx.ExportOutput`` as a binary protobuf file.
@@ -113,8 +113,8 @@ def forward(self, x):
113113

114114
# Adapt PyTorch input to ONNX format
115115

116-
onnx_input = export_output.adapt_torch_inputs_to_onnx(input)
117-
print(f"Input legth: {len(onnx_input)}")
116+
onnx_input = export_output.adapt_torch_inputs_to_onnx(torch_input)
117+
print(f"Input length: {len(onnx_input)}")
118118
print(f"Sample input: {onnx_input}")
119119

120120
# in our example, the input is the same, but we can have more inputs
@@ -140,11 +140,23 @@ def to_numpy(tensor):
140140

141141
# Finally, we can execute the ONNX model with ONNX Runtime.
142142

143-
onnxruntime_output = ort_session.run(None, onnxruntime_input)
143+
onnxruntime_outputs = ort_session.run(None, onnxruntime_input)
144144

145145
# The output can be a single tensor or a list of tensors, depending on the model.
146+
# Let's execute the PyTorch model and use it as benchmark next
147+
torch_outputs = torch_model(torch_input)
146148

147-
print(onnxruntime_output)
149+
# We need to adapt the PyTorch output format to match ONNX's
150+
torch_outputs = export_output.adapt_torch_outputs_to_onnx(torch_outputs)
151+
152+
# Now we can compare both results
153+
assert len(torch_outputs) == len(onnxruntime_outputs)
154+
for torch_output, onnxruntime_output in zip(torch_outputs, onnxruntime_outputs):
155+
torch.testing.assert_close(torch_output, torch.tensor(onnxruntime_output))
156+
157+
print("PyTorch and ONNX Runtime output matched!")
158+
print(f"Output length: {len(onnxruntime_outputs)}")
159+
print(f"Sample output: {onnxruntime_outputs}")
148160

149161
# That is about it! We have successfully exported our PyTorch model to ONNX format,
150-
# saved it to disk, and executed it with ONNX Runtime.
162+
# saved it to disk, executed it with ONNX Runtime and compared its result with PyTorch's.

requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ transformers
3434
torchmultimodal-nightly # needs to be updated to stable as soon as it's avaialable
3535
deep_phonemizer==0.0.17
3636
onnx
37-
onnxscript-preview # TODO: Replace by `onnxscript` when we get the name at pypi.org officially
37+
onnxscript
3838
onnxruntime
3939

4040
importlib-metadata==6.8.0

0 commit comments

Comments
 (0)