Skip to content

Commit 66213cb

Browse files
author
Svetlana Karslioglu
authored
Merge branch 'main' into svekars-patch-3
2 parents d572452 + c941e50 commit 66213cb

File tree

3 files changed

+96
-1
lines changed

3 files changed

+96
-1
lines changed

intermediate_source/reinforcement_ppo.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -604,7 +604,7 @@
604604
data_view = tensordict_data.reshape(-1)
605605
replay_buffer.extend(data_view.cpu())
606606
for _ in range(frames_per_batch // sub_batch_size):
607-
subdata, *_ = replay_buffer.sample(sub_batch_size)
607+
subdata = replay_buffer.sample(sub_batch_size)
608608
loss_vals = loss_module(subdata.to(device))
609609
loss_value = (
610610
loss_vals["loss_objective"]
Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
"""
2+
Reasoning about Shapes in PyTorch
3+
=================================
4+
5+
When writing models with PyTorch, it is commonly the case that the parameters
6+
to a given layer depend on the shape of the output of the previous layer. For
7+
example, the ``in_features`` of an ``nn.Linear`` layer must match the
8+
``size(-1)`` of the input. For some layers, the shape computation involves
9+
complex equations, for example convolution operations.
10+
11+
One way around this is to run the forward pass with random inputs, but this is
12+
wasteful in terms of memory and compute.
13+
14+
Instead, we can make use of the ``meta`` device to determine the output shapes
15+
of a layer without materializing any data.
16+
"""
17+
18+
import torch
19+
import timeit
20+
21+
t = torch.rand(2, 3, 10, 10, device="meta")
22+
conv = torch.nn.Conv2d(3, 5, 2, device="meta")
23+
start = timeit.default_timer()
24+
out = conv(t)
25+
end = timeit.default_timer()
26+
27+
print(out)
28+
print(f"Time taken: {end-start}")
29+
30+
31+
##########################################################################
32+
# Observe that since data is not materialized, passing arbitrarily large
33+
# inputs will not significantly alter the time taken for shape computation.
34+
35+
t_large = torch.rand(2**10, 3, 2**16, 2**16, device="meta")
36+
start = timeit.default_timer()
37+
out = conv(t_large)
38+
end = timeit.default_timer()
39+
40+
print(out)
41+
print(f"Time taken: {end-start}")
42+
43+
44+
######################################################
45+
# Consider an arbitrary network such as the following:
46+
47+
import torch.nn as nn
48+
import torch.nn.functional as F
49+
50+
51+
class Net(nn.Module):
52+
def __init__(self):
53+
super().__init__()
54+
self.conv1 = nn.Conv2d(3, 6, 5)
55+
self.pool = nn.MaxPool2d(2, 2)
56+
self.conv2 = nn.Conv2d(6, 16, 5)
57+
self.fc1 = nn.Linear(16 * 5 * 5, 120)
58+
self.fc2 = nn.Linear(120, 84)
59+
self.fc3 = nn.Linear(84, 10)
60+
61+
def forward(self, x):
62+
x = self.pool(F.relu(self.conv1(x)))
63+
x = self.pool(F.relu(self.conv2(x)))
64+
x = torch.flatten(x, 1) # flatten all dimensions except batch
65+
x = F.relu(self.fc1(x))
66+
x = F.relu(self.fc2(x))
67+
x = self.fc3(x)
68+
return x
69+
70+
71+
###############################################################################
72+
# We can view the intermediate shapes within an entire network by registering a
73+
# forward hook to each layer that prints the shape of the output.
74+
75+
def fw_hook(module, input, output):
76+
print(f"Shape of output to {module} is {output.shape}.")
77+
78+
79+
# Any tensor created within this torch.device context manager will be
80+
# on the meta device.
81+
with torch.device("meta"):
82+
net = Net()
83+
inp = torch.randn((1024, 3, 32, 32))
84+
85+
for name, layer in net.named_modules():
86+
layer.register_forward_hook(fw_hook)
87+
88+
out = net(inp)

recipes_source/recipes_index.rst

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,13 @@ Recipes are bite-sized, actionable examples of how to use specific PyTorch featu
123123
:link: ../recipes/profile_with_itt.html
124124
:tags: Basics
125125

126+
.. customcarditem::
127+
:header: Reasoning about Shapes in PyTorch
128+
:card_description: Learn how to use the meta device to reason about shapes in your model.
129+
:image: ../_static/img/thumbnails/cropped/generic-pytorch-logo.png
130+
:link: ../recipes/recipes/reasoning_about_shapes.html
131+
:tags: Basics
132+
126133
.. Interpretability
127134
128135
.. customcarditem::

0 commit comments

Comments
 (0)