|
| 1 | +(beta) Compiling the optimizer with torch.compile |
| 2 | +========================================================================================== |
| 3 | + |
| 4 | +**Author:** `Michael Lazos <https://github.com/mlazos>`_ |
| 5 | + |
| 6 | +The optimizer is a key algorithm for training any deep learning model. |
| 7 | +Since it is responsible for updating every model parameter, it can often |
| 8 | +become the bottleneck in training performance for large models. In this recipe, |
| 9 | +we will apply ``torch.compile`` to the optimizer to observe the GPU performance |
| 10 | +improvement. |
| 11 | + |
| 12 | +.. note:: |
| 13 | + |
| 14 | + This tutorial requires PyTorch 2.2.0 or later. |
| 15 | + |
| 16 | +Model Setup |
| 17 | +~~~~~~~~~~~~~~~~~~~~~ |
| 18 | +For this example, we'll use a simple sequence of linear layers. |
| 19 | +Since we are only benchmarking the optimizer, the choice of model doesn't matter |
| 20 | +because optimizer performance is a function of the number of parameters. |
| 21 | + |
| 22 | +Depending on what machine you are using, your exact results may vary. |
| 23 | + |
| 24 | +.. code-block:: python |
| 25 | +
|
| 26 | + import torch |
| 27 | + |
| 28 | + model = torch.nn.Sequential( |
| 29 | + *[torch.nn.Linear(1024, 1024, False, device="cuda") for _ in range(10)] |
| 30 | + ) |
| 31 | + input = torch.rand(1024, device="cuda") |
| 32 | + output = model(input) |
| 33 | + output.sum().backward() |
| 34 | +
|
| 35 | +Setting up and running the optimizer benchmark |
| 36 | +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ |
| 37 | +In this example, we'll use the Adam optimizer |
| 38 | +and create a helper function to wrap the step() |
| 39 | +in ``torch.compile()``. |
| 40 | + |
| 41 | +.. note:: |
| 42 | + |
| 43 | + ``torch.compile`` is only supported on cuda devices with compute capability >= 7.0 |
| 44 | + |
| 45 | +.. code-block:: python |
| 46 | +
|
| 47 | + # exit cleanly if we are on a device that doesn't support torch.compile |
| 48 | + if torch.cuda.get_device_capability() < (7, 0): |
| 49 | + print("Exiting because torch.compile is not supported on this device.") |
| 50 | + import sys |
| 51 | + sys.exit(0) |
| 52 | +
|
| 53 | +
|
| 54 | + opt = torch.optim.Adam(model.parameters(), lr=0.01) |
| 55 | +
|
| 56 | +
|
| 57 | + @torch.compile(fullgraph=False) |
| 58 | + def fn(): |
| 59 | + opt.step() |
| 60 | + |
| 61 | + |
| 62 | + # Let's define a helpful benchmarking function: |
| 63 | + import torch.utils.benchmark as benchmark |
| 64 | + |
| 65 | + |
| 66 | + def benchmark_torch_function_in_microseconds(f, *args, **kwargs): |
| 67 | + t0 = benchmark.Timer( |
| 68 | + stmt="f(*args, **kwargs)", globals={"args": args, "kwargs": kwargs, "f": f} |
| 69 | + ) |
| 70 | + return t0.blocked_autorange().mean * 1e6 |
| 71 | +
|
| 72 | +
|
| 73 | + # Warmup runs to compile the function |
| 74 | + for _ in range(5): |
| 75 | + fn() |
| 76 | + |
| 77 | + eager_runtime = benchmark_torch_function_in_microseconds(opt.step) |
| 78 | + compiled_runtime = benchmark_torch_function_in_microseconds(fn) |
| 79 | + |
| 80 | + assert eager_runtime > compiled_runtime |
| 81 | + |
| 82 | + print(f"eager runtime: {eager_runtime}us") |
| 83 | + print(f"compiled runtime: {compiled_runtime}us") |
| 84 | +
|
| 85 | +Sample Results: |
| 86 | + |
| 87 | +* Eager runtime: 747.2437149845064us |
| 88 | +* Compiled runtime: 392.07384741178us |
0 commit comments