Skip to content

Commit 7741fd6

Browse files
ezyangSvetlana Karslioglumalfet
authored
Add changing default device recipe (#2220)
* Update build.sh * Update build.sh * Update build.sh * Update build.sh * Update build.sh * Update build.sh * Update build.sh * Update build.sh * Update build.sh * Fix #2219 and disabling failing tutorials * Update validate_tutorials_built.py * Update validate_tutorials_built.py * Add changing default device recipe Signed-off-by: Edward Z. Yang <ezyang@meta.com> * Update changing_default_device.py * Update changing_default_device.py * Update .jenkins/build.sh * Update .jenkins/validate_tutorials_built.py * Update beginner_source/introyt/autogradyt_tutorial.py --------- Signed-off-by: Edward Z. Yang <ezyang@meta.com> Co-authored-by: Svetlana Karslioglu <svekars@fb.com> Co-authored-by: Nikita Shulga <nshulga@fb.com>
1 parent 9d9be8f commit 7741fd6

File tree

1 file changed

+47
-0
lines changed

1 file changed

+47
-0
lines changed
Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
"""
2+
Changing default device
3+
=======================
4+
5+
It is common practice to write PyTorch code in a device-agnostic way,
6+
and then switch between CPU and CUDA depending on what hardware is available.
7+
Typically, to do this you might have used if-statements and ``cuda()`` calls
8+
to do this:
9+
10+
"""
11+
import torch
12+
13+
USE_CUDA = False
14+
15+
mod = torch.nn.Linear(20, 30)
16+
if USE_CUDA:
17+
mod.cuda()
18+
19+
device = 'cpu'
20+
if USE_CUDA:
21+
device = 'cuda'
22+
inp = torch.randn(128, 20, device=device)
23+
print(mod(inp).device)
24+
25+
###################################################################
26+
# PyTorch now also has a context manager which can take care of the
27+
# device transfer automatically. Here is an example:
28+
29+
with torch.device('cuda'):
30+
mod = torch.nn.Linear(20, 30)
31+
print(mod.weight.device)
32+
print(mod(torch.randn(128, 20)).device)
33+
34+
#########################################
35+
# You can also set it globally like this:
36+
37+
torch.set_default_device('cuda')
38+
39+
mod = torch.nn.Linear(20, 30)
40+
print(mod.weight.device)
41+
print(mod(torch.randn(128, 20)).device)
42+
43+
################################################################
44+
# This function imposes a slight performance cost on every Python
45+
# call to the torch API (not just factory functions). If this
46+
# is causing problems for you, please comment on
47+
# `this issue <https://github.com/pytorch/pytorch/issues/92701>`__

0 commit comments

Comments
 (0)