|
7 | 7 |
|
8 | 8 | #########################################################
|
9 | 9 | # Horizontal fusion is a key optimization in ML compilers. In eager,
|
10 |
| -# this is typically expressed using the torch._foreach* ops which paralellizes |
11 |
| -# operations across a list of tensors. However, supporting all possible permuatations |
| 10 | +# this is typically expressed using the torch._foreach* ops which parallelizes |
| 11 | +# operations across a list of tensors. However, supporting all possible permutations |
12 | 12 | # of arguments is quite difficult (e.g. mixtures of scalars and lists). Foreach_map
|
13 |
| -# allows conversion of any pointwise op in torch to a horiztonally fused foreach |
14 |
| -# variant. In this tutorial, we will demonstrate how implement the Adam optimizer |
| 13 | +# allows conversion of any pointwise op in ``torch`` to a horiztonally fused foreach |
| 14 | +# variant. In this tutorial, we will demonstrate how to implement the Adam optimizer |
15 | 15 | # with ``foreach_map`` to generate a fully fused kernel.
|
16 | 16 | #
|
17 | 17 | #
|
18 | 18 | # .. note::
|
19 | 19 | #
|
20 |
| -# This tutorial requires PyTorch 2.6.0 or later. |
| 20 | +# This tutorial requires PyTorch 2.7.0 or later. |
21 | 21 |
|
22 | 22 | #####################################################################
|
23 | 23 | # Model Setup
|
24 | 24 | # ~~~~~~~~~~~~~~~~~~~~~
|
25 | 25 | # For this example, we'll use a simple sequence of linear layers.
|
26 | 26 | # We instantiate an independent copy to compare the two optimizer implementations.
|
27 | 27 | #
|
| 28 | +import torch |
28 | 29 |
|
29 | 30 | # exit cleanly if we are on a device that doesn't support ``torch.compile``
|
30 | 31 | if torch.cuda.get_device_capability() < (7, 0):
|
31 | 32 | print("Exiting because torch.compile is not supported on this device.")
|
32 | 33 | import sys
|
33 | 34 | sys.exit(0)
|
34 | 35 |
|
35 |
| -import torch |
36 |
| - |
37 | 36 | # Create simple model
|
38 | 37 | model = torch.nn.Sequential(
|
39 | 38 | *[torch.nn.Linear(1024, 1024, False, device="cuda") for _ in range(10)]
|
|
55 | 54 | # Helper functions for foreach_map implementation
|
56 | 55 | # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
57 | 56 | #
|
58 |
| -# In this section, we'll begin out implementation of the Adam optimizer. |
| 57 | +# In this section, we'll begin our implementation of the Adam optimizer. |
59 | 58 | #
|
60 | 59 | from torch._higher_order_ops.foreach_map import foreach_map
|
61 | 60 |
|
@@ -89,7 +88,7 @@ def update_param(param, step, exp_avg, exp_avg_sq, beta1, beta2, lr, eps):
|
89 | 88 | denom = (exp_avg_sq.sqrt() / (bias_correction2 * step_size)).add(eps / step_size)
|
90 | 89 | return torch.add(param, torch.div(exp_avg, denom))
|
91 | 90 |
|
92 |
| -# Our full adam implementation |
| 91 | +# Our full Adam implementation |
93 | 92 | def foreach_map_adam(
|
94 | 93 | steps,
|
95 | 94 | params,
|
@@ -166,7 +165,11 @@ def foreach_map_adam(
|
166 | 165 | ######################################################################
|
167 | 166 | # Conclusion
|
168 | 167 | # ~~~~~~~~~~
|
169 |
| -# In this tutorial, we implemented a custom fully fused Adam optimizer using foreach_map. |
| 168 | +# In this tutorial, we successfully implemented a custom fully-fused Adam optimizer using foreach_map. |
| 169 | +# By leveraging the power of foreach_map and torch.compile, we were able to create an optimized version of the Adam |
| 170 | +# optimizer that can be used in various machine learning applications. This tutorial provides a comprehensive guide |
| 171 | +# on how to use foreach_map and torch.compile to optimize machine learning models, and serves as a |
| 172 | +# valuable resource for developers looking to improve the performance of their models with horizontal fusion. |
170 | 173 | #
|
171 | 174 | # See also:
|
172 | 175 | #
|
|
0 commit comments