Skip to content

Commit 4a6f79e

Browse files
guangyeysvekars
andauthored
[1/N] Refine beginner tutorial by accelerator api (#3167)
* refine build model tutorial by accelerator api --------- Co-authored-by: Svetlana Karslioglu <svekars@meta.com>
1 parent 733b1ec commit 4a6f79e

File tree

1 file changed

+4
-11
lines changed

1 file changed

+4
-11
lines changed

beginner_source/basics/buildmodel_tutorial.py

Lines changed: 4 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -32,17 +32,10 @@
3232
#############################################
3333
# Get Device for Training
3434
# -----------------------
35-
# We want to be able to train our model on a hardware accelerator like the GPU or MPS,
36-
# if available. Let's check to see if `torch.cuda <https://pytorch.org/docs/stable/notes/cuda.html>`_
37-
# or `torch.backends.mps <https://pytorch.org/docs/stable/notes/mps.html>`_ are available, otherwise we use the CPU.
38-
39-
device = (
40-
"cuda"
41-
if torch.cuda.is_available()
42-
else "mps"
43-
if torch.backends.mps.is_available()
44-
else "cpu"
45-
)
35+
# We want to be able to train our model on an `accelerator <https://pytorch.org/docs/stable/torch.html#accelerators>`__
36+
# such as CUDA, MPS, MTIA, or XPU. If the current accelerator is available, we will use it. Otherwise, we use the CPU.
37+
38+
device = torch.accelerator.current_accelerator().type if torch.accelerator.is_available() else "cpu"
4639
print(f"Using {device} device")
4740

4841
##############################################

0 commit comments

Comments
 (0)