-
Notifications
You must be signed in to change notification settings - Fork 4.2k
Foreach_map tutorial #3318
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Foreach_map tutorial #3318
Changes from 1 commit
Commits
Show all changes
8 commits
Select commit
Hold shift + click to select a range
ae45035
Foreach_map tutorial
mlazos a96d0c4
Update recipes_source/foreach_map.py
mlazos da21908
Udpates to tutorial
mlazos ada986d
More updates
mlazos 13bc5e3
Merge branch 'main' into mlazos/foreach_map_tutorial
mlazos 340eedd
Merge branch 'main' into mlazos/foreach_map_tutorial
AlannaBurke 7b6dc86
Merge branch 'main' into mlazos/foreach_map_tutorial
svekars bb6cb29
Merge branch 'main' into mlazos/foreach_map_tutorial
svekars File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,174 @@ | ||
""" | ||
(beta) Explicit horizontal fusion with foreach_map and torch.compile | ||
============================================================ | ||
|
||
**Author:** `Michael Lazos <https://github.com/mlazos>`_ | ||
""" | ||
|
||
######################################################### | ||
# Horizontal fusion is a key optimization in ML compilers. In eager, | ||
# this is typically expressed using the torch._foreach* ops which paralellizes | ||
# operations across a list of tensors. However, supporting all possible permuatations | ||
# of arguments is quite difficult (e.g. mixtures of scalars and lists). Foreach_map | ||
# allows conversion of any pointwise op in torch to a horiztonally fused foreach | ||
mlazos marked this conversation as resolved.
Show resolved
Hide resolved
|
||
# variant. In this tutorial, we will demonstrate how implement the Adam optimizer | ||
mlazos marked this conversation as resolved.
Show resolved
Hide resolved
|
||
# with foreach_map and generate a fully fused kernel. | ||
mlazos marked this conversation as resolved.
Show resolved
Hide resolved
|
||
# | ||
# | ||
# .. note:: | ||
# | ||
# This tutorial requires PyTorch 2.6.0 or later. | ||
|
||
##################################################################### | ||
# Model Setup | ||
# ~~~~~~~~~~~~~~~~~~~~~ | ||
# For this example, we'll use a simple sequence of linear layers. | ||
# We instantiate an independent copy to compare the two optimizer implementations. | ||
# | ||
|
||
# exit cleanly if we are on a device that doesn't support ``torch.compile`` | ||
if torch.cuda.get_device_capability() < (7, 0): | ||
print("Exiting because torch.compile is not supported on this device.") | ||
import sys | ||
sys.exit(0) | ||
|
||
import torch | ||
|
||
# Create simple model | ||
model = torch.nn.Sequential( | ||
*[torch.nn.Linear(1024, 1024, False, device="cuda") for _ in range(10)] | ||
) | ||
model_copy = torch.nn.Sequential( | ||
*[torch.nn.Linear(1024, 1024, False, device="cuda") for _ in range(10)] | ||
) | ||
input = torch.rand(1024, device="cuda") | ||
|
||
# run forward pass | ||
output = model(input) | ||
output_copy = model_copy(input) | ||
|
||
# run backward to populate the grads for our optimizer below | ||
output.sum().backward() | ||
output_copy.sum().backward() | ||
|
||
##################################################################### | ||
# Helper functions for foreach_map implementation | ||
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ | ||
# | ||
# In this section, we'll begin out implementation of the Adam optimizer. | ||
mlazos marked this conversation as resolved.
Show resolved
Hide resolved
|
||
# | ||
from torch._higher_order_ops.foreach_map import foreach_map | ||
|
||
# Helper function to extract optimizer states from a torch.optim.Adam instance | ||
def get_inputs(optim): | ||
steps = [] | ||
params = [] | ||
grads = [] | ||
mlazos marked this conversation as resolved.
Show resolved
Hide resolved
|
||
exp_avgs = [] | ||
exp_avg_sqs = [] | ||
for group in optim.param_groups: | ||
for p in group["params"]: | ||
params.append(p) | ||
grads.append(p.grad) | ||
state = optim.state[p] | ||
exp_avgs.append(state["exp_avg"]) | ||
exp_avg_sqs.append(state["exp_avg_sq"]) | ||
steps.append(state["step"]) | ||
|
||
return steps, params, exp_avgs, exp_avg_sqs | ||
|
||
|
||
# Functions to update the different optimizer states | ||
def update_exp_avg_sq(exp_avg_sq, grad, beta2): | ||
return exp_avg_sq.mul(beta2).addcmul(grad, grad, value=1 - beta2) | ||
|
||
def update_param(param, step, exp_avg, exp_avg_sq, beta1, beta2, lr, eps): | ||
bias_correction1 = 1 - torch.pow(beta1, step) | ||
bias_correction2 = (1 - torch.pow(beta2, step)).sqrt() | ||
step_size = (lr / bias_correction1).neg() | ||
denom = (exp_avg_sq.sqrt() / (bias_correction2 * step_size)).add(eps / step_size) | ||
return torch.add(param, torch.div(exp_avg, denom)) | ||
|
||
# Our full adam implementation | ||
mlazos marked this conversation as resolved.
Show resolved
Hide resolved
|
||
def foreach_map_adam( | ||
steps, | ||
params, | ||
exp_avgs, | ||
exp_avg_sqs, | ||
weight_decay=0, | ||
beta1=0.9, | ||
beta2=0.999, | ||
lr=1e-3, | ||
eps=1e-8, | ||
): | ||
with torch.no_grad(): | ||
grads = [param.grad for param in params] | ||
# update step | ||
updated_steps = foreach_map(lambda x: x + 1, steps) | ||
torch._foreach_copy_(steps, updated_steps) | ||
|
||
if weight_decay != 0: | ||
foreach_map(torch.add, (grads,), alpha=weight_decay) | ||
|
||
# Higher-order operators (HOPs) cannot have multiple outputs at the moment | ||
# need to call foreach_map once for each output | ||
exp_avgs_updated = foreach_map(torch.lerp, exp_avgs, grads, 1 - beta1) | ||
exp_avgs_sq_updated = foreach_map(update_exp_avg_sq, exp_avg_sqs, grads, beta2) | ||
params_updated = foreach_map( | ||
update_param, | ||
params, | ||
steps, | ||
exp_avgs_updated, | ||
exp_avgs_sq_updated, | ||
beta1, | ||
beta2, | ||
lr, | ||
eps, | ||
) | ||
# Higher-order operators (HOPs) don't support input mutation today | ||
# so manually update the states in-place | ||
torch._foreach_copy_(exp_avgs, exp_avgs_updated) | ||
torch._foreach_copy_(exp_avg_sqs, exp_avgs_sq_updated) | ||
torch._foreach_copy_(params, params_updated) | ||
return | ||
|
||
##################################################################### | ||
# Setting up and running the compiled kernel | ||
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ | ||
# | ||
# In this section, we'll run our Adam optimizer | ||
# and compare the results | ||
# | ||
# .. note:: | ||
# | ||
# ``torch.compile`` is only supported on CUDA devices that have a compute capability of 7.0 or higher. | ||
opt_eager = torch.optim.Adam(model.parameters(), lr=torch.tensor(0.01)) | ||
opt_eager_copy = torch.optim.Adam(model_copy.parameters(), lr=torch.tensor(0.01)) | ||
|
||
# warm up the optimizer state dict | ||
opt_eager.step() | ||
opt_eager_copy.step() | ||
|
||
inputs = get_inputs(opt_eager_copy) | ||
compiled_adam = torch.compile(foreach_map_adam) | ||
|
||
# optionally view the output code | ||
torch._logging.set_logs(output_code=True) | ||
|
||
# Warmup runs to compile the function | ||
for _ in range(5): | ||
opt_eager.step() | ||
compiled_adam(*inputs) | ||
|
||
for eager_p, compile_p in zip(opt_eager.param_groups[0]["params"], opt_eager_copy.param_groups[0]["params"]): | ||
torch.allclose(eager_p, compile_p) | ||
mlazos marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
###################################################################### | ||
# Conclusion | ||
# ~~~~~~~~~~ | ||
# In this tutorial, we implemented a custom fully fused Adam optimizer using foreach_map. | ||
mlazos marked this conversation as resolved.
Show resolved
Hide resolved
|
||
# | ||
# See also: | ||
# | ||
# * `Compiled optimizer tutorial <https://pytorch.org/tutorials/recipes/compiling_optimizer.html>`__ - an intro into the compiled optimizer. | ||
# * `Compiling the optimizer with PT2 <https://dev-discuss.pytorch.org/t/compiling-the-optimizer-with-pt2/1669>`__ - deeper technical details on the compiled optimizer. |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.