Skip to content

Commit da21908

Browse files
committed
Udpates to tutorial
1 parent a96d0c4 commit da21908

File tree

1 file changed

+13
-10
lines changed

1 file changed

+13
-10
lines changed

recipes_source/foreach_map.py

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -7,33 +7,32 @@
77

88
#########################################################
99
# Horizontal fusion is a key optimization in ML compilers. In eager,
10-
# this is typically expressed using the torch._foreach* ops which paralellizes
11-
# operations across a list of tensors. However, supporting all possible permuatations
10+
# this is typically expressed using the torch._foreach* ops which parallelizes
11+
# operations across a list of tensors. However, supporting all possible permutations
1212
# of arguments is quite difficult (e.g. mixtures of scalars and lists). Foreach_map
13-
# allows conversion of any pointwise op in torch to a horiztonally fused foreach
14-
# variant. In this tutorial, we will demonstrate how implement the Adam optimizer
13+
# allows conversion of any pointwise op in ``torch`` to a horiztonally fused foreach
14+
# variant. In this tutorial, we will demonstrate how to implement the Adam optimizer
1515
# with ``foreach_map`` to generate a fully fused kernel.
1616
#
1717
#
1818
# .. note::
1919
#
20-
# This tutorial requires PyTorch 2.6.0 or later.
20+
# This tutorial requires PyTorch 2.7.0 or later.
2121

2222
#####################################################################
2323
# Model Setup
2424
# ~~~~~~~~~~~~~~~~~~~~~
2525
# For this example, we'll use a simple sequence of linear layers.
2626
# We instantiate an independent copy to compare the two optimizer implementations.
2727
#
28+
import torch
2829

2930
# exit cleanly if we are on a device that doesn't support ``torch.compile``
3031
if torch.cuda.get_device_capability() < (7, 0):
3132
print("Exiting because torch.compile is not supported on this device.")
3233
import sys
3334
sys.exit(0)
3435

35-
import torch
36-
3736
# Create simple model
3837
model = torch.nn.Sequential(
3938
*[torch.nn.Linear(1024, 1024, False, device="cuda") for _ in range(10)]
@@ -55,7 +54,7 @@
5554
# Helper functions for foreach_map implementation
5655
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
5756
#
58-
# In this section, we'll begin out implementation of the Adam optimizer.
57+
# In this section, we'll begin our implementation of the Adam optimizer.
5958
#
6059
from torch._higher_order_ops.foreach_map import foreach_map
6160

@@ -89,7 +88,7 @@ def update_param(param, step, exp_avg, exp_avg_sq, beta1, beta2, lr, eps):
8988
denom = (exp_avg_sq.sqrt() / (bias_correction2 * step_size)).add(eps / step_size)
9089
return torch.add(param, torch.div(exp_avg, denom))
9190

92-
# Our full adam implementation
91+
# Our full Adam implementation
9392
def foreach_map_adam(
9493
steps,
9594
params,
@@ -166,7 +165,11 @@ def foreach_map_adam(
166165
######################################################################
167166
# Conclusion
168167
# ~~~~~~~~~~
169-
# In this tutorial, we implemented a custom fully fused Adam optimizer using foreach_map.
168+
# In this tutorial, we successfully implemented a custom fully-fused Adam optimizer using foreach_map.
169+
# By leveraging the power of foreach_map and torch.compile, we were able to create an optimized version of the Adam
170+
# optimizer that can be used in various machine learning applications. This tutorial provides a comprehensive guide
171+
# on how to use foreach_map and torch.compile to optimize machine learning models, and serves as a
172+
# valuable resource for developers looking to improve the performance of their models with horizontal fusion.
170173
#
171174
# See also:
172175
#

0 commit comments

Comments
 (0)