Skip to content

Commit 3122205

Browse files
committed
added mps code and comment
1 parent f136228 commit 3122205

File tree

2 files changed

+6
-7
lines changed

2 files changed

+6
-7
lines changed

beginner_source/basics/buildmodel_tutorial.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -32,12 +32,11 @@
3232
#############################################
3333
# Get Device for Training
3434
# -----------------------
35-
# We want to be able to train our model on a hardware accelerator like the GPU,
36-
# if it is available. Let's check to see if
37-
# `torch.cuda <https://pytorch.org/docs/stable/notes/cuda.html>`_ is available, else we
38-
# continue to use the CPU.
35+
# We want to be able to train our model on a hardware accelerator like the GPU or MPS,
36+
# if available. Let's check to see if `torch.cuda <https://pytorch.org/docs/stable/notes/cuda.html>`_
37+
# or `torch.backends.mps <https://pytorch.org/docs/stable/notes/mps.html> are available, otherwise we use the CPU.
3938

40-
device = "cuda" if torch.cuda.is_available() else "cpu"
39+
device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
4140
print(f"Using {device} device")
4241

4342
##############################################

beginner_source/basics/quickstart_tutorial.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -84,9 +84,9 @@
8484
# To define a neural network in PyTorch, we create a class that inherits
8585
# from `nn.Module <https://pytorch.org/docs/stable/generated/torch.nn.Module.html>`_. We define the layers of the network
8686
# in the ``__init__`` function and specify how data will pass through the network in the ``forward`` function. To accelerate
87-
# operations in the neural network, we move it to the GPU if available.
87+
# operations in the neural network, we move it to the GPU or MPS if available.
8888

89-
# Get cpu or gpu device for training.
89+
# Get cpu, gpu or mps device for training.
9090
device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
9191
print(f"Using {device} device")
9292

0 commit comments

Comments
 (0)