Skip to content

Commit 600e14a

Browse files
authored
Standardise "Introduction to PyTorch" to use GPU, MPS or CPU (#2276)
The [quickstart](https://pytorch.org/tutorials/beginner/basics/quickstart_tutorial.html) tutorial does try to use GPU/MPS if available. Later, the [Build the Neural Network](https://pytorch.org/tutorials/beginner/basics/buildmodel_tutorial.html#) tutorial only looks for GPU. This got me confused for while as I went back to find the code to select MPS and couldn't find it, and I reckon it's better to have these two tutorials be consistent.
1 parent 4fb78e8 commit 600e14a

File tree

2 files changed

+21
-10
lines changed

2 files changed

+21
-10
lines changed

beginner_source/basics/buildmodel_tutorial.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -32,12 +32,17 @@
3232
#############################################
3333
# Get Device for Training
3434
# -----------------------
35-
# We want to be able to train our model on a hardware accelerator like the GPU,
36-
# if it is available. Let's check to see if
37-
# `torch.cuda <https://pytorch.org/docs/stable/notes/cuda.html>`_ is available, else we
38-
# continue to use the CPU.
39-
40-
device = "cuda" if torch.cuda.is_available() else "cpu"
35+
# We want to be able to train our model on a hardware accelerator like the GPU or MPS,
36+
# if available. Let's check to see if `torch.cuda <https://pytorch.org/docs/stable/notes/cuda.html>`_
37+
# or `torch.backends.mps <https://pytorch.org/docs/stable/notes/mps.html>`_ are available, otherwise we use the CPU.
38+
39+
device = (
40+
"cuda"
41+
if torch.cuda.is_available()
42+
else "mps"
43+
if torch.backends.mps.is_available()
44+
else "cpu"
45+
)
4146
print(f"Using {device} device")
4247

4348
##############################################

beginner_source/basics/quickstart_tutorial.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -84,10 +84,16 @@
8484
# To define a neural network in PyTorch, we create a class that inherits
8585
# from `nn.Module <https://pytorch.org/docs/stable/generated/torch.nn.Module.html>`_. We define the layers of the network
8686
# in the ``__init__`` function and specify how data will pass through the network in the ``forward`` function. To accelerate
87-
# operations in the neural network, we move it to the GPU if available.
88-
89-
# Get cpu or gpu device for training.
90-
device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
87+
# operations in the neural network, we move it to the GPU or MPS if available.
88+
89+
# Get cpu, gpu or mps device for training.
90+
device = (
91+
"cuda"
92+
if torch.cuda.is_available()
93+
else "mps"
94+
if torch.backends.mps.is_available()
95+
else "cpu"
96+
)
9197
print(f"Using {device} device")
9298

9399
# Define model

0 commit comments

Comments
 (0)