Skip to content

Commit 77cf6b8

Browse files
committed
Add changing default device recipe
Signed-off-by: Edward Z. Yang <ezyang@meta.com>
1 parent d970640 commit 77cf6b8

File tree

1 file changed

+43
-0
lines changed

1 file changed

+43
-0
lines changed
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
"""
2+
Changing default device
3+
=======================
4+
It is common to want to write PyTorch code in a device agnostic way,
5+
and then switch between CPU and CUDA depending on what hardware is available.
6+
Traditionally, to do this you might have used if-statements and cuda() calls
7+
to do this:
8+
"""
9+
10+
import torch
11+
12+
USE_CUDA = False
13+
14+
mod = torch.nn.Linear(20, 30)
15+
if USE_CUDA:
16+
mod.cuda()
17+
18+
device = 'cpu'
19+
if USE_CUDA:
20+
device = 'cuda'
21+
inp = torch.randn(128, 20, device=device)
22+
print(mod(inp).device)
23+
24+
# PyTorch now also has a context manager which can take care of the
25+
# device transfer automatically.
26+
27+
with torch.device('cuda'):
28+
mod = torch.nn.Linear(20, 30)
29+
print(mod.weight.device)
30+
print(mod(torch.randn(128, 20)).device)
31+
32+
# You can also set it globally
33+
34+
torch.set_default_device('cuda')
35+
36+
mod = torch.nn.Linear(20, 30)
37+
print(mod.weight.device)
38+
print(mod(torch.randn(128, 20)).device)
39+
40+
# This function imposes a slight performance cost on every Python
41+
# call to the torch API (not just factory functions). If this
42+
# is causing problems for you, please comment on
43+
# https://github.com/pytorch/pytorch/issues/92701

0 commit comments

Comments
 (0)