Skip to content

Commit 618efd9

Browse files
mlazossvekarswilliamwen42malfet
authored andcommitted
Added optimizer compile recipe (#2700)
* Added optimizer compile tutorial --------- Co-authored-by: Svetlana Karslioglu <svekars@meta.com> Co-authored-by: William Wen <williamwen@meta.com> Co-authored-by: Nikita Shulga <2453524+malfet@users.noreply.github.com>
1 parent e1a9c5b commit 618efd9

File tree

2 files changed

+98
-0
lines changed

2 files changed

+98
-0
lines changed
Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
(beta) Compiling the optimizer with torch.compile
2+
==========================================================================================
3+
4+
**Author:** `Michael Lazos <https://github.com/mlazos>`_
5+
6+
The optimizer is a key algorithm for training any deep learning model.
7+
Since it is responsible for updating every model parameter, it can often
8+
become the bottleneck in training performance for large models. In this recipe,
9+
we will apply ``torch.compile`` to the optimizer to observe the GPU performance
10+
improvement.
11+
12+
.. note::
13+
14+
This tutorial requires PyTorch 2.2.0 or later.
15+
16+
Model Setup
17+
~~~~~~~~~~~~~~~~~~~~~
18+
For this example, we'll use a simple sequence of linear layers.
19+
Since we are only benchmarking the optimizer, the choice of model doesn't matter
20+
because optimizer performance is a function of the number of parameters.
21+
22+
Depending on what machine you are using, your exact results may vary.
23+
24+
.. code-block:: python
25+
26+
import torch
27+
28+
model = torch.nn.Sequential(
29+
*[torch.nn.Linear(1024, 1024, False, device="cuda") for _ in range(10)]
30+
)
31+
input = torch.rand(1024, device="cuda")
32+
output = model(input)
33+
output.sum().backward()
34+
35+
Setting up and running the optimizer benchmark
36+
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
37+
In this example, we'll use the Adam optimizer
38+
and create a helper function to wrap the step()
39+
in ``torch.compile()``.
40+
41+
.. note::
42+
43+
``torch.compile`` is only supported on cuda devices with compute capability >= 7.0
44+
45+
.. code-block:: python
46+
47+
# exit cleanly if we are on a device that doesn't support torch.compile
48+
if torch.cuda.get_device_capability() < (7, 0):
49+
print("Exiting because torch.compile is not supported on this device.")
50+
import sys
51+
sys.exit(0)
52+
53+
54+
opt = torch.optim.Adam(model.parameters(), lr=0.01)
55+
56+
57+
@torch.compile(fullgraph=False)
58+
def fn():
59+
opt.step()
60+
61+
62+
# Let's define a helpful benchmarking function:
63+
import torch.utils.benchmark as benchmark
64+
65+
66+
def benchmark_torch_function_in_microseconds(f, *args, **kwargs):
67+
t0 = benchmark.Timer(
68+
stmt="f(*args, **kwargs)", globals={"args": args, "kwargs": kwargs, "f": f}
69+
)
70+
return t0.blocked_autorange().mean * 1e6
71+
72+
73+
# Warmup runs to compile the function
74+
for _ in range(5):
75+
fn()
76+
77+
eager_runtime = benchmark_torch_function_in_microseconds(opt.step)
78+
compiled_runtime = benchmark_torch_function_in_microseconds(fn)
79+
80+
assert eager_runtime > compiled_runtime
81+
82+
print(f"eager runtime: {eager_runtime}us")
83+
print(f"compiled runtime: {compiled_runtime}us")
84+
85+
Sample Results:
86+
87+
* Eager runtime: 747.2437149845064us
88+
* Compiled runtime: 392.07384741178us

recipes_source/recipes_index.rst

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -276,6 +276,15 @@ Recipes are bite-sized, actionable examples of how to use specific PyTorch featu
276276
:link: ../recipes/amx.html
277277
:tags: Model-Optimization
278278

279+
.. (beta) Compiling the Optimizer with torch.compile
280+
281+
.. customcarditem::
282+
:header: (beta) Compiling the Optimizer with torch.compile
283+
:card_description: Speed up the optimizer using torch.compile
284+
:image: ../_static/img/thumbnails/cropped/generic-pytorch-logo.png
285+
:link: ../recipes/compiling_optimizer.html
286+
:tags: Model-Optimization
287+
279288
.. Intel(R) Extension for PyTorch*
280289
281290
.. customcarditem::
@@ -368,6 +377,7 @@ Recipes are bite-sized, actionable examples of how to use specific PyTorch featu
368377
/recipes/recipes/amp_recipe
369378
/recipes/recipes/tuning_guide
370379
/recipes/recipes/intel_extension_for_pytorch
380+
/recipes/compiling_optimizer
371381
/recipes/torch_compile_backend_ipex
372382
/recipes/torchscript_inference
373383
/recipes/deployment_with_flask

0 commit comments

Comments
 (0)