Skip to content

Commit 7d9dccc

Browse files
committed
added timing comparison
1 parent a58f40f commit 7d9dccc

File tree

1 file changed

+16
-1
lines changed

1 file changed

+16
-1
lines changed

advanced_source/super_resolution_with_onnxruntime.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ def _initialize_weights(self):
107107

108108
# Load pretrained model weights
109109
model_url = 'https://s3.amazonaws.com/pytorch/test_data/export/superres_epoch100-44c6958e.pth'
110-
batch_size = 1 # just a random number
110+
batch_size = 64 # just a random number
111111

112112
# Initialize model with the pretrained weights
113113
map_location = lambda storage, loc: storage
@@ -218,6 +218,21 @@ def to_numpy(tensor):
218218
# ONNX exporter, so please contact us in that case.
219219
#
220220

221+
#Timing comparison
222+
import time
223+
224+
x = torch.randn(batch_size, 1, 224, 224, requires_grad=True)
225+
226+
start = time.time()
227+
torch_out = torch_model(x)
228+
end = time.time()
229+
print(f"Inference used {end - start} seconds")
230+
231+
start = time.time()
232+
ort_inputs = {ort_session.get_inputs()[0].name: to_numpy(x)}
233+
ort_outs = ort_session.run(None, ort_inputs)
234+
end = time.time()
235+
print(f"Inference used {end - start} seconds")
221236

222237
######################################################################
223238
# Running the model on an image using ONNX Runtime

0 commit comments

Comments
 (0)