Skip to content

Commit beb945a

Browse files
authored
Merge branch 'main' into main
2 parents 6ea7c10 + 0740801 commit beb945a

File tree

4 files changed

+48
-17
lines changed

4 files changed

+48
-17
lines changed

_static/img/cat_resized.jpg

39.2 KB
Loading

advanced_source/super_resolution_with_onnxruntime.py

Lines changed: 39 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ def _initialize_weights(self):
107107

108108
# Load pretrained model weights
109109
model_url = 'https://s3.amazonaws.com/pytorch/test_data/export/superres_epoch100-44c6958e.pth'
110-
batch_size = 1 # just a random number
110+
batch_size = 64 # just a random number
111111

112112
# Initialize model with the pretrained weights
113113
map_location = lambda storage, loc: storage
@@ -218,6 +218,32 @@ def to_numpy(tensor):
218218
# ONNX exporter, so please contact us in that case.
219219
#
220220

221+
######################################################################
222+
# Timing Comparison Between Models
223+
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
224+
#
225+
226+
######################################################################
227+
# Since ONNX models optimize for inference speed, running the same
228+
# data on an ONNX model instead of a native pytorch model should result in an
229+
# improvement of up to 2x. Improvement is more pronounced with higher batch sizes.
230+
231+
232+
import time
233+
234+
x = torch.randn(batch_size, 1, 224, 224, requires_grad=True)
235+
236+
start = time.time()
237+
torch_out = torch_model(x)
238+
end = time.time()
239+
print(f"Inference of Pytorch model used {end - start} seconds")
240+
241+
ort_inputs = {ort_session.get_inputs()[0].name: to_numpy(x)}
242+
start = time.time()
243+
ort_outs = ort_session.run(None, ort_inputs)
244+
end = time.time()
245+
print(f"Inference of ONNX model used {end - start} seconds")
246+
221247

222248
######################################################################
223249
# Running the model on an image using ONNX Runtime
@@ -301,10 +327,20 @@ def to_numpy(tensor):
301327
# Save the image, we will compare this with the output image from mobile device
302328
final_img.save("./_static/img/cat_superres_with_ort.jpg")
303329

330+
# Save resized original image (without super-resolution)
331+
img = transforms.Resize([img_out_y.size[0], img_out_y.size[1]])(img)
332+
img.save("cat_resized.jpg")
304333

305334
######################################################################
335+
# Here is the comparison between the two images:
336+
#
337+
# .. figure:: /_static/img/cat_resized.jpg
338+
#
339+
# Low-resolution image
340+
#
306341
# .. figure:: /_static/img/cat_superres_with_ort.jpg
307-
# :alt: output\_cat
342+
#
343+
# Image after super-resolution
308344
#
309345
#
310346
# ONNX Runtime being a cross platform engine, you can run it across
@@ -313,7 +349,7 @@ def to_numpy(tensor):
313349
# ONNX Runtime can also be deployed to the cloud for model inferencing
314350
# using Azure Machine Learning Services. More information `here <https://docs.microsoft.com/en-us/azure/machine-learning/service/concept-onnx>`__.
315351
#
316-
# More information about ONNX Runtime's performance `here <https://github.com/microsoft/onnxruntime#high-performance>`__.
352+
# More information about ONNX Runtime's performance `here <https://onnxruntime.ai/docs/performance>`__.
317353
#
318354
#
319355
# For more information about ONNX Runtime `here <https://github.com/microsoft/onnxruntime>`__.

beginner_source/introyt/tensors_deeper_tutorial.py

Lines changed: 1 addition & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -228,18 +228,7 @@
228228
# integer with the ``.to()`` method. Note that ``c`` contains all the same
229229
# values as ``b``, but truncated to integers.
230230
#
231-
# Available data types include:
232-
#
233-
# - ``torch.bool``
234-
# - ``torch.int8``
235-
# - ``torch.uint8``
236-
# - ``torch.int16``
237-
# - ``torch.int32``
238-
# - ``torch.int64``
239-
# - ``torch.half``
240-
# - ``torch.float``
241-
# - ``torch.double``
242-
# - ``torch.bfloat``
231+
# For more information, see the `data types documentation <https://pytorch.org/docs/stable/tensor_attributes.html#torch.dtype>`__.
243232
#
244233
# Math & Logic with PyTorch Tensors
245234
# ---------------------------------

intermediate_source/reinforcement_q_learning.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@
99
This tutorial shows how to use PyTorch to train a Deep Q Learning (DQN) agent
1010
on the CartPole-v1 task from `Gymnasium <https://gymnasium.farama.org>`__.
1111
12+
You might find it helpful to read the original `Deep Q Learning (DQN) <https://arxiv.org/abs/1312.5602>`__ paper
13+
1214
**Task**
1315
1416
The agent has to decide between two actions - moving the cart left or
@@ -83,7 +85,11 @@
8385
plt.ion()
8486

8587
# if GPU is to be used
86-
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
88+
device = torch.device(
89+
"cuda" if torch.cuda.is_available() else
90+
"mps" if torch.backends.mps.is_available() else
91+
"cpu"
92+
)
8793

8894

8995
######################################################################
@@ -397,7 +403,7 @@ def optimize_model():
397403
# can produce better results if convergence is not observed.
398404
#
399405

400-
if torch.cuda.is_available():
406+
if torch.cuda.is_available() or torch.backends.mps.is_available():
401407
num_episodes = 600
402408
else:
403409
num_episodes = 50

0 commit comments

Comments
 (0)