3
3
============================================================================
4
4
5
5
**Author:** `Animesh Jain <https://github.com/anijain2305>`_
6
+
6
7
As deep learning models get larger, the compilation time of these models also
7
8
increases. This extended compilation time can result in a large startup time in
8
9
inference services or wasted resources in large-scale training. This recipe
23
24
24
25
pip install torch
25
26
26
- .. note::
27
+ .. note::
27
28
This feature is available starting with the 2.5 release. If you are using version 2.4,
28
29
you can enable the configuration flag ``torch._dynamo.config.inline_inbuilt_nn_modules=True``
29
30
to prevent recompilations during regional compilation. In version 2.5, this flag is enabled by default.
30
31
"""
31
32
32
-
33
+ from time import perf_counter
33
34
34
35
######################################################################
35
36
# Steps
36
37
# -----
37
- #
38
+ #
38
39
# In this recipe, we will follow these steps:
39
40
#
40
41
# 1. Import all necessary libraries.
41
42
# 2. Define and initialize a neural network with repeated regions.
42
43
# 3. Understand the difference between the full model and the regional compilation.
43
44
# 4. Measure the compilation time of the full model and the regional compilation.
44
- #
45
- # First, let's import the necessary libraries for loading our data:
46
- #
47
- #
48
- #
45
+ #
46
+ # First, let's import the necessary libraries for loading our data:
47
+ #
48
+ #
49
+ #
49
50
50
51
import torch
51
52
import torch .nn as nn
52
- from time import perf_counter
53
+
53
54
54
55
##########################################################
55
56
# Next, let's define and initialize a neural network with repeated regions.
56
- #
57
+ #
57
58
# Typically, neural networks are composed of repeated layers. For example, a
58
59
# large language model is composed of many Transformer blocks. In this recipe,
59
60
# we will create a ``Layer`` using the ``nn.Module`` class as a proxy for a repeated region.
60
61
# We will then create a ``Model`` which is composed of 64 instances of this
61
62
# ``Layer`` class.
62
- #
63
+ #
63
64
class Layer (torch .nn .Module ):
64
65
def __init__ (self ):
65
66
super ().__init__ ()
@@ -76,13 +77,16 @@ def forward(self, x):
76
77
b = self .relu2 (b )
77
78
return b
78
79
80
+
79
81
class Model (torch .nn .Module ):
80
82
def __init__ (self , apply_regional_compilation ):
81
83
super ().__init__ ()
82
84
self .linear = torch .nn .Linear (10 , 10 )
83
85
# Apply compile only to the repeated layers.
84
86
if apply_regional_compilation :
85
- self .layers = torch .nn .ModuleList ([torch .compile (Layer ()) for _ in range (64 )])
87
+ self .layers = torch .nn .ModuleList (
88
+ [torch .compile (Layer ()) for _ in range (64 )]
89
+ )
86
90
else :
87
91
self .layers = torch .nn .ModuleList ([Layer () for _ in range (64 )])
88
92
@@ -93,15 +97,16 @@ def forward(self, x):
93
97
x = layer (x )
94
98
return x
95
99
100
+
96
101
####################################################
97
102
# Next, let's review the difference between the full model and the regional compilation.
98
- #
99
- # In full model compilation, the entire model is compiled as a whole. This is the common approach
103
+ #
104
+ # In full model compilation, the entire model is compiled as a whole. This is the common approach
100
105
# most users take with ``torch.compile``. In this example, we apply ``torch.compile`` to
101
106
# the ``Model`` object. This will effectively inline the 64 layers, producing a
102
107
# large graph to compile. You can look at the full graph by running this recipe
103
108
# with ``TORCH_LOGS=graph_code``.
104
- #
109
+ #
105
110
#
106
111
107
112
model = Model (apply_regional_compilation = False ).cuda ()
@@ -113,19 +118,19 @@ def forward(self, x):
113
118
# By strategically choosing to compile a repeated region of the model, we can compile a
114
119
# much smaller graph and then reuse the compiled graph for all the regions.
115
120
# In the example, ``torch.compile`` is applied only to the ``layers`` and not the full model.
116
- #
121
+ #
117
122
118
123
regional_compiled_model = Model (apply_regional_compilation = True ).cuda ()
119
124
120
125
#####################################################
121
126
# Applying compilation to a repeated region, instead of full model, leads to
122
127
# large savings in compile time. Here, we will just compile a layer instance and
123
128
# then reuse it 64 times in the ``Model`` object.
124
- #
129
+ #
125
130
# Note that with repeated regions, some part of the model might not be compiled.
126
131
# For example, the ``self.linear`` in the ``Model`` is outside of the scope of
127
132
# regional compilation.
128
- #
133
+ #
129
134
# Also, note that there is a tradeoff between performance speedup and compile
130
135
# time. Full model compilation involves a larger graph and,
131
136
# theoretically, offers more scope for optimizations. However, for practical
@@ -137,10 +142,11 @@ def forward(self, x):
137
142
# Next, let's measure the compilation time of the full model and the regional compilation.
138
143
#
139
144
# ``torch.compile`` is a JIT compiler, which means that it compiles on the first invocation.
140
- # In the code below, we measure the total time spent in the first invocation. While this method is not
145
+ # In the code below, we measure the total time spent in the first invocation. While this method is not
141
146
# precise, it provides a good estimate since the majority of the time is spent in
142
147
# compilation.
143
148
149
+
144
150
def measure_latency (fn , input ):
145
151
# Reset the compiler caches to ensure no reuse between different runs
146
152
torch .compiler .reset ()
@@ -151,13 +157,16 @@ def measure_latency(fn, input):
151
157
end = perf_counter ()
152
158
return end - start
153
159
160
+
154
161
input = torch .randn (10 , 10 , device = "cuda" )
155
162
full_model_compilation_latency = measure_latency (full_compiled_model , input )
156
163
print (f"Full model compilation time = { full_model_compilation_latency :.2f} seconds" )
157
164
158
165
regional_compilation_latency = measure_latency (regional_compiled_model , input )
159
166
print (f"Regional compilation time = { regional_compilation_latency :.2f} seconds" )
160
167
168
+ assert regional_compilation_latency < full_model_compilation_latency
169
+
161
170
############################################################################
162
171
# Conclusion
163
172
# -----------
@@ -166,4 +175,4 @@ def measure_latency(fn, input):
166
175
# has repeated regions. This approach requires user modifications to apply `torch.compile` to
167
176
# the repeated regions instead of more commonly used full model compilation. We
168
177
# are continually working on reducing cold start compilation time.
169
- #
178
+ #
0 commit comments