@@ -175,7 +175,7 @@ def evaluate(mod, inp):
175
175
######################################################################
176
176
# And indeed, we can see that running our model with ``torch.compile``
177
177
# results in a significant speedup. On an NVIDIA A100 GPU, we observe a
178
- # 2.2x speedup. Speedup mainly comes from reducing Python overhead and
178
+ # 2.3x speedup. Speedup mainly comes from reducing Python overhead and
179
179
# GPU read/writes, and so the observed speedup may vary on factors such as model
180
180
# architecture and batch size. For example, if a model's architecture is simple
181
181
# and the amount of data is large, then the bottleneck would be
@@ -197,16 +197,16 @@ def evaluate(mod, inp):
197
197
opt = torch .optim .Adam (model .parameters ())
198
198
199
199
def train (mod , data ):
200
+ opt .zero_grad (True )
200
201
pred = mod (data [0 ])
201
202
loss = torch .nn .CrossEntropyLoss ()(pred , data [1 ])
202
203
loss .backward ()
204
+ opt .step ()
203
205
204
206
eager_times = []
205
207
for i in range (N_ITERS ):
206
208
inp = generate_data (16 )
207
- opt .zero_grad (True )
208
209
_ , eager_time = timed (lambda : train (model , inp ))
209
- opt .step ()
210
210
eager_times .append (eager_time )
211
211
print (f"eager train time { i } : { eager_time } " )
212
212
print ("~" * 10 )
@@ -218,9 +218,7 @@ def train(mod, data):
218
218
compile_times = []
219
219
for i in range (N_ITERS ):
220
220
inp = generate_data (16 )
221
- opt .zero_grad (True )
222
221
_ , compile_time = timed (lambda : train_opt (model , inp ))
223
- opt .step ()
224
222
compile_times .append (compile_time )
225
223
print (f"compile train time { i } : { compile_time } " )
226
224
print ("~" * 10 )
@@ -235,13 +233,7 @@ def train(mod, data):
235
233
# Again, we can see that ``torch.compile`` takes longer in the first
236
234
# iteration, as it must compile the model, but afterward, we see
237
235
# significant speedups compared to eager. On an NVIDIA A100 GPU, we
238
- # observe a 1.8x speedup.
239
- #
240
- # One thing to note is that, as of now, we cannot place optimizer code --
241
- # ``opt.zero_grad`` and ``opt.step`` -- inside of an optimized function.
242
- # The rest of the training loop -- the forward pass and the backward pass --
243
- # can be optimized. We are currently working on enabling optimizers to be
244
- # compatible with ``torch.compile``.
236
+ # observe a 2.2x speedup.
245
237
246
238
######################################################################
247
239
# Comparison to TorchScript and FX Tracing
0 commit comments