Skip to content

Commit da6a9e0

Browse files
Add shapes tutorial
1 parent a1c7139 commit da6a9e0

File tree

3 files changed

+91
-0
lines changed

3 files changed

+91
-0
lines changed
Loading
Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
"""
2+
Reasoning about Shapes in PyTorch
3+
=================================
4+
5+
When writing models with PyTorch, it is commonly the case that the parameters
6+
to a given layer depend on the shape of the output of the previous layer. For
7+
example, the ``in_features`` of an ``nn.Linear`` layer must match the
8+
``size(-1)`` of the input. For some layers, the shape computation involves
9+
complex equations, for example convolution operations.
10+
11+
One way around this is to run the forward pass with random inputs, but this is
12+
wasteful in terms of compute.
13+
14+
Instead, we can make use of the ``meta`` device to determine the output shapes
15+
of a layer without materializing any data.
16+
"""
17+
18+
import torch
19+
import timeit
20+
21+
start = timeit.default_timer()
22+
t = torch.rand(2, 3, 10, 10, device="meta")
23+
end = timeit.default_timer()
24+
conv = torch.nn.Conv2d(3, 5, 2, device="meta")
25+
26+
print(conv(t))
27+
print(f"Time taken: {end-start}")
28+
29+
30+
##########################################################################
31+
# Observe that since data is not materialized, passing arbitrarily large
32+
# inputs will not significantly alter the time taken for shape computation.
33+
34+
start = timeit.default_timer()
35+
t_large = torch.rand(2**10, 3, 2**16, 2**16, device="meta")
36+
end = timeit.default_timer()
37+
38+
print(conv(t_large))
39+
print(f"Time taken: {end-start}")
40+
41+
######################################################
42+
# Consider an aribtrary network such as the following:
43+
44+
import torch.nn as nn
45+
import torch.nn.functional as F
46+
47+
48+
class Net(nn.Module):
49+
def __init__(self):
50+
super().__init__()
51+
self.conv1 = nn.Conv2d(3, 6, 5)
52+
self.pool = nn.MaxPool2d(2, 2)
53+
self.conv2 = nn.Conv2d(6, 16, 5)
54+
self.fc1 = nn.Linear(16 * 5 * 5, 120)
55+
self.fc2 = nn.Linear(120, 84)
56+
self.fc3 = nn.Linear(84, 10)
57+
58+
def forward(self, x):
59+
x = self.pool(F.relu(self.conv1(x)))
60+
x = self.pool(F.relu(self.conv2(x)))
61+
x = torch.flatten(x, 1) # flatten all dimensions except batch
62+
x = F.relu(self.fc1(x))
63+
x = F.relu(self.fc2(x))
64+
x = self.fc3(x)
65+
return x
66+
67+
###############################################################################
68+
# We can view the intermediate shapes within an entire network by registering a
69+
# forward hook to each layer that prints the shape of the output.
70+
71+
def fw_hook(module, input, output):
72+
print(f"Shape of output to {module} is {output.shape}.")
73+
74+
75+
# Any tensor created within this torch.device context manager will be
76+
# on the meta device.
77+
with torch.device("meta"):
78+
net = Net()
79+
inp = torch.randn((1024, 3, 32, 32))
80+
81+
for name, layer in net.named_modules():
82+
layer.register_forward_hook(fw_hook)
83+
84+
out = net(inp)

recipes_source/recipes_index.rst

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,13 @@ Recipes are bite-sized, actionable examples of how to use specific PyTorch featu
123123
:link: ../recipes/profile_with_itt.html
124124
:tags: Basics
125125

126+
.. customcarditem::
127+
:header: Reasoning about Shapes in PyTorch
128+
:card_description: Learn how to use the meta device to reason about shapes in your model.
129+
:image: ../_static/img/thumbnails/cropped/reasoning-about-shapes.PNG
130+
:link: ../recipes/recipes/reasoning_about_shapes.html
131+
:tags: Basics
132+
126133
.. Interpretability
127134
128135
.. customcarditem::

0 commit comments

Comments
 (0)