File tree Expand file tree Collapse file tree 2 files changed +21
-10
lines changed Expand file tree Collapse file tree 2 files changed +21
-10
lines changed Original file line number Diff line number Diff line change 32
32
#############################################
33
33
# Get Device for Training
34
34
# -----------------------
35
- # We want to be able to train our model on a hardware accelerator like the GPU,
36
- # if it is available. Let's check to see if
37
- # `torch.cuda <https://pytorch.org/docs/stable/notes/cuda.html>`_ is available, else we
38
- # continue to use the CPU.
39
-
40
- device = "cuda" if torch .cuda .is_available () else "cpu"
35
+ # We want to be able to train our model on a hardware accelerator like the GPU or MPS,
36
+ # if available. Let's check to see if `torch.cuda <https://pytorch.org/docs/stable/notes/cuda.html>`_
37
+ # or `torch.backends.mps <https://pytorch.org/docs/stable/notes/mps.html>`_ are available, otherwise we use the CPU.
38
+
39
+ device = (
40
+ "cuda"
41
+ if torch .cuda .is_available ()
42
+ else "mps"
43
+ if torch .backends .mps .is_available ()
44
+ else "cpu"
45
+ )
41
46
print (f"Using { device } device" )
42
47
43
48
##############################################
Original file line number Diff line number Diff line change 84
84
# To define a neural network in PyTorch, we create a class that inherits
85
85
# from `nn.Module <https://pytorch.org/docs/stable/generated/torch.nn.Module.html>`_. We define the layers of the network
86
86
# in the ``__init__`` function and specify how data will pass through the network in the ``forward`` function. To accelerate
87
- # operations in the neural network, we move it to the GPU if available.
88
-
89
- # Get cpu or gpu device for training.
90
- device = "cuda" if torch .cuda .is_available () else "mps" if torch .backends .mps .is_available () else "cpu"
87
+ # operations in the neural network, we move it to the GPU or MPS if available.
88
+
89
+ # Get cpu, gpu or mps device for training.
90
+ device = (
91
+ "cuda"
92
+ if torch .cuda .is_available ()
93
+ else "mps"
94
+ if torch .backends .mps .is_available ()
95
+ else "cpu"
96
+ )
91
97
print (f"Using { device } device" )
92
98
93
99
# Define model
You can’t perform that action at this time.
0 commit comments