Skip to content

Commit c0a92ae

Browse files
committed
Add Model Freezing in TorchScript
1 parent d40dc05 commit c0a92ae

File tree

2 files changed

+142
-0
lines changed

2 files changed

+142
-0
lines changed
Lines changed: 135 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,135 @@
1+
"""
2+
Model Freezing in TorchScript
3+
=============================
4+
5+
In this tutorial, we introduce the syntax for *model freezing* in TorchScript.
6+
Freezing is the process of inlining Pytorch module parameters and attributes
7+
values into the TorchScript internal representation. Parameter and attribute
8+
values are treated as final values and they cannot be modified in the resulting
9+
Frozen module.
10+
11+
Basic Syntax
12+
------------
13+
14+
Model freezing can be invoked using API below:
15+
16+
* ``torch.jit.freeze(mod : ScriptModule, names : str[]) -> SciptModule``
17+
18+
Note the input module can either be the result of scripting or tracing,
19+
See https://pytorch.org/tutorials/beginner/Intro_to_TorchScript_tutorial.html
20+
21+
Next, we demonstrate how freezing works using an example:
22+
"""
23+
24+
import torch, time
25+
26+
class Net(torch.nn.Module):
27+
def __init__(self):
28+
super(Net, self).__init__()
29+
self.conv1 = torch.nn.Conv2d(1, 32, 3, 1)
30+
self.conv2 = torch.nn.Conv2d(32, 64, 3, 1)
31+
self.dropout1 = torch.nn.Dropout2d(0.25)
32+
self.dropout2 = torch.nn.Dropout2d(0.5)
33+
self.fc1 = torch.nn.Linear(9216, 128)
34+
self.fc2 = torch.nn.Linear(128, 10)
35+
36+
def forward(self, x):
37+
x = self.conv1(x)
38+
x = torch.nn.functional.relu(x)
39+
x = self.conv2(x)
40+
x = torch.nn.functional.max_pool2d(x, 2)
41+
x = self.dropout1(x)
42+
x = torch.flatten(x, 1)
43+
x = self.fc1(x)
44+
x = torch.nn.functional.relu(x)
45+
x = self.dropout2(x)
46+
x = self.fc2(x)
47+
output = torch.nn.functional.log_softmax(x, dim=1)
48+
return output
49+
50+
@torch.jit.export
51+
def version(self):
52+
return 1.0
53+
54+
net = torch.jit.script(Net())
55+
fnet = torch.jit.freeze(net)
56+
57+
print(net.conv1.weight.size())
58+
print(net.conv1.bias)
59+
60+
try:
61+
print(fnet.conv1.bias)
62+
# without exception handling, prints:
63+
# RuntimeError: __torch__.z.___torch_mangle_3.Net does not have a field
64+
# with name 'conv1'
65+
except RuntimeError:
66+
print("field 'conv1' is inlined. It does not exist in 'fnet'")
67+
68+
try:
69+
fnet.version()
70+
# without exception handling, prints:
71+
# RuntimeError: __torch__.z.___torch_mangle_3.Net does not have a field
72+
# with name 'version'
73+
except RuntimeError:
74+
print("method 'version' is not deleted in fnet. Only 'forward' is preserved")
75+
76+
fnet2 = torch.jit.freeze(net, ["version"])
77+
78+
print(fnet2.version())
79+
80+
B=1
81+
warmup = 1
82+
iter = 1000
83+
input = torch.rand(B, 1,28, 28)
84+
85+
start = time.time()
86+
for i in range(warmup):
87+
net(input)
88+
end = time.time()
89+
print("Scripted - Warm up time: {0:7.4f}".format(end-start), flush=True)
90+
91+
start = time.time()
92+
for i in range(warmup):
93+
fnet(input)
94+
end = time.time()
95+
print("Frozen - Warm up time: {0:7.4f}".format(end-start), flush=True)
96+
97+
start = time.time()
98+
for i in range(iter):
99+
input = torch.rand(B, 1,28, 28)
100+
net(input)
101+
end = time.time()
102+
print("Scripted - Inference: {0:5.2f}".format(end-start), flush=True)
103+
104+
start = time.time()
105+
for i in range(iter):
106+
input = torch.rand(B, 1,28, 28)
107+
fnet2(input)
108+
end = time.time()
109+
print("Frozen - Inference time: {0:5.2f}".format(end-start), flush =True)
110+
111+
"""
112+
On my machine, I measured the time:
113+
114+
Scripted - Warm up time: 0.0107
115+
Frozen - Warm up time: 0.0048
116+
Scripted - Inference: 1.35
117+
Frozen - Inference time: 1.17
118+
119+
In our example, warm up time measures the first two runs. The frozen model
120+
is 50% faster than the scripted model. On some more complex models, we
121+
observed even higher speed up of warm up time. freezing achieves this speed up
122+
because it is doing some the work TorchScript has to do when the first couple
123+
runs are initiated.
124+
125+
Inference time measures inference execution time after the model is warmed up.
126+
Although we observed significant variation in execution time, the
127+
frozen model is often about 15% faster than the scripted model. When input is larger,
128+
we observe a smaller speed up because the execution is dominated by tensor operations.
129+
130+
Conclusion
131+
-----------
132+
133+
In this tutorial, we learned about model freezing. Freezing is a useful technique to
134+
optimize models for inference and it also can significantly reduce TorchScript warmup time.
135+
"""

recipes_source/recipes_index.rst

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,12 @@ Recipes are bite-sized bite-sized, actionable examples of how to use specific Py
146146
:link: ../recipes/deployment_with_flask.html
147147
:tags: Production,TorchScript
148148

149+
.. customcarditem::
150+
:header: Model Freezing in TorchScript
151+
:card_description: Learn how to use freezing API to optimize your trained model in TorchScript and do inference.
152+
:image: ../_static/img/thumbnails/cropped/torchscript_overview.png
153+
:link: ../recipes/torchscript_freezing.html
154+
:tags: TorchScript
149155

150156
.. End of tutorial card section
151157
@@ -180,3 +186,4 @@ Recipes are bite-sized bite-sized, actionable examples of how to use specific Py
180186
/recipes/recipes/dynamic_quantization
181187
/recipes/torchscript_inference
182188
/recipes/deployment_with_flask
189+
/recipes/torchscript_freezing

0 commit comments

Comments
 (0)