Skip to content

Commit b69a12a

Browse files
Svetlana Karsliogluwilliamwen42
Svetlana Karslioglu
andauthored
Add Torch export tutorial
- Add Torch export tutorial (#2557) - Temporarily pull 2.1 binaries in the stable branch - Update wordlist --------- Co-authored-by: Svetlana Karslioglu <svekars@fb.com> Co-authored-by: William Wen <williamwen@meta.com>
1 parent a0277fd commit b69a12a

File tree

4 files changed

+562
-19
lines changed

4 files changed

+562
-19
lines changed

.jenkins/build.sh

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -24,11 +24,10 @@ pip install --progress-bar off -r $DIR/../requirements.txt
2424

2525
#Install PyTorch Nightly for test.
2626
# Nightly - pip install --pre torch torchvision torchaudio -f https://download.pytorch.org/whl/nightly/cu102/torch_nightly.html
27-
# RC Link
28-
# pip uninstall -y torch torchvision torchaudio torchtext
29-
# pip install --pre --upgrade -f https://download.pytorch.org/whl/test/cu102/torch_test.html torch torchvision torchaudio torchtext
30-
# pip uninstall -y torch torchvision torchaudio torchtext
31-
# pip install --pre --upgrade -f https://download.pytorch.org/whl/test/cu116/torch_test.html torch torchdata torchvision torchaudio torchtext
27+
# Install 2.1 for testing
28+
pip uninstall -y torch torchvision torchaudio torchtext torchdata
29+
pip3 install torch torchvision torchaudio --no-cache-dir --index-url https://download.pytorch.org/whl/test/cu121
30+
pip3 install torchdata torchtext --index-url https://download.pytorch.org/whl/test/cpu
3231

3332
# Install two language tokenizers for Translation with TorchText tutorial
3433
python -m spacy download en_core_web_sm

en-wordlist.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -436,6 +436,7 @@ summarization
436436
swappable
437437
tanh
438438
th
439+
tensor's
439440
thresholding
440441
timestep
441442
timesteps

intermediate_source/torch_compile_tutorial.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -11,15 +11,15 @@
1111
# ``torch.compile`` makes PyTorch code run faster by
1212
# JIT-compiling PyTorch code into optimized kernels,
1313
# all while requiring minimal code changes.
14-
#
14+
#
1515
# In this tutorial, we cover basic ``torch.compile`` usage,
1616
# and demonstrate the advantages of ``torch.compile`` over
1717
# previous PyTorch compiler solutions, such as
18-
# `TorchScript <https://pytorch.org/docs/stable/jit.html>`__ and
18+
# `TorchScript <https://pytorch.org/docs/stable/jit.html>`__ and
1919
# `FX Tracing <https://pytorch.org/docs/stable/fx.html#torch.fx.symbolic_trace>`__.
2020
#
2121
# **Contents**
22-
#
22+
#
2323
# - Basic Usage
2424
# - Demonstrating Speedups
2525
# - Comparison to TorchScript and FX Tracing
@@ -59,7 +59,7 @@
5959
#
6060
# ``torch.compile`` is included in the latest PyTorch..
6161
# Running TorchInductor on GPU requires Triton, which is included with the PyTorch 2.0 nightly
62-
# binary. If Triton is still missing, try installing ``torchtriton`` via pip
62+
# binary. If Triton is still missing, try installing ``torchtriton`` via pip
6363
# (``pip install torchtriton --extra-index-url "https://download.pytorch.org/whl/nightly/cu117"``
6464
# for CUDA 11.7).
6565
#
@@ -104,7 +104,7 @@ def forward(self, x):
104104
# -----------------------
105105
#
106106
# Let's now demonstrate that using ``torch.compile`` can speed
107-
# up real models. We will compare standard eager mode and
107+
# up real models. We will compare standard eager mode and
108108
# ``torch.compile`` by evaluating and training a ``torchvision`` model on random data.
109109
#
110110
# Before we start, we need to define some utility functions.
@@ -253,15 +253,15 @@ def train(mod, data):
253253
######################################################################
254254
# Comparison to TorchScript and FX Tracing
255255
# -----------------------------------------
256-
#
256+
#
257257
# We have seen that ``torch.compile`` can speed up PyTorch code.
258258
# Why else should we use ``torch.compile`` over existing PyTorch
259259
# compiler solutions, such as TorchScript or FX Tracing? Primarily, the
260260
# advantage of ``torch.compile`` lies in its ability to handle
261261
# arbitrary Python code with minimal changes to existing code.
262262
#
263263
# One case that ``torch.compile`` can handle that other compiler
264-
# solutions struggle with is data-dependent control flow (the
264+
# solutions struggle with is data-dependent control flow (the
265265
# ``if x.sum() < 0:`` line below).
266266

267267
def f1(x, y):
@@ -399,7 +399,7 @@ def f3(x):
399399
# `FX graphs <https://pytorch.org/docs/stable/fx.html#torch.fx.Graph>`__, which can
400400
# then be further optimized. TorchDynamo extracts FX graphs by analyzing Python bytecode
401401
# during runtime and detecting calls to PyTorch operations.
402-
#
402+
#
403403
# Normally, TorchInductor, another component of ``torch.compile``,
404404
# further compiles the FX graphs into optimized kernels,
405405
# but TorchDynamo allows for different backends to be used. In order to inspect
@@ -463,10 +463,8 @@ def bar(a, b):
463463

464464
# Reset since we are using a different backend.
465465
torch._dynamo.reset()
466-
explanation, out_guards, graphs, ops_per_graph, break_reasons, explanation_verbose = torch._dynamo.explain(
467-
bar, torch.randn(10), torch.randn(10)
468-
)
469-
print(explanation_verbose)
466+
explain_output = torch._dynamo.explain(bar)(torch.randn(10), torch.randn(10))
467+
print(explain_output)
470468

471469
######################################################################
472470
# In order to maximize speedup, graph breaks should be limited.
@@ -487,16 +485,18 @@ def bar(a, b):
487485
print(opt_model(generate_data(16)[0]))
488486

489487
######################################################################
488+
# <!----TODO: replace this section with a link to the torch.export tutorial when done --->
489+
#
490490
# Finally, if we simply want TorchDynamo to output the FX graph for export,
491491
# we can use ``torch._dynamo.export``. Note that ``torch._dynamo.export``, like
492492
# ``fullgraph=True``, raises an error if TorchDynamo breaks the graph.
493493

494494
try:
495-
torch._dynamo.export(bar, torch.randn(10), torch.randn(10))
495+
torch._dynamo.export(bar)(torch.randn(10), torch.randn(10))
496496
except:
497497
tb.print_exc()
498498

499-
model_exp = torch._dynamo.export(init_model(), generate_data(16)[0])
499+
model_exp = torch._dynamo.export(init_model())(generate_data(16)[0])
500500
print(model_exp[0](generate_data(16)[0]))
501501

502502
######################################################################

0 commit comments

Comments
 (0)