Skip to content

torch.compile tutorial update for pt2 stable release #2224

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Mar 8, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion .jenkins/validate_tutorials_built.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@
"recipes/profiler_recipe",
"recipes/save_load_across_devices",
"recipes/warmstarting_model_using_parameters_from_a_different_model",
"torch_compile_tutorial_",
"recipes/dynamic_quantization",
"recipes/saving_and_loading_a_general_checkpoint",
"recipes/benchmark",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
#
# **Required pip Dependencies**
#
# - ``torch >= 1.14``
# - ``torch >= 2.0``
# - ``torchvision``
# - ``numpy``
# - ``scipy``
Expand All @@ -52,9 +52,6 @@

import torch

import torch._inductor.config
torch._inductor.config.cpp.cxx = ("g++",)

def foo(x, y):
a = torch.sin(x)
b = torch.cos(x)
Expand Down Expand Up @@ -133,6 +130,11 @@ def evaluate(mod, inp):
return mod(inp)

model = init_model()

# Reset since we are using a different mode.
import torch._dynamo
torch._dynamo.reset()

evaluate_opt = torch.compile(evaluate, mode="reduce-overhead")

inp = generate_data(16)[0]
Expand Down Expand Up @@ -174,8 +176,7 @@ def evaluate(mod, inp):

######################################################################
# And indeed, we can see that running our model with ``torch.compile``
# results in a significant speedup. On an NVIDIA A100 GPU, we observe a
# 2.3x speedup. Speedup mainly comes from reducing Python overhead and
# results in a significant speedup. Speedup mainly comes from reducing Python overhead and
# GPU read/writes, and so the observed speedup may vary on factors such as model
# architecture and batch size. For example, if a model's architecture is simple
# and the amount of data is large, then the bottleneck would be
Expand Down Expand Up @@ -231,9 +232,8 @@ def train(mod, data):

######################################################################
# Again, we can see that ``torch.compile`` takes longer in the first
# iteration, as it must compile the model, but afterward, we see
# significant speedups compared to eager. On an NVIDIA A100 GPU, we
# observe a 2.2x speedup.
# iteration, as it must compile the model, but in subsequent iterations, we see
# significant speedups compared to eager.

######################################################################
# Comparison to TorchScript and FX Tracing
Expand Down Expand Up @@ -297,6 +297,9 @@ def test_fns(fn1, fn2, args):
# Now we can see that ``torch.compile`` correctly handles
# data-dependent control flow.

# Reset since we are using a different mode.
torch._dynamo.reset()

compile_f1 = torch.compile(f1)
print("compile 1, 1:", test_fns(f1, compile_f1, (inp1, inp2)))
print("compile 1, 2:", test_fns(f1, compile_f1, (-inp1, inp2)))
Expand Down Expand Up @@ -394,7 +397,6 @@ def custom_backend(gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor])
gm.graph.print_tabular()
return gm.forward

import torch._dynamo
# Reset since we are using a different backend.
torch._dynamo.reset()

Expand Down Expand Up @@ -489,4 +491,4 @@ def bar(a, b):
# In this tutorial, we introduced ``torch.compile`` by covering
# basic usage, demonstrating speedups over eager mode, comparing to previous
# PyTorch compiler solutions, and briefly investigating TorchDynamo and its interactions
# with FX graphs. We hope that you will give ``torch.compile`` a try!
# with FX graphs. We hope that you will give ``torch.compile`` a try!
Loading