Skip to content

Commit e486157

Browse files
authored
Merge branch 'main' into mlazos/log-recipe
2 parents 39784aa + 1ace4ee commit e486157

File tree

2 files changed

+98
-0
lines changed

2 files changed

+98
-0
lines changed
Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
(beta) Compiling the optimizer with torch.compile
2+
==========================================================================================
3+
4+
**Author:** `Michael Lazos <https://github.com/mlazos>`_
5+
6+
The optimizer is a key algorithm for training any deep learning model.
7+
Since it is responsible for updating every model parameter, it can often
8+
become the bottleneck in training performance for large models. In this recipe,
9+
we will apply ``torch.compile`` to the optimizer to observe the GPU performance
10+
improvement.
11+
12+
.. note::
13+
14+
This tutorial requires PyTorch 2.2.0 or later.
15+
16+
Model Setup
17+
~~~~~~~~~~~~~~~~~~~~~
18+
For this example, we'll use a simple sequence of linear layers.
19+
Since we are only benchmarking the optimizer, the choice of model doesn't matter
20+
because optimizer performance is a function of the number of parameters.
21+
22+
Depending on what machine you are using, your exact results may vary.
23+
24+
.. code-block:: python
25+
26+
import torch
27+
28+
model = torch.nn.Sequential(
29+
*[torch.nn.Linear(1024, 1024, False, device="cuda") for _ in range(10)]
30+
)
31+
input = torch.rand(1024, device="cuda")
32+
output = model(input)
33+
output.sum().backward()
34+
35+
Setting up and running the optimizer benchmark
36+
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
37+
In this example, we'll use the Adam optimizer
38+
and create a helper function to wrap the step()
39+
in ``torch.compile()``.
40+
41+
.. note::
42+
43+
``torch.compile`` is only supported on cuda devices with compute capability >= 7.0
44+
45+
.. code-block:: python
46+
47+
# exit cleanly if we are on a device that doesn't support torch.compile
48+
if torch.cuda.get_device_capability() < (7, 0):
49+
print("Exiting because torch.compile is not supported on this device.")
50+
import sys
51+
sys.exit(0)
52+
53+
54+
opt = torch.optim.Adam(model.parameters(), lr=0.01)
55+
56+
57+
@torch.compile(fullgraph=False)
58+
def fn():
59+
opt.step()
60+
61+
62+
# Let's define a helpful benchmarking function:
63+
import torch.utils.benchmark as benchmark
64+
65+
66+
def benchmark_torch_function_in_microseconds(f, *args, **kwargs):
67+
t0 = benchmark.Timer(
68+
stmt="f(*args, **kwargs)", globals={"args": args, "kwargs": kwargs, "f": f}
69+
)
70+
return t0.blocked_autorange().mean * 1e6
71+
72+
73+
# Warmup runs to compile the function
74+
for _ in range(5):
75+
fn()
76+
77+
eager_runtime = benchmark_torch_function_in_microseconds(opt.step)
78+
compiled_runtime = benchmark_torch_function_in_microseconds(fn)
79+
80+
assert eager_runtime > compiled_runtime
81+
82+
print(f"eager runtime: {eager_runtime}us")
83+
print(f"compiled runtime: {compiled_runtime}us")
84+
85+
Sample Results:
86+
87+
* Eager runtime: 747.2437149845064us
88+
* Compiled runtime: 392.07384741178us

recipes_source/recipes_index.rst

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -284,6 +284,15 @@ Recipes are bite-sized, actionable examples of how to use specific PyTorch featu
284284
:link: ../recipes/amx.html
285285
:tags: Model-Optimization
286286

287+
.. (beta) Compiling the Optimizer with torch.compile
288+
289+
.. customcarditem::
290+
:header: (beta) Compiling the Optimizer with torch.compile
291+
:card_description: Speed up the optimizer using torch.compile
292+
:image: ../_static/img/thumbnails/cropped/generic-pytorch-logo.png
293+
:link: ../recipes/compiling_optimizer.html
294+
:tags: Model-Optimization
295+
287296
.. Intel(R) Extension for PyTorch*
288297
289298
.. customcarditem::
@@ -377,6 +386,7 @@ Recipes are bite-sized, actionable examples of how to use specific PyTorch featu
377386
/recipes/recipes/amp_recipe
378387
/recipes/recipes/tuning_guide
379388
/recipes/recipes/intel_extension_for_pytorch
389+
/recipes/compiling_optimizer
380390
/recipes/torch_compile_backend_ipex
381391
/recipes/torchscript_inference
382392
/recipes/deployment_with_flask

0 commit comments

Comments
 (0)