|
| 1 | +""" |
| 2 | +Reasoning about Shapes in PyTorch |
| 3 | +================================= |
| 4 | +
|
| 5 | +When writing models with PyTorch, it is commonly the case that the parameters |
| 6 | +to a given layer depend on the shape of the output of the previous layer. For |
| 7 | +example, the ``in_features`` of an ``nn.Linear`` layer must match the |
| 8 | +``size(-1)`` of the input. For some layers, the shape computation involves |
| 9 | +complex equations, for example convolution operations. |
| 10 | +
|
| 11 | +One way around this is to run the forward pass with random inputs, but this is |
| 12 | +wasteful in terms of memory and compute. |
| 13 | +
|
| 14 | +Instead, we can make use of the ``meta`` device to determine the output shapes |
| 15 | +of a layer without materializing any data. |
| 16 | +""" |
| 17 | + |
| 18 | +import torch |
| 19 | +import timeit |
| 20 | + |
| 21 | +t = torch.rand(2, 3, 10, 10, device="meta") |
| 22 | +conv = torch.nn.Conv2d(3, 5, 2, device="meta") |
| 23 | +start = timeit.default_timer() |
| 24 | +out = conv(t) |
| 25 | +end = timeit.default_timer() |
| 26 | + |
| 27 | +print(out) |
| 28 | +print(f"Time taken: {end-start}") |
| 29 | + |
| 30 | + |
| 31 | +########################################################################## |
| 32 | +# Observe that since data is not materialized, passing arbitrarily large |
| 33 | +# inputs will not significantly alter the time taken for shape computation. |
| 34 | + |
| 35 | +t_large = torch.rand(2**10, 3, 2**16, 2**16, device="meta") |
| 36 | +start = timeit.default_timer() |
| 37 | +out = conv(t_large) |
| 38 | +end = timeit.default_timer() |
| 39 | + |
| 40 | +print(out) |
| 41 | +print(f"Time taken: {end-start}") |
| 42 | + |
| 43 | + |
| 44 | +###################################################### |
| 45 | +# Consider an arbitrary network such as the following: |
| 46 | + |
| 47 | +import torch.nn as nn |
| 48 | +import torch.nn.functional as F |
| 49 | + |
| 50 | + |
| 51 | +class Net(nn.Module): |
| 52 | + def __init__(self): |
| 53 | + super().__init__() |
| 54 | + self.conv1 = nn.Conv2d(3, 6, 5) |
| 55 | + self.pool = nn.MaxPool2d(2, 2) |
| 56 | + self.conv2 = nn.Conv2d(6, 16, 5) |
| 57 | + self.fc1 = nn.Linear(16 * 5 * 5, 120) |
| 58 | + self.fc2 = nn.Linear(120, 84) |
| 59 | + self.fc3 = nn.Linear(84, 10) |
| 60 | + |
| 61 | + def forward(self, x): |
| 62 | + x = self.pool(F.relu(self.conv1(x))) |
| 63 | + x = self.pool(F.relu(self.conv2(x))) |
| 64 | + x = torch.flatten(x, 1) # flatten all dimensions except batch |
| 65 | + x = F.relu(self.fc1(x)) |
| 66 | + x = F.relu(self.fc2(x)) |
| 67 | + x = self.fc3(x) |
| 68 | + return x |
| 69 | + |
| 70 | + |
| 71 | +############################################################################### |
| 72 | +# We can view the intermediate shapes within an entire network by registering a |
| 73 | +# forward hook to each layer that prints the shape of the output. |
| 74 | + |
| 75 | +def fw_hook(module, input, output): |
| 76 | + print(f"Shape of output to {module} is {output.shape}.") |
| 77 | + |
| 78 | + |
| 79 | +# Any tensor created within this torch.device context manager will be |
| 80 | +# on the meta device. |
| 81 | +with torch.device("meta"): |
| 82 | + net = Net() |
| 83 | + inp = torch.randn((1024, 3, 32, 32)) |
| 84 | + |
| 85 | +for name, layer in net.named_modules(): |
| 86 | + layer.register_forward_hook(fw_hook) |
| 87 | + |
| 88 | +out = net(inp) |
0 commit comments