11
11
# ``torch.compile`` makes PyTorch code run faster by
12
12
# JIT-compiling PyTorch code into optimized kernels,
13
13
# all while requiring minimal code changes.
14
- #
14
+ #
15
15
# In this tutorial, we cover basic ``torch.compile`` usage,
16
16
# and demonstrate the advantages of ``torch.compile`` over
17
17
# previous PyTorch compiler solutions, such as
18
- # `TorchScript <https://pytorch.org/docs/stable/jit.html>`__ and
18
+ # `TorchScript <https://pytorch.org/docs/stable/jit.html>`__ and
19
19
# `FX Tracing <https://pytorch.org/docs/stable/fx.html#torch.fx.symbolic_trace>`__.
20
20
#
21
21
# **Contents**
22
- #
22
+ #
23
23
# - Basic Usage
24
24
# - Demonstrating Speedups
25
25
# - Comparison to TorchScript and FX Tracing
59
59
#
60
60
# ``torch.compile`` is included in the latest PyTorch..
61
61
# Running TorchInductor on GPU requires Triton, which is included with the PyTorch 2.0 nightly
62
- # binary. If Triton is still missing, try installing ``torchtriton`` via pip
62
+ # binary. If Triton is still missing, try installing ``torchtriton`` via pip
63
63
# (``pip install torchtriton --extra-index-url "https://download.pytorch.org/whl/nightly/cu117"``
64
64
# for CUDA 11.7).
65
65
#
@@ -104,7 +104,7 @@ def forward(self, x):
104
104
# -----------------------
105
105
#
106
106
# Let's now demonstrate that using ``torch.compile`` can speed
107
- # up real models. We will compare standard eager mode and
107
+ # up real models. We will compare standard eager mode and
108
108
# ``torch.compile`` by evaluating and training a ``torchvision`` model on random data.
109
109
#
110
110
# Before we start, we need to define some utility functions.
@@ -253,15 +253,15 @@ def train(mod, data):
253
253
######################################################################
254
254
# Comparison to TorchScript and FX Tracing
255
255
# -----------------------------------------
256
- #
256
+ #
257
257
# We have seen that ``torch.compile`` can speed up PyTorch code.
258
258
# Why else should we use ``torch.compile`` over existing PyTorch
259
259
# compiler solutions, such as TorchScript or FX Tracing? Primarily, the
260
260
# advantage of ``torch.compile`` lies in its ability to handle
261
261
# arbitrary Python code with minimal changes to existing code.
262
262
#
263
263
# One case that ``torch.compile`` can handle that other compiler
264
- # solutions struggle with is data-dependent control flow (the
264
+ # solutions struggle with is data-dependent control flow (the
265
265
# ``if x.sum() < 0:`` line below).
266
266
267
267
def f1 (x , y ):
@@ -399,7 +399,7 @@ def f3(x):
399
399
# `FX graphs <https://pytorch.org/docs/stable/fx.html#torch.fx.Graph>`__, which can
400
400
# then be further optimized. TorchDynamo extracts FX graphs by analyzing Python bytecode
401
401
# during runtime and detecting calls to PyTorch operations.
402
- #
402
+ #
403
403
# Normally, TorchInductor, another component of ``torch.compile``,
404
404
# further compiles the FX graphs into optimized kernels,
405
405
# but TorchDynamo allows for different backends to be used. In order to inspect
@@ -463,10 +463,8 @@ def bar(a, b):
463
463
464
464
# Reset since we are using a different backend.
465
465
torch ._dynamo .reset ()
466
- explanation , out_guards , graphs , ops_per_graph , break_reasons , explanation_verbose = torch ._dynamo .explain (
467
- bar , torch .randn (10 ), torch .randn (10 )
468
- )
469
- print (explanation_verbose )
466
+ explain_output = torch ._dynamo .explain (bar )(torch .randn (10 ), torch .randn (10 ))
467
+ print (explain_output )
470
468
471
469
######################################################################
472
470
# In order to maximize speedup, graph breaks should be limited.
@@ -487,16 +485,18 @@ def bar(a, b):
487
485
print (opt_model (generate_data (16 )[0 ]))
488
486
489
487
######################################################################
488
+ # <!----TODO: replace this section with a link to the torch.export tutorial when done --->
489
+ #
490
490
# Finally, if we simply want TorchDynamo to output the FX graph for export,
491
491
# we can use ``torch._dynamo.export``. Note that ``torch._dynamo.export``, like
492
492
# ``fullgraph=True``, raises an error if TorchDynamo breaks the graph.
493
493
494
494
try :
495
- torch ._dynamo .export (bar , torch .randn (10 ), torch .randn (10 ))
495
+ torch ._dynamo .export (bar )( torch .randn (10 ), torch .randn (10 ))
496
496
except :
497
497
tb .print_exc ()
498
498
499
- model_exp = torch ._dynamo .export (init_model (), generate_data (16 )[0 ])
499
+ model_exp = torch ._dynamo .export (init_model ())( generate_data (16 )[0 ])
500
500
print (model_exp [0 ](generate_data (16 )[0 ]))
501
501
502
502
######################################################################
0 commit comments