|
1 |
| -""" |
2 | 1 | (beta) Compiling the optimizer with torch.compile
|
3 | 2 | ==========================================================================================
|
4 | 3 |
|
5 |
| - |
6 | 4 | **Author:** `Michael Lazos <https://github.com/mlazos>`_
|
7 |
| -""" |
8 |
| - |
9 |
| -###################################################################### |
10 |
| -# |
11 |
| -# The optimizer is a key algorithm for training any deep learning model. |
12 |
| -# Since it is responsible for updating every model parameter, it can often |
13 |
| -# become the bottleneck in training performance for large models. In this recipe, |
14 |
| -# we will apply ``torch.compile`` to the optimizer to observe the GPU performance |
15 |
| -# improvement. |
16 |
| -# |
17 |
| -# .. note:: |
18 |
| -# |
19 |
| -# This tutorial requires PyTorch 2.2.0 or later. |
20 |
| -# |
21 |
| - |
22 |
| - |
23 |
| -###################################################################### |
24 |
| -# Model Setup |
25 |
| -# ~~~~~~~~~~~~~~~~~~~~~ |
26 |
| -# For this example, we'll use a simple sequence of linear layers. |
27 |
| -# Since we are only benchmarking the optimizer, the choice of model doesn't matter |
28 |
| -# because optimizer performance is a function of the number of parameters. |
29 |
| -# |
30 |
| -# Depending on what machine you are using, your exact results may vary. |
31 |
| - |
32 |
| -import torch |
33 |
| - |
34 |
| -model = torch.nn.Sequential( |
35 |
| - *[torch.nn.Linear(1024, 1024, False, device="cuda") for _ in range(10)] |
36 |
| -) |
37 |
| -input = torch.rand(1024, device="cuda") |
38 |
| -output = model(input) |
39 |
| -output.sum().backward() |
40 |
| -
|
41 |
| -############################################################################# |
42 |
| -# Setting up and running the optimizer benchmark |
43 |
| -# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ |
44 |
| -# In this example, we'll use the Adam optimizer |
45 |
| -# and create a helper function to wrap the step() |
46 |
| -# in torch.compile() |
47 |
| -# |
48 |
| -# .. note:: |
49 |
| -# |
50 |
| -# torch.compile is only supported on cuda devices with compute capability >= 7.0 |
51 |
| - |
52 |
| -# exit cleanly if we are on a device that doesn't support torch.compile |
53 |
| -if torch.cuda.get_device_capability() < (7, 0): |
54 |
| - print("Exiting because torch.compile is not supported on this device.") |
55 |
| - import sys |
56 |
| - sys.exit(0) |
57 |
| - |
58 |
| - |
59 |
| -opt = torch.optim.Adam(model.parameters(), lr=0.01) |
60 |
| - |
61 |
| - |
62 |
| -@torch.compile(fullgraph=False) |
63 |
| -def fn(): |
64 |
| - opt.step() |
65 |
| - |
66 |
| - |
67 |
| -# Let's define a helpful benchmarking function: |
68 |
| -import torch.utils.benchmark as benchmark |
69 |
| - |
70 |
| - |
71 |
| -def benchmark_torch_function_in_microseconds(f, *args, **kwargs): |
72 |
| - t0 = benchmark.Timer( |
73 |
| - stmt="f(*args, **kwargs)", globals={"args": args, "kwargs": kwargs, "f": f} |
74 |
| - ) |
75 |
| - return t0.blocked_autorange().mean * 1e6 |
76 |
| -
|
77 |
| - |
78 |
| -# Warmup runs to compile the function |
79 |
| -for _ in range(5): |
80 |
| - fn() |
81 |
| - |
82 |
| -eager_runtime = benchmark_torch_function_in_microseconds(opt.step) |
83 |
| -compiled_runtime = benchmark_torch_function_in_microseconds(fn) |
84 |
| - |
85 |
| -assert eager_runtime > compiled_runtime |
86 |
| - |
87 |
| -print(f"eager runtime: {eager_runtime}us") |
88 |
| -print(f"compiled runtime: {compiled_runtime}us") |
89 | 5 |
|
90 |
| -# Sample Results: |
91 |
| -# eager runtime: 747.2437149845064us |
92 |
| -# compiled runtime: 392.07384741178us |
| 6 | +The optimizer is a key algorithm for training any deep learning model. |
| 7 | +Since it is responsible for updating every model parameter, it can often |
| 8 | +become the bottleneck in training performance for large models. In this recipe, |
| 9 | +we will apply ``torch.compile`` to the optimizer to observe the GPU performance |
| 10 | +improvement. |
| 11 | + |
| 12 | +.. note:: |
| 13 | + |
| 14 | + This tutorial requires PyTorch 2.2.0 or later. |
| 15 | + |
| 16 | +Model Setup |
| 17 | +~~~~~~~~~~~~~~~~~~~~~ |
| 18 | +For this example, we'll use a simple sequence of linear layers. |
| 19 | +Since we are only benchmarking the optimizer, the choice of model doesn't matter |
| 20 | +because optimizer performance is a function of the number of parameters. |
| 21 | + |
| 22 | +Depending on what machine you are using, your exact results may vary. |
| 23 | + |
| 24 | +.. code-block:: python |
| 25 | +
|
| 26 | + import torch |
| 27 | + |
| 28 | + model = torch.nn.Sequential( |
| 29 | + *[torch.nn.Linear(1024, 1024, False, device="cuda") for _ in range(10)] |
| 30 | + ) |
| 31 | + input = torch.rand(1024, device="cuda") |
| 32 | + output = model(input) |
| 33 | + output.sum().backward() |
| 34 | +
|
| 35 | +Setting up and running the optimizer benchmark |
| 36 | +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ |
| 37 | +In this example, we'll use the Adam optimizer |
| 38 | +and create a helper function to wrap the step() |
| 39 | +in ``torch.compile()``. |
| 40 | + |
| 41 | +.. note:: |
| 42 | + |
| 43 | + ``torch.compile`` is only supported on cuda devices with compute capability >= 7.0 |
| 44 | + |
| 45 | +.. code-block:: python |
| 46 | +
|
| 47 | + # exit cleanly if we are on a device that doesn't support torch.compile |
| 48 | + if torch.cuda.get_device_capability() < (7, 0): |
| 49 | + print("Exiting because torch.compile is not supported on this device.") |
| 50 | + import sys |
| 51 | + sys.exit(0) |
| 52 | +
|
| 53 | +
|
| 54 | + opt = torch.optim.Adam(model.parameters(), lr=0.01) |
| 55 | +
|
| 56 | +
|
| 57 | + @torch.compile(fullgraph=False) |
| 58 | + def fn(): |
| 59 | + opt.step() |
| 60 | + |
| 61 | + |
| 62 | + # Let's define a helpful benchmarking function: |
| 63 | + import torch.utils.benchmark as benchmark |
| 64 | + |
| 65 | + |
| 66 | + def benchmark_torch_function_in_microseconds(f, *args, **kwargs): |
| 67 | + t0 = benchmark.Timer( |
| 68 | + stmt="f(*args, **kwargs)", globals={"args": args, "kwargs": kwargs, "f": f} |
| 69 | + ) |
| 70 | + return t0.blocked_autorange().mean * 1e6 |
| 71 | +
|
| 72 | +
|
| 73 | + # Warmup runs to compile the function |
| 74 | + for _ in range(5): |
| 75 | + fn() |
| 76 | + |
| 77 | + eager_runtime = benchmark_torch_function_in_microseconds(opt.step) |
| 78 | + compiled_runtime = benchmark_torch_function_in_microseconds(fn) |
| 79 | + |
| 80 | + assert eager_runtime > compiled_runtime |
| 81 | + |
| 82 | + print(f"eager runtime: {eager_runtime}us") |
| 83 | + print(f"compiled runtime: {compiled_runtime}us") |
| 84 | +
|
| 85 | +Sample Results: |
| 86 | + |
| 87 | +* Eager runtime: 747.2437149845064us |
| 88 | +* Compiled runtime: 392.07384741178us |
0 commit comments