diff --git a/intermediate_source/torch_compile_full_example.py b/intermediate_source/torch_compile_full_example.py new file mode 100644 index 00000000000..51e42bd7909 --- /dev/null +++ b/intermediate_source/torch_compile_full_example.py @@ -0,0 +1,204 @@ +# -*- coding: utf-8 -*- + +""" +``torch.compile`` End-to-End Tutorial +================================= +**Author:** William Wen +""" + +import warnings + +###################################################################### +# ``torch.compile`` is the new way to speed up your PyTorch code! +# ``torch.compile`` makes PyTorch code run faster by +# JIT-compiling PyTorch code into optimized kernels, +# while requiring minimal code changes. +# +# This tutorial covers an end-to-end example of training and evaluating a +# real model with ``torch.compile``. For a gentler introduction to ``torch.compile``, +# please check out our ```torch.compile`` tutorial `__. +# +# **Required pip Dependencies** +# +# - ``torch >= 2.0`` +# - ``torchvision`` + +# NOTE: a modern NVIDIA GPU (H100, A100, or V100) is recommended for this tutorial in +# order to reproduce the speedup numbers shown below and documented elsewhere. + +import torch + +gpu_ok = False +if torch.cuda.is_available(): + device_cap = torch.cuda.get_device_capability() + if device_cap in ((7, 0), (8, 0), (9, 0)): + gpu_ok = True + +if not gpu_ok: + warnings.warn( + "GPU is not NVIDIA V100, A100, or H100. Speedup numbers may be lower " + "than expected." + ) + + +###################################################################### +# Let's demonstrate how using ``torch.compile`` can speed up a real model. +# We will compare standard eager mode and +# ``torch.compile`` by evaluating and training a ``torchvision`` model on random data. +# +# Before we start, we need to define some utility functions. + + +# Returns the result of running `fn()` and the time it took for `fn()` to run, +# in seconds. We use CUDA events and synchronization for the most accurate +# measurements. +def timed(fn): + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + start.record() + result = fn() + end.record() + torch.cuda.synchronize() + return result, start.elapsed_time(end) / 1000 + + +# Generates random input and targets data for the model, where `b` is +# batch size. +def generate_data(b): + return ( + torch.randn(b, 3, 128, 128).to().cuda(), + torch.randint(1000, (b,)).cuda(), + ) + + +N_ITERS = 10 + +from torchvision.models import densenet121 + + +def init_model(): + return densenet121().cuda() + + +###################################################################### +# First, let's compare inference. +# +# Note that in the call to ``torch.compile``, we have the additional +# ``mode`` argument, which we will discuss below. + +model = init_model() + +model_opt = torch.compile(model, mode="reduce-overhead") + +inp = generate_data(16)[0] +with torch.no_grad(): + print("eager:", timed(lambda: model(inp))[1]) + print("compile:", timed(lambda: model_opt(inp))[1]) + +###################################################################### +# Notice that ``torch.compile`` takes a lot longer to complete +# compared to eager. This is because ``torch.compile`` compiles +# the model into optimized kernels as it executes. In our example, the +# structure of the model doesn't change, and so recompilation is not +# needed. So if we run our optimized model several more times, we should +# see a significant improvement compared to eager. + +eager_times = [] +for i in range(N_ITERS): + inp = generate_data(16)[0] + with torch.no_grad(): + _, eager_time = timed(lambda: model(inp)) + eager_times.append(eager_time) + print(f"eager eval time {i}: {eager_time}") + +print("~" * 10) + +compile_times = [] +for i in range(N_ITERS): + inp = generate_data(16)[0] + with torch.no_grad(): + _, compile_time = timed(lambda: model_opt(inp)) + compile_times.append(compile_time) + print(f"compile eval time {i}: {compile_time}") +print("~" * 10) + +import numpy as np + +eager_med = np.median(eager_times) +compile_med = np.median(compile_times) +speedup = eager_med / compile_med +assert speedup > 1 +print( + f"(eval) eager median: {eager_med}, compile median: {compile_med}, speedup: {speedup}x" +) +print("~" * 10) + +###################################################################### +# And indeed, we can see that running our model with ``torch.compile`` +# results in a significant speedup. Speedup mainly comes from reducing Python overhead and +# GPU read/writes, and so the observed speedup may vary on factors such as model +# architecture and batch size. For example, if a model's architecture is simple +# and the amount of data is large, then the bottleneck would be +# GPU compute and the observed speedup may be less significant. +# +# You may also see different speedup results depending on the chosen ``mode`` +# argument. The ``"reduce-overhead"`` mode uses CUDA graphs to further reduce +# the overhead of Python. For your own models, +# you may need to experiment with different modes to maximize speedup. You can +# read more about modes `here `__. +# +# You may might also notice that the second time we run our model with ``torch.compile`` is significantly +# slower than the other runs, although it is much faster than the first run. This is because the ``"reduce-overhead"`` +# mode runs a few warm-up iterations for CUDA graphs. +# +# Now, let's consider comparing training. + +model = init_model() +opt = torch.optim.Adam(model.parameters()) + + +def train(mod, data): + opt.zero_grad(True) + pred = mod(data[0]) + loss = torch.nn.CrossEntropyLoss()(pred, data[1]) + loss.backward() + opt.step() + + +eager_times = [] +for i in range(N_ITERS): + inp = generate_data(16) + _, eager_time = timed(lambda: train(model, inp)) + eager_times.append(eager_time) + print(f"eager train time {i}: {eager_time}") +print("~" * 10) + +model = init_model() +opt = torch.optim.Adam(model.parameters()) +train_opt = torch.compile(train, mode="reduce-overhead") + +compile_times = [] +for i in range(N_ITERS): + inp = generate_data(16) + _, compile_time = timed(lambda: train_opt(model, inp)) + compile_times.append(compile_time) + print(f"compile train time {i}: {compile_time}") +print("~" * 10) + +eager_med = np.median(eager_times) +compile_med = np.median(compile_times) +speedup = eager_med / compile_med +assert speedup > 1 +print( + f"(train) eager median: {eager_med}, compile median: {compile_med}, speedup: {speedup}x" +) +print("~" * 10) + +###################################################################### +# Again, we can see that ``torch.compile`` takes longer in the first +# iteration, as it must compile the model, but in subsequent iterations, we see +# significant speedups compared to eager. +# +# We remark that the speedup numbers presented in this tutorial are for +# demonstration purposes only. Official speedup values can be seen at the +# `TorchInductor performance dashboard `__. diff --git a/intermediate_source/torch_compile_tutorial.py b/intermediate_source/torch_compile_tutorial.py index de31af04dc1..650b97b6bda 100644 --- a/intermediate_source/torch_compile_tutorial.py +++ b/intermediate_source/torch_compile_tutorial.py @@ -7,215 +7,139 @@ """ ###################################################################### -# ``torch.compile`` is the latest method to speed up your PyTorch code! +# ``torch.compile`` is the new way to speed up your PyTorch code! # ``torch.compile`` makes PyTorch code run faster by # JIT-compiling PyTorch code into optimized kernels, -# all while requiring minimal code changes. +# while requiring minimal code changes. # -# In this tutorial, we cover basic ``torch.compile`` usage, -# and demonstrate the advantages of ``torch.compile`` over -# previous PyTorch compiler solutions, such as -# `TorchScript `__ and -# `FX Tracing `__. +# ``torch.compile`` accomplishes this by tracing through +# your Python code, looking for PyTorch operations. +# Code that is difficult to trace will result a +# **graph break**, which are lost optimization opportunities, rather +# than errors or silent incorrectness. +# +# ``torch.compile`` is available in PyTorch 2.0 and later. +# +# This introduction covers basic ``torch.compile`` usage +# and demonstrates the advantages of ``torch.compile`` over +# our previous PyTorch compiler solution, +# `TorchScript `__. +# +# For an end-to-end example on a real model, check out our `end-to-end ``torch.compile`` tutorial `__. # # **Contents** # # .. contents:: # :local: -# -# **Required pip Dependencies** + +# **Required pip dependencies for this tutorial** # # - ``torch >= 2.0`` -# - ``torchvision`` # - ``numpy`` # - ``scipy`` -# - ``tabulate`` # -# **System Requirements** +# **System requirements** # - A C++ compiler, such as ``g++`` # - Python development package (``python-devel``/``python-dev``) ###################################################################### -# NOTE: a modern NVIDIA GPU (H100, A100, or V100) is recommended for this tutorial in -# order to reproduce the speedup numbers shown below and documented elsewhere. +# Basic Usage +# ------------ +# +# We turn on some logging to help us to see what ``torch.compile`` is doing +# under the hood in this tutorial. +# The following code will print out the PyTorch ops that ``torch.compile`` traced. import torch -import warnings -gpu_ok = False -if torch.cuda.is_available(): - device_cap = torch.cuda.get_device_capability() - if device_cap in ((7, 0), (8, 0), (9, 0)): - gpu_ok = True - -if not gpu_ok: - warnings.warn( - "GPU is not NVIDIA V100, A100, or H100. Speedup numbers may be lower " - "than expected." - ) +torch._logging.set_logs(graph_code=True) ###################################################################### -# Basic Usage -# ------------ -# -# ``torch.compile`` is included in the latest PyTorch. -# Running TorchInductor on GPU requires Triton, which is included with the PyTorch 2.0 nightly -# binary. If Triton is still missing, try installing ``torchtriton`` via pip -# (``pip install torchtriton --extra-index-url "https://download.pytorch.org/whl/nightly/cu117"`` -# for CUDA 11.7). -# -# Arbitrary Python functions can be optimized by passing the callable to -# ``torch.compile``. We can then call the returned optimized -# function in place of the original function. +# ``torch.compile`` is a decorator that takes an arbitrary Python function. + def foo(x, y): a = torch.sin(x) b = torch.cos(y) return a + b + + opt_foo1 = torch.compile(foo) -print(opt_foo1(torch.randn(10, 10), torch.randn(10, 10))) +print(opt_foo1(torch.randn(3, 3), torch.randn(3, 3))) -###################################################################### -# Alternatively, we can decorate the function. -t1 = torch.randn(10, 10) -t2 = torch.randn(10, 10) @torch.compile def opt_foo2(x, y): a = torch.sin(x) b = torch.cos(y) return a + b -print(opt_foo2(t1, t2)) -###################################################################### -# We can also optimize ``torch.nn.Module`` instances. -t = torch.randn(10, 100) - -class MyModule(torch.nn.Module): - def __init__(self): - super().__init__() - self.lin = torch.nn.Linear(100, 10) - - def forward(self, x): - return torch.nn.functional.relu(self.lin(x)) - -mod = MyModule() -mod.compile() -print(mod(t)) -## or: -# opt_mod = torch.compile(mod) -# print(opt_mod(t)) +print(opt_foo2(torch.randn(3, 3), torch.randn(3, 3))) ###################################################################### -# torch.compile and Nested Calls -# ------------------------------ -# Nested function calls within the decorated function will also be compiled. +# ``torch.compile`` is applied recursively, so nested function calls +# within the top-level compiled function will also be compiled. + -def nested_function(x): +def inner(x): return torch.sin(x) + @torch.compile -def outer_function(x, y): - a = nested_function(x) +def outer(x, y): + a = inner(x) b = torch.cos(y) return a + b -print(outer_function(t1, t2)) + +print(outer(torch.randn(3, 3), torch.randn(3, 3))) + ###################################################################### -# In the same fashion, when compiling a module all sub-modules and methods -# within it, that are not in a skip list, are also compiled. +# We can also optimize ``torch.nn.Module`` instances by either calling +# its ``.compile()`` method or by directly ``torch.compile``-ing the module. +# This is equivalent to ``torch.compile``-ing the module's ``__call__`` method +# (which indirectly calls ``forward``). + +t = torch.randn(10, 100) -class OuterModule(torch.nn.Module): + +class MyModule(torch.nn.Module): def __init__(self): super().__init__() - self.inner_module = MyModule() - self.outer_lin = torch.nn.Linear(10, 2) + self.lin = torch.nn.Linear(3, 3) def forward(self, x): - x = self.inner_module(x) - return torch.nn.functional.relu(self.outer_lin(x)) + return torch.nn.functional.relu(self.lin(x)) -outer_mod = OuterModule() -outer_mod.compile() -print(outer_mod(t)) -###################################################################### -# We can also disable some functions from being compiled by using -# ``torch.compiler.disable``. Suppose you want to disable the tracing on just -# the ``complex_function`` function, but want to continue the tracing back in -# ``complex_conjugate``. In this case, you can use -# ``torch.compiler.disable(recursive=False)`` option. Otherwise, the default is -# ``recursive=True``. - -def complex_conjugate(z): - return torch.conj(z) - -@torch.compiler.disable(recursive=False) -def complex_function(real, imag): - # Assuming this function cause problems in the compilation - z = torch.complex(real, imag) - return complex_conjugate(z) - -def outer_function(): - real = torch.tensor([2, 3], dtype=torch.float32) - imag = torch.tensor([4, 5], dtype=torch.float32) - z = complex_function(real, imag) - return torch.abs(z) - -# Try to compile the outer_function -try: - opt_outer_function = torch.compile(outer_function) - print(opt_outer_function()) -except Exception as e: - print("Compilation of outer_function failed:", e) +mod1 = MyModule() +mod1.compile() +print(mod1(torch.randn(3, 3))) -###################################################################### -# Best Practices and Recommendations -# ---------------------------------- -# -# Behavior of ``torch.compile`` with Nested Modules and Function Calls -# -# When you use ``torch.compile``, the compiler will try to recursively compile -# every function call inside the target function or module inside the target -# function or module that is not in a skip list (such as built-ins, some functions in -# the torch.* namespace). -# -# **Best Practices:** -# -# 1. **Top-Level Compilation:** One approach is to compile at the highest level -# possible (i.e., when the top-level module is initialized/called) and -# selectively disable compilation when encountering excessive graph breaks or -# errors. If there are still many compile issues, compile individual -# subcomponents instead. -# -# 2. **Modular Testing:** Test individual functions and modules with ``torch.compile`` -# before integrating them into larger models to isolate potential issues. -# -# 3. **Disable Compilation Selectively:** If certain functions or sub-modules -# cannot be handled by `torch.compile`, use the `torch.compiler.disable` context -# managers to recursively exclude them from compilation. -# -# 4. **Compile Leaf Functions First:** In complex models with multiple nested -# functions and modules, start by compiling the leaf functions or modules first. -# For more information see `TorchDynamo APIs for fine-grained tracing `__. -# -# 5. **Prefer ``mod.compile()`` over ``torch.compile(mod)``:** Avoids ``_orig_`` prefix issues in ``state_dict``. -# -# 6. **Use ``fullgraph=True`` to catch graph breaks:** Helps ensure end-to-end compilation, maximizing speedup -# and compatibility with ``torch.export``. +mod2 = MyModule() +mod2 = torch.compile(mod2) +print(mod2(torch.randn(3, 3))) ###################################################################### # Demonstrating Speedups # ----------------------- # -# Let's now demonstrate that using ``torch.compile`` can speed -# up real models. We will compare standard eager mode and -# ``torch.compile`` by evaluating and training a ``torchvision`` model on random data. -# -# Before we start, we need to define some utility functions. +# Now let's demonstrate how ``torch.compile`` speeds up a simple PyTorch example. +# For a demonstration on a more complex model, see . + + +def foo3(x): + y = x + 1 + z = torch.nn.functional.relu(y) + u = z * 2 + return u + + +opt_foo3 = torch.compile(foo3) + # Returns the result of running `fn()` and the time it took for `fn()` to run, # in seconds. We use CUDA events and synchronization for the most accurate @@ -227,74 +151,47 @@ def timed(fn): result = fn() end.record() torch.cuda.synchronize() - return result, start.elapsed_time(end) / 1000 - -# Generates random input and targets data for the model, where `b` is -# batch size. -def generate_data(b): - return ( - torch.randn(b, 3, 128, 128).to(torch.float32).cuda(), - torch.randint(1000, (b,)).cuda(), - ) + return result, start.elapsed_time(end) / 1024 -N_ITERS = 10 -from torchvision.models import densenet121 -def init_model(): - return densenet121().to(torch.float32).cuda() +inp = torch.randn(4096, 4096).cuda() +print("compile:", timed(lambda: opt_foo3(inp))[1]) +print("eager:", timed(lambda: foo3(inp))[1]) ###################################################################### -# First, let's compare inference. -# -# Note that in the call to ``torch.compile``, we have the additional -# ``mode`` argument, which we will discuss below. - -model = init_model() - -# Reset since we are using a different mode. -import torch._dynamo -torch._dynamo.reset() - -model_opt = torch.compile(model, mode="reduce-overhead") - -inp = generate_data(16)[0] -with torch.no_grad(): - print("eager:", timed(lambda: model(inp))[1]) - print("compile:", timed(lambda: model_opt(inp))[1]) - -###################################################################### -# Notice that ``torch.compile`` takes a lot longer to complete -# compared to eager. This is because ``torch.compile`` compiles -# the model into optimized kernels as it executes. In our example, the -# structure of the model doesn't change, and so recompilation is not -# needed. So if we run our optimized model several more times, we should +# Notice that ``torch.compile`` appears to take a lot longer to complete +# compared to eager. This is because ``torch.compile`` takes extra time to compile +# the model on the first execution. +# ``torch.compile`` re-uses compiled code whever possible, +# so if we run our optimized model several more times, we should # see a significant improvement compared to eager. +# turn off logging for now to prevent spam +torch._logging.set_logs(graph_code=False) + eager_times = [] -for i in range(N_ITERS): - inp = generate_data(16)[0] - with torch.no_grad(): - _, eager_time = timed(lambda: model(inp)) +for i in range(10): + _, eager_time = timed(lambda: foo3(inp)) eager_times.append(eager_time) - print(f"eager eval time {i}: {eager_time}") - + print(f"eager time {i}: {eager_time}") print("~" * 10) compile_times = [] -for i in range(N_ITERS): - inp = generate_data(16)[0] - with torch.no_grad(): - _, compile_time = timed(lambda: model_opt(inp)) +for i in range(10): + _, compile_time = timed(lambda: opt_foo3(inp)) compile_times.append(compile_time) - print(f"compile eval time {i}: {compile_time}") + print(f"compile time {i}: {compile_time}") print("~" * 10) import numpy as np + eager_med = np.median(eager_times) compile_med = np.median(compile_times) speedup = eager_med / compile_med -assert(speedup > 1) -print(f"(eval) eager median: {eager_med}, compile median: {compile_med}, speedup: {speedup}x") +assert speedup > 1 +print( + f"(eval) eager median: {eager_med}, compile median: {compile_med}, speedup: {speedup}x" +) print("~" * 10) ###################################################################### @@ -305,151 +202,72 @@ def init_model(): # and the amount of data is large, then the bottleneck would be # GPU compute and the observed speedup may be less significant. # -# You may also see different speedup results depending on the chosen ``mode`` -# argument. The ``"reduce-overhead"`` mode uses CUDA graphs to further reduce -# the overhead of Python. For your own models, -# you may need to experiment with different modes to maximize speedup. You can -# read more about modes `here `__. -# -# You may might also notice that the second time we run our model with ``torch.compile`` is significantly -# slower than the other runs, although it is much faster than the first run. This is because the ``"reduce-overhead"`` -# mode runs a few warm-up iterations for CUDA graphs. -# -# For general PyTorch benchmarking, you can try using ``torch.utils.benchmark`` instead of the ``timed`` -# function we defined above. We wrote our own timing function in this tutorial to show -# ``torch.compile``'s compilation latency. -# -# Now, let's consider comparing training. - -model = init_model() -opt = torch.optim.Adam(model.parameters()) - -def train(mod, data): - opt.zero_grad(True) - pred = mod(data[0]) - loss = torch.nn.CrossEntropyLoss()(pred, data[1]) - loss.backward() - opt.step() - -eager_times = [] -for i in range(N_ITERS): - inp = generate_data(16) - _, eager_time = timed(lambda: train(model, inp)) - eager_times.append(eager_time) - print(f"eager train time {i}: {eager_time}") -print("~" * 10) - -model = init_model() -opt = torch.optim.Adam(model.parameters()) -train_opt = torch.compile(train, mode="reduce-overhead") - -compile_times = [] -for i in range(N_ITERS): - inp = generate_data(16) - _, compile_time = timed(lambda: train_opt(model, inp)) - compile_times.append(compile_time) - print(f"compile train time {i}: {compile_time}") -print("~" * 10) - -eager_med = np.median(eager_times) -compile_med = np.median(compile_times) -speedup = eager_med / compile_med -assert(speedup > 1) -print(f"(train) eager median: {eager_med}, compile median: {compile_med}, speedup: {speedup}x") -print("~" * 10) +# To see speedups on a real model, check out our `end-to-end ``torch.compile`` tutorial `__. ###################################################################### -# Again, we can see that ``torch.compile`` takes longer in the first -# iteration, as it must compile the model, but in subsequent iterations, we see -# significant speedups compared to eager. +# Benefits over TorchScript +# ------------------------- # -# We remark that the speedup numbers presented in this tutorial are for -# demonstration purposes only. Official speedup values can be seen at the -# `TorchInductor performance dashboard `__. - -###################################################################### -# Comparison to TorchScript and FX Tracing -# ----------------------------------------- -# -# We have seen that ``torch.compile`` can speed up PyTorch code. -# Why else should we use ``torch.compile`` over existing PyTorch -# compiler solutions, such as TorchScript or FX Tracing? Primarily, the +# Why should we use ``torch.compile`` over TorchScript? Primarily, the # advantage of ``torch.compile`` lies in its ability to handle # arbitrary Python code with minimal changes to existing code. # -# One case that ``torch.compile`` can handle that other compiler -# solutions struggle with is data-dependent control flow (the -# ``if x.sum() < 0:`` line below). +# Compare to TorchScript, which has a tracing mode (``torch.jit.trace``) and +# a scripting mode (``torch.jit.script``). Tracing mode is susceptible to +# silent incorrectness, while scripting mode requires significant code changes +# and will raise errors on unsupported Python code. +# +# For example, TorchScript tracing silently fails on data-dependent control flow +# (the ``if x.sum() < 0:`` line below) +# because only the actual control flow path is traced. +# In comparison, ``torch.compile`` is able to correctly handle it. + def f1(x, y): if x.sum() < 0: return -y return y -# Test that `fn1` and `fn2` return the same result, given -# the same arguments `args`. Typically, `fn1` will be an eager function -# while `fn2` will be a compiled function (torch.compile, TorchScript, or FX graph). + +# Test that `fn1` and `fn2` return the same result, given the same arguments `args`. def test_fns(fn1, fn2, args): out1 = fn1(*args) out2 = fn2(*args) return torch.allclose(out1, out2) + inp1 = torch.randn(5, 5) inp2 = torch.randn(5, 5) -###################################################################### -# TorchScript tracing ``f1`` results in -# silently incorrect results, since only the actual control flow path -# is traced. - traced_f1 = torch.jit.trace(f1, (inp1, inp2)) print("traced 1, 1:", test_fns(f1, traced_f1, (inp1, inp2))) print("traced 1, 2:", test_fns(f1, traced_f1, (-inp1, inp2))) -###################################################################### -# FX tracing ``f1`` results in an error due to the presence of -# data-dependent control flow. - -import traceback as tb -try: - torch.fx.symbolic_trace(f1) -except: - tb.print_exc() - -###################################################################### -# If we provide a value for ``x`` as we try to FX trace ``f1``, then -# we run into the same problem as TorchScript tracing, as the data-dependent -# control flow is removed in the traced function. - -fx_f1 = torch.fx.symbolic_trace(f1, concrete_args={"x": inp1}) -print("fx 1, 1:", test_fns(f1, fx_f1, (inp1, inp2))) -print("fx 1, 2:", test_fns(f1, fx_f1, (-inp1, inp2))) - -###################################################################### -# Now we can see that ``torch.compile`` correctly handles -# data-dependent control flow. - -# Reset since we are using a different mode. -torch._dynamo.reset() - compile_f1 = torch.compile(f1) print("compile 1, 1:", test_fns(f1, compile_f1, (inp1, inp2))) print("compile 1, 2:", test_fns(f1, compile_f1, (-inp1, inp2))) print("~" * 10) ###################################################################### -# TorchScript scripting can handle data-dependent control flow, but this -# solution comes with its own set of problems. Namely, TorchScript scripting -# can require major code changes and will raise errors when unsupported Python +# TorchScript scripting can handle data-dependent control flow, +# but it can require major code changes and will raise errors when unsupported Python # is used. # # In the example below, we forget TorchScript type annotations and we receive # a TorchScript error because the input type for argument ``y``, an ``int``, # does not match with the default argument type, ``torch.Tensor``. +# In comparison, ``torch.compile`` works without requiring any type annotations. + + +import traceback as tb + +torch._logging.set_logs(graph_code=True) + def f2(x, y): return x + y + inp1 = torch.randn(5, 5) inp2 = 3 @@ -459,18 +277,17 @@ def f2(x, y): except: tb.print_exc() -###################################################################### -# However, ``torch.compile`` is easily able to handle ``f2``. - compile_f2 = torch.compile(f2) print("compile 2:", test_fns(f2, compile_f2, (inp1, inp2))) print("~" * 10) ###################################################################### # Another case that ``torch.compile`` handles well compared to -# previous compilers solutions is the usage of non-PyTorch functions. +# both TorchScript tracing and scripting is the usage of third-party library functions. import scipy + + def f3(x): x = x * 2 x = scipy.fft.dct(x.numpy()) @@ -478,67 +295,43 @@ def f3(x): x = x * 2 return x + ###################################################################### # TorchScript tracing treats results from non-PyTorch function calls # as constants, and so our results can be silently wrong. +# TorchScript scripting disallows non-PyTorch function calls. +# On the other hand, ``torch.compile`` is easily able to handle +# the non-PyTorch function call. + inp1 = torch.randn(5, 5) inp2 = torch.randn(5, 5) traced_f3 = torch.jit.trace(f3, (inp1,)) print("traced 3:", test_fns(f3, traced_f3, (inp2,))) -###################################################################### -# TorchScript scripting and FX tracing disallow non-PyTorch function calls. - try: torch.jit.script(f3) except: tb.print_exc() -try: - torch.fx.symbolic_trace(f3) -except: - tb.print_exc() - -###################################################################### -# In comparison, ``torch.compile`` is easily able to handle -# the non-PyTorch function call. - compile_f3 = torch.compile(f3) print("compile 3:", test_fns(f3, compile_f3, (inp2,))) + ###################################################################### -# TorchDynamo and FX Graphs -# -------------------------- +# Graph Breaks +# ------------------------------------ +# The graph break is one of the most fundamental concepts within ``torch.compile``. +# It allows ``torch.compile`` to handle arbitrary Python code by interrupting +# compilation, running the unsupported code, then resuming compilation. +# The term "graph break" comes from the fact that ``torch.compile`` attempts +# to capture and optimize the PyTorch operation graph. When unsupported Python code is encountered, +# then this graph must be "broken". +# Graph breaks result in lost optimization opportunities, which may still be undesirable, +# but this is better than silent incorrectness or a hard crash. # -# One important component of ``torch.compile`` is TorchDynamo. -# TorchDynamo is responsible for JIT compiling arbitrary Python code into -# `FX graphs `__, which can -# then be further optimized. TorchDynamo extracts FX graphs by analyzing Python bytecode -# during runtime and detecting calls to PyTorch operations. -# -# Normally, TorchInductor, another component of ``torch.compile``, -# further compiles the FX graphs into optimized kernels, -# but TorchDynamo allows for different backends to be used. In order to inspect -# the FX graphs that TorchDynamo outputs, let us create a custom backend that -# outputs the FX graph and simply returns the graph's unoptimized forward method. - -from typing import List -def custom_backend(gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]): - print("custom backend called with FX graph:") - gm.graph.print_tabular() - return gm.forward - -# Reset since we are using a different backend. -torch._dynamo.reset() +# Let's look at a data-dependent control flow example to better see how graph breaks work. -opt_model = torch.compile(init_model(), backend=custom_backend) -opt_model(generate_data(16)[0]) - -###################################################################### -# Using our custom backend, we can now see how TorchDynamo is able to handle -# data-dependent control flow. Consider the function below, where the line -# ``if b.sum() < 0`` is the source of data-dependent control flow. def bar(a, b): x = a / (torch.abs(a) + 1) @@ -546,23 +339,24 @@ def bar(a, b): b = b * -1 return x * b -opt_bar = torch.compile(bar, backend=custom_backend) -inp1 = torch.randn(10) -inp2 = torch.randn(10) + +opt_bar = torch.compile(bar) +inp1 = torch.ones(10) +inp2 = torch.ones(10) opt_bar(inp1, inp2) opt_bar(inp1, -inp2) ###################################################################### -# The output reveals that TorchDynamo extracted 3 different FX graphs -# corresponding the following code (order may differ from the output above): +# The first time we run ``bar``, we see that ``torch.compile`` traced 2 graphs +# corresponding to the following code (noting that ``b.sum() < 0`` is False): # -# 1. ``x = a / (torch.abs(a) + 1)`` -# 2. ``b = b * -1; return x * b`` -# 3. ``return x * b`` +# 1. ``x = a / (torch.abs(a) + 1); b.sum()`` +# 2. ``return x * b`` # -# When TorchDynamo encounters unsupported Python features, such as data-dependent -# control flow, it breaks the computation graph, lets the default Python -# interpreter handle the unsupported code, then resumes capturing the graph. +# The second time we run ``bar``, we take the other branch of the if statement +# and we get 1 traced graph corresponding to the code ``b = b * -1; return x * b``. +# We do not see a graph of ``x = a / (torch.abs(a) + 1)`` outputted the second time +# since ``torch.compile`` cached this graph from the first run and re-used it. # # Let's investigate by example how TorchDynamo would step through ``bar``. # If ``b.sum() < 0``, then TorchDynamo would run graph 1, let @@ -571,49 +365,78 @@ def bar(a, b): # would run graph 1, let Python determine the result of the conditional, then # run graph 3. # -# This highlights a major difference between TorchDynamo and previous PyTorch -# compiler solutions. When encountering unsupported Python features, -# previous solutions either raise an error or silently fail. -# TorchDynamo, on the other hand, will break the computation graph. -# -# We can see where TorchDynamo breaks the graph by using ``torch._dynamo.explain``: +# We can see all graph breaks by using ``torch._logging.set_logs(graph_breaks=True)``. -# Reset since we are using a different backend. +# Reset to clear the torch.compile cache torch._dynamo.reset() -explain_output = torch._dynamo.explain(bar)(torch.randn(10), torch.randn(10)) -print(explain_output) +opt_bar(inp1, inp2) +opt_bar(inp1, -inp2) ###################################################################### # In order to maximize speedup, graph breaks should be limited. # We can force TorchDynamo to raise an error upon the first graph # break encountered by using ``fullgraph=True``: -opt_bar = torch.compile(bar, fullgraph=True) +# Reset to clear the torch.compile cache +torch._dynamo.reset() + +opt_bar_fullgraph = torch.compile(bar, fullgraph=True) try: - opt_bar(torch.randn(10), torch.randn(10)) + opt_bar_fullgraph(torch.randn(10), torch.randn(10)) except: tb.print_exc() ###################################################################### -# And below, we demonstrate that TorchDynamo does not break the graph on -# the model we used above for demonstrating speedups. +# In our example above, we can work around this graph break by replacing +# the if statement with a ``torch.cond``: + +from functorch.experimental.control_flow import cond + + +@torch.compile(fullgraph=True) +def bar_fixed(a, b): + x = a / (torch.abs(a) + 1) + + def true_branch(y): + return y * -1 + + def false_branch(y): + # NOTE: torch.cond doesn't allow aliased outputs + return y.clone() + + x = cond(b.sum() < 0, true_branch, false_branch, (b,)) + return x * b + + +bar_fixed(inp1, inp2) +bar_fixed(inp1, -inp2) -opt_model = torch.compile(init_model(), fullgraph=True) -print(opt_model(generate_data(16)[0])) ###################################################################### -# We can use ``torch.export`` (from PyTorch 2.1+) to extract a single, exportable -# FX graph from the input PyTorch program. The exported graph is intended to be -# run on different (i.e. Python-less) environments. One important restriction -# is that the ``torch.export`` does not support graph breaks. Please check +# In order to serialize graphs or to run graphs on different (i.e. Python-less) +# environments, consider using ``torch.export`` instead (from PyTorch 2.1+). +# One important restriction is that ``torch.export`` does not support graph breaks. Please check # `this tutorial `__ # for more details on ``torch.export``. +###################################################################### +# Troubleshooting +# --------------- +# Is ``torch.compile`` failing to speed up your model? Is compile time unreasonably long? +# Is your code recompiling excessively? Are you having difficulties dealing with graph breaks? +# Are you looking for tips on how to best use ``torch.compile``? +# Or maybe you simply want to learn more about the inner workings of ``torch.compile``? +# +# Check out `the ``torch.compile`` troubleshooting guide `__! + ###################################################################### # Conclusion # ------------ # # In this tutorial, we introduced ``torch.compile`` by covering -# basic usage, demonstrating speedups over eager mode, comparing to previous -# PyTorch compiler solutions, and briefly investigating TorchDynamo and its interactions -# with FX graphs. We hope that you will give ``torch.compile`` a try! +# basic usage, demonstrating speedups over eager mode, comparing to TorchScript, +# and briefly describing graph breaks. +# +# For an end-to-end example on a real model, check out our `end-to-end ``torch.compile`` tutorial `__. +# +# We hope that you will give ``torch.compile`` a try!