Skip to content

Commit 404f254

Browse files
rohan-varmapytorchmergebot
authored andcommitted
Upstream apply_optim_in_backward from TorchRec (#87397) (#88539)
Summary: Upstreaming this as part of sharing common APIs. This is just a plain move, any changes needed to support DDP / FSDP will come in follow up diffs. Test Plan: CI Reviewed By: zhaojuanmao Differential Revision: D40564646 fbshipit-source-id: 619c434e02196812f8d4db1e40d07290e08b18f9 Pull Request resolved: #88539 Approved by: https://github.com/awgu
1 parent da452bc commit 404f254

File tree

3 files changed

+192
-0
lines changed

3 files changed

+192
-0
lines changed
Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,113 @@
1+
# Owner(s): ["oncall: distributed"]
2+
3+
# Copyright (c) Meta Platforms, Inc. and affiliates.
4+
# All rights reserved.
5+
#
6+
# This source code is licensed under the BSD-style license found in the
7+
# LICENSE file in the root directory of this source tree.
8+
9+
10+
import unittest
11+
from copy import deepcopy
12+
13+
import torch
14+
import torch.nn as nn
15+
16+
from torch.distributed.optim import _apply_optimizer_in_backward
17+
18+
# TODO (rohan-varma): Add FSDP & DDP tests once supported
19+
20+
def _validate_params(params_list, fn):
21+
ref_params = params_list[0]
22+
for param_list in params_list[1:]:
23+
for p1, p2 in zip(ref_params, param_list):
24+
fn(p1, p2)
25+
26+
27+
class ApplyOverlappedOptimizerTest(unittest.TestCase):
28+
29+
def _run_training_loop_and_validate(self, inp, models, optimizers):
30+
for i in range(6):
31+
for model in models:
32+
model(inp).sum().backward()
33+
for opt in optimizers:
34+
opt.step()
35+
36+
with self.subTest(i):
37+
_validate_params(
38+
[model.parameters() for model in models],
39+
torch.testing.assert_allclose,
40+
)
41+
42+
for opt in optimizers:
43+
opt.zero_grad(set_to_none=True)
44+
45+
def _test_apply_optimizer_in_backward(self, share_params) -> None:
46+
weight_optimizer_kwargs = {"lr": 1.0}
47+
bias_optimizer_kwargs = {"lr": 0.5}
48+
model = nn.Sequential(nn.Linear(10, 10), nn.Linear(10, 10))
49+
if share_params:
50+
model[0].weight = model[1].weight
51+
52+
# Use different optimizers for weights & biases.
53+
weights = [m.weight for m in model]
54+
biases = [m.bias for m in model]
55+
optim_weight = torch.optim.SGD(weights, **weight_optimizer_kwargs)
56+
optim_bias = torch.optim.SGD(biases, **bias_optimizer_kwargs)
57+
model_with_opt_in_bwd = deepcopy(model)
58+
59+
# Apply different optimizer in backwards for weights and biases.
60+
_apply_optimizer_in_backward(
61+
torch.optim.SGD,
62+
[m.weight for m in model_with_opt_in_bwd],
63+
optimizer_kwargs=weight_optimizer_kwargs
64+
)
65+
66+
_apply_optimizer_in_backward(
67+
torch.optim.SGD,
68+
[m.bias for m in model_with_opt_in_bwd],
69+
optimizer_kwargs=bias_optimizer_kwargs
70+
)
71+
72+
_validate_params(
73+
[
74+
model.parameters(),
75+
model_with_opt_in_bwd.parameters(),
76+
],
77+
torch.testing.assert_allclose,
78+
)
79+
80+
self._run_training_loop_and_validate(
81+
torch.randn(4, 10),
82+
[model, model_with_opt_in_bwd],
83+
[optim_weight, optim_bias],
84+
)
85+
86+
def test_apply_optimizer_in_backward(self) -> None:
87+
self._test_apply_optimizer_in_backward(share_params=False)
88+
89+
def test_apply_optimizer_in_backward_shared_params(self) -> None:
90+
self._test_apply_optimizer_in_backward(share_params=True)
91+
92+
def test_multiple_optim_for_params(self) -> None:
93+
model = nn.Sequential(nn.Linear(10, 10), nn.Linear(10, 10))
94+
opt_0_kwargs = {"lr": 0.03}
95+
opt_1_kwargs = {"lr": 0.01}
96+
opt_0 = torch.optim.SGD(model.parameters(), **opt_0_kwargs)
97+
opt_1 = torch.optim.SGD(model.parameters(), **opt_1_kwargs)
98+
model_with_opt_in_bwd = deepcopy(model)
99+
_apply_optimizer_in_backward(
100+
torch.optim.SGD,
101+
model_with_opt_in_bwd.parameters(),
102+
optimizer_kwargs=opt_0_kwargs,
103+
)
104+
_apply_optimizer_in_backward(
105+
torch.optim.SGD,
106+
model_with_opt_in_bwd.parameters(),
107+
optimizer_kwargs=opt_1_kwargs,
108+
)
109+
self._run_training_loop_and_validate(
110+
torch.randn(4, 10),
111+
[model, model_with_opt_in_bwd],
112+
[opt_0, opt_1],
113+
)

torch/distributed/optim/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from .functional_rprop import _FunctionalRprop
1818
from .functional_adamax import _FunctionalAdamax
1919
from .utils import as_functional_optim
20+
from .apply_optimizer_in_backward import _apply_optimizer_in_backward
2021

2122

2223
# DistributedOptimizer imports torch.distributed.rpc names, so gate availability
Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
from typing import Any, Dict, Iterable, Type, List, no_type_check
2+
3+
import torch
4+
5+
__all__: List[str] = []
6+
7+
@no_type_check
8+
def _apply_optimizer_in_backward(
9+
optimizer_class: Type[torch.optim.Optimizer],
10+
params: Iterable[torch.nn.Parameter],
11+
optimizer_kwargs: Dict[str, Any],
12+
) -> None:
13+
"""
14+
Upon ``backward()``, parameters will fire the corresponding optimizer.
15+
16+
Note - gradients for these parameters will be set to None after ``backward()``.
17+
This means that any other (non applied) optimizer over this parameter will be
18+
a no-op.
19+
20+
Args:
21+
optimizer_class: (Type[torch.optim.Optimizer]): Optimizer to apply to parameter
22+
params: (Iterator[nn.Parameter]): parameters to apply optimizer state to
23+
optimizer_kwargs: (Dict[str, Any]): kwargs to pass to optimizer constructor
24+
25+
Example::
26+
params_generator = model.parameters()
27+
param_1 = next(params_generator)
28+
remainder_params = list(params_generator)
29+
30+
apply_optimizer_in_backward(torch.optim.SGD, [param_1], {"lr": .02})
31+
apply_optimizer_in_backward(torch.optim.Adam, remainder_params, {"lr": .04})
32+
33+
model(...).sum().backward() # after backward, parameters will already
34+
# have their registered optimizer applied.
35+
36+
"""
37+
38+
@no_type_check
39+
def _apply_optimizer_in_backward_to_param(param: torch.nn.Parameter) -> None:
40+
# view_as creates a node in autograd graph that allows us access to the
41+
# parameter's AccumulateGrad autograd function object. We register a
42+
# hook on this object to fire the optimizer when the gradient for
43+
# this parameter is ready (has been accumulated into .grad field)
44+
45+
# Don't create a new acc_grad if we already have one
46+
# i.e.f or shared parameters or attaching multiple optimizers to a param.
47+
if not hasattr(param, 'acc_grad'):
48+
acc_grad = param.view_as(param).grad_fn.next_functions[0][0]
49+
else:
50+
acc_grad = param._acc_grad
51+
52+
optimizer = optimizer_class([param], **optimizer_kwargs)
53+
54+
# Keep the grad accumulator around for the lifetime of the Tensor,
55+
# store it on the param to avoid uncollectable ref-cycle
56+
if not hasattr(param, 'acc_grad'):
57+
param._acc_grad = acc_grad # type: ignore[attr-defined]
58+
59+
if not hasattr(param, '_in_backward_optimizers'):
60+
param._in_backward_optimizers = [] # type: ignore[attr-defined]
61+
# TODO: investigate whether we really need these attributes.
62+
param._optimizer_classes = [] # type: ignore[attr-defined]
63+
param._optimizer_kwargs = [] # type: ignore[attr-defined]
64+
65+
param._in_backward_optimizers.append(optimizer) # type: ignore[attr-defined]
66+
param._optimizer_classes.append(optimizer_class) # type: ignore[attr-defined]
67+
param._optimizer_kwargs.append(optimizer_kwargs) # type: ignore[attr-defined]
68+
69+
def optimizer_hook(*_unused) -> None:
70+
for opt in param._in_backward_optimizers: # type: ignore[attr-defined]
71+
opt.step()
72+
73+
param.grad = None
74+
75+
param._acc_grad.register_hook(optimizer_hook) # type: ignore[attr-defined]
76+
77+
for param in params:
78+
_apply_optimizer_in_backward_to_param(param)

0 commit comments

Comments
 (0)