Skip to content

Commit 6e261b2

Browse files
Jessica Linezyangguyang3532hritikbhandari
authored
1.6 model freezing tutorial (#1077)
* Update feature classification labels * Update NVidia -> Nvidia * Bring back default filename_pattern so that by default we run all galleries. Signed-off-by: Edward Z. Yang <ezyang@fb.com> * Add prototype_source directory * Add prototype directory * Add prototype * Remove extra "done" * Add REAME.txt * Update for prototype instructions * Update for prototype feature * refine torchvision_tutorial doc for windows * Update neural_style_tutorial.py (#1059) Updated the mistake in the Loading Images Section. * torch_script_custom_ops restructure (#1057) Signed-off-by: Edward Z. Yang <ezyang@fb.com> * Port custom ops tutorial to new registration API, increase testability. Signed-off-by: Edward Z. Yang <ezyang@fb.com> * Kill some other occurrences of RegisterOperators Signed-off-by: Edward Z. Yang <ezyang@fb.com> * Update README.md * Make torch_script_custom_classes tutorial runnable I also fixed some warnings in the tutorial, and fixed some minor bitrot (e.g., torch::script::Module to torch::jit::Module) I also added some missing quotes around some bash expansions. Signed-off-by: Edward Z. Yang <ezyang@fb.com> * Update torch_script_custom_classes to use TORCH_LIBRARY (#1062) Signed-off-by: Edward Z. Yang <ezyang@fb.com> * Add Model Freezing in TorchScript Co-authored-by: Edward Z. Yang <ezyang@fb.com> Co-authored-by: Yang Gu <yangu@microsoft.com> Co-authored-by: Hritik Bhandari <bhandari.hritik@gmail.com>
1 parent f8b200d commit 6e261b2

File tree

2 files changed

+135
-1
lines changed

2 files changed

+135
-1
lines changed

prototype_source/README.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,4 +2,4 @@ Prototype Tutorials
22
------------------
33
1. distributed_rpc_profiling.rst
44
Profiling PyTorch RPC-Based Workloads
5-
https://github.com/pytorch/tutorials/blob/release/1.6/prototype_source/distributed_rpc_profiling.rst
5+
https://github.com/pytorch/tutorials/blob/release/1.6/prototype_source/distributed_rpc_profiling.rst
Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,134 @@
1+
"""
2+
Model Freezing in TorchScript
3+
=============================
4+
5+
In this tutorial, we introduce the syntax for *model freezing* in TorchScript.
6+
Freezing is the process of inlining Pytorch module parameters and attributes
7+
values into the TorchScript internal representation. Parameter and attribute
8+
values are treated as final values and they cannot be modified in the resulting
9+
Frozen module.
10+
11+
Basic Syntax
12+
------------
13+
Model freezing can be invoked using API below:
14+
15+
``torch.jit.freeze(mod : ScriptModule, names : str[]) -> SciptModule``
16+
17+
Note the input module can either be the result of scripting or tracing.
18+
See https://pytorch.org/tutorials/beginner/Intro_to_TorchScript_tutorial.html
19+
20+
Next, we demonstrate how freezing works using an example:
21+
"""
22+
23+
import torch, time
24+
25+
class Net(torch.nn.Module):
26+
def __init__(self):
27+
super(Net, self).__init__()
28+
self.conv1 = torch.nn.Conv2d(1, 32, 3, 1)
29+
self.conv2 = torch.nn.Conv2d(32, 64, 3, 1)
30+
self.dropout1 = torch.nn.Dropout2d(0.25)
31+
self.dropout2 = torch.nn.Dropout2d(0.5)
32+
self.fc1 = torch.nn.Linear(9216, 128)
33+
self.fc2 = torch.nn.Linear(128, 10)
34+
35+
def forward(self, x):
36+
x = self.conv1(x)
37+
x = torch.nn.functional.relu(x)
38+
x = self.conv2(x)
39+
x = torch.nn.functional.max_pool2d(x, 2)
40+
x = self.dropout1(x)
41+
x = torch.flatten(x, 1)
42+
x = self.fc1(x)
43+
x = torch.nn.functional.relu(x)
44+
x = self.dropout2(x)
45+
x = self.fc2(x)
46+
output = torch.nn.functional.log_softmax(x, dim=1)
47+
return output
48+
49+
@torch.jit.export
50+
def version(self):
51+
return 1.0
52+
53+
net = torch.jit.script(Net())
54+
fnet = torch.jit.freeze(net)
55+
56+
print(net.conv1.weight.size())
57+
print(net.conv1.bias)
58+
59+
try:
60+
print(fnet.conv1.bias)
61+
# without exception handling, prints:
62+
# RuntimeError: __torch__.z.___torch_mangle_3.Net does not have a field
63+
# with name 'conv1'
64+
except RuntimeError:
65+
print("field 'conv1' is inlined. It does not exist in 'fnet'")
66+
67+
try:
68+
fnet.version()
69+
# without exception handling, prints:
70+
# RuntimeError: __torch__.z.___torch_mangle_3.Net does not have a field
71+
# with name 'version'
72+
except RuntimeError:
73+
print("method 'version' is not deleted in fnet. Only 'forward' is preserved")
74+
75+
fnet2 = torch.jit.freeze(net, ["version"])
76+
77+
print(fnet2.version())
78+
79+
B=1
80+
warmup = 1
81+
iter = 1000
82+
input = torch.rand(B, 1,28, 28)
83+
84+
start = time.time()
85+
for i in range(warmup):
86+
net(input)
87+
end = time.time()
88+
print("Scripted - Warm up time: {0:7.4f}".format(end-start), flush=True)
89+
90+
start = time.time()
91+
for i in range(warmup):
92+
fnet(input)
93+
end = time.time()
94+
print("Frozen - Warm up time: {0:7.4f}".format(end-start), flush=True)
95+
96+
start = time.time()
97+
for i in range(iter):
98+
input = torch.rand(B, 1,28, 28)
99+
net(input)
100+
end = time.time()
101+
print("Scripted - Inference: {0:5.2f}".format(end-start), flush=True)
102+
103+
start = time.time()
104+
for i in range(iter):
105+
input = torch.rand(B, 1,28, 28)
106+
fnet2(input)
107+
end = time.time()
108+
print("Frozen - Inference time: {0:5.2f}".format(end-start), flush =True)
109+
110+
###############################################################
111+
# On my machine, I measured the time:
112+
#
113+
# * Scripted - Warm up time: 0.0107
114+
# * Frozen - Warm up time: 0.0048
115+
# * Scripted - Inference: 1.35
116+
# * Frozen - Inference time: 1.17
117+
118+
###############################################################
119+
# In our example, warm up time measures the first two runs. The frozen model
120+
# is 50% faster than the scripted model. On some more complex models, we
121+
# observed even higher speed up of warm up time. freezing achieves this speed up
122+
# because it is doing some the work TorchScript has to do when the first couple
123+
# runs are initiated.
124+
#
125+
# Inference time measures inference execution time after the model is warmed up.
126+
# Although we observed significant variation in execution time, the
127+
# frozen model is often about 15% faster than the scripted model. When input is larger,
128+
# we observe a smaller speed up because the execution is dominated by tensor operations.
129+
130+
###############################################################
131+
# Conclusion
132+
# -----------
133+
# In this tutorial, we learned about model freezing. Freezing is a useful technique to
134+
# optimize models for inference and it also can significantly reduce TorchScript warmup time.

0 commit comments

Comments
 (0)