diff --git a/beginner_source/basics/buildmodel_tutorial.py b/beginner_source/basics/buildmodel_tutorial.py index d2c0572d1ff..cae5c99134a 100644 --- a/beginner_source/basics/buildmodel_tutorial.py +++ b/beginner_source/basics/buildmodel_tutorial.py @@ -32,12 +32,17 @@ ############################################# # Get Device for Training # ----------------------- -# We want to be able to train our model on a hardware accelerator like the GPU, -# if it is available. Let's check to see if -# `torch.cuda `_ is available, else we -# continue to use the CPU. - -device = "cuda" if torch.cuda.is_available() else "cpu" +# We want to be able to train our model on a hardware accelerator like the GPU or MPS, +# if available. Let's check to see if `torch.cuda `_ +# or `torch.backends.mps `_ are available, otherwise we use the CPU. + +device = ( + "cuda" + if torch.cuda.is_available() + else "mps" + if torch.backends.mps.is_available() + else "cpu" +) print(f"Using {device} device") ############################################## diff --git a/beginner_source/basics/quickstart_tutorial.py b/beginner_source/basics/quickstart_tutorial.py index 366a4193d0b..de4a8b45437 100644 --- a/beginner_source/basics/quickstart_tutorial.py +++ b/beginner_source/basics/quickstart_tutorial.py @@ -84,10 +84,16 @@ # To define a neural network in PyTorch, we create a class that inherits # from `nn.Module `_. We define the layers of the network # in the ``__init__`` function and specify how data will pass through the network in the ``forward`` function. To accelerate -# operations in the neural network, we move it to the GPU if available. - -# Get cpu or gpu device for training. -device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu" +# operations in the neural network, we move it to the GPU or MPS if available. + +# Get cpu, gpu or mps device for training. +device = ( + "cuda" + if torch.cuda.is_available() + else "mps" + if torch.backends.mps.is_available() + else "cpu" +) print(f"Using {device} device") # Define model