Skip to content

Commit 9b3a56e

Browse files
rohan-varmapytorchmergebot
authored andcommitted
[Optimizer Overlap] Move hooks to own file (#71601)
Summary: Pull Request resolved: #71601 Moves current prototype optimizer overlap to its own file for a better namespace. No code changes besides a few comment fixes. Note that this code is still prototype and not expected to be used by an end user. ghstack-source-id: 147458826 Test Plan: CI Reviewed By: cbalioglu Differential Revision: D33662678 fbshipit-source-id: 3cc931323230a4b66c02b9e6f744aaf5c48d4d34 (cherry picked from commit 5070595)
1 parent ba08440 commit 9b3a56e

File tree

4 files changed

+68
-61
lines changed

4 files changed

+68
-61
lines changed

torch/distributed/algorithms/ddp_comm_hooks/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
default_hooks as default,
1010
powerSGD_hook as powerSGD,
1111
quantization_hooks as quantization,
12+
optimizer_overlap_hooks as optimizer_overlap,
1213
)
1314

1415

torch/distributed/algorithms/ddp_comm_hooks/default_hooks.py

Lines changed: 0 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -105,62 +105,6 @@ def decompress(fut):
105105
return fut.then(decompress)
106106

107107

108-
class _OptimizerHookState(object):
109-
"""
110-
Holds state for running optimizer in-line after DDP communication hook.
111-
Currently contains only optimizer class which must have a method `step_param`.
112-
"""
113-
114-
__slots__ = ["functional_optimizer"]
115-
116-
def __init__(
117-
self, functional_optim_cls, *functional_optim_args, **functional_optim_kwargs
118-
):
119-
self.functional_optimizer = functional_optim_cls(
120-
[],
121-
*functional_optim_args,
122-
**functional_optim_kwargs,
123-
_allow_empty_param_list=True,
124-
)
125-
if not hasattr(self.functional_optimizer, "step_param"):
126-
raise ValueError(
127-
f"Class {functional_optim_cls} must implement method step_param."
128-
)
129-
130-
131-
# TODO: Add an example to use such a wrapper.
132-
def _hook_then_optimizer(
133-
hook: Callable[[Any, dist.GradBucket], torch.futures.Future[torch.Tensor]],
134-
optimizer_state: _OptimizerHookState,
135-
) -> Callable[[Any, dist.GradBucket], torch.futures.Future[torch.Tensor]]:
136-
r"""
137-
Runs optimizer in a functional fashion after DDP communication hook.
138-
139-
.. warning ::
140-
This API is experimental adn subject to change.
141-
"""
142-
143-
144-
def hook_then_optimizer_wrapper(
145-
hook_state, bucket: dist.GradBucket
146-
) -> torch.futures.Future[torch.Tensor]:
147-
# Run original hook
148-
fut = hook(hook_state, bucket)
149-
150-
def optimizer_step(fut):
151-
gradient_tensors = bucket.gradients()
152-
model_params = bucket.parameters()
153-
for grad_tensor, model_param in zip(gradient_tensors, model_params):
154-
optimizer_state.functional_optimizer.step_param(
155-
model_param,
156-
grad_tensor,
157-
)
158-
return bucket.buffer()
159-
return fut.then(optimizer_step)
160-
161-
return hook_then_optimizer_wrapper
162-
163-
164108
def fp16_compress_wrapper(
165109
hook: Callable[[Any, dist.GradBucket], torch.futures.Future[torch.Tensor]]
166110
) -> Callable[[Any, dist.GradBucket], torch.futures.Future[torch.Tensor]]:
Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
from typing import Any, Callable
2+
3+
import torch
4+
import torch.distributed as dist
5+
6+
7+
class _OptimizerHookState(object):
8+
"""
9+
Holds state for running optimizer in-line after DDP communication hook.
10+
Currently contains only optimizer class which must have a method `step_param`.
11+
"""
12+
13+
__slots__ = ["functional_optimizer"]
14+
15+
def __init__(
16+
self, functional_optim_cls, *functional_optim_args, **functional_optim_kwargs
17+
):
18+
self.functional_optimizer = functional_optim_cls(
19+
[],
20+
*functional_optim_args,
21+
**functional_optim_kwargs,
22+
_allow_empty_param_list=True,
23+
)
24+
if not hasattr(self.functional_optimizer, "step_param"):
25+
raise ValueError(
26+
f"Class {functional_optim_cls} must implement method step_param."
27+
)
28+
29+
30+
# TODO: Add an example to use such a wrapper.
31+
def _hook_then_optimizer(
32+
hook: Callable[[Any, dist.GradBucket], torch.futures.Future[torch.Tensor]],
33+
optimizer_state: _OptimizerHookState,
34+
) -> Callable[[Any, dist.GradBucket], torch.futures.Future[torch.Tensor]]:
35+
r"""
36+
Runs optimizer in a functional fashion after DDP communication hook.
37+
38+
.. warning ::
39+
This API is experimental adn subject to change.
40+
"""
41+
42+
def hook_then_optimizer_wrapper(
43+
hook_state, bucket: dist.GradBucket
44+
) -> torch.futures.Future[torch.Tensor]:
45+
# Run original hook
46+
fut = hook(hook_state, bucket)
47+
48+
def optimizer_step(fut):
49+
gradient_tensors = bucket.gradients()
50+
model_params = bucket.parameters()
51+
for grad_tensor, model_param in zip(gradient_tensors, model_params):
52+
optimizer_state.functional_optimizer.step_param(
53+
model_param,
54+
grad_tensor,
55+
)
56+
return bucket.buffer()
57+
58+
return fut.then(optimizer_step)
59+
60+
return hook_then_optimizer_wrapper

torch/testing/_internal/distributed/distributed_test.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,6 @@
1515
import torch
1616
import torch.cuda
1717
import torch.distributed as dist
18-
import torch.distributed.algorithms.ddp_comm_hooks.post_localSGD_hook as post_localSGD
19-
import torch.distributed.algorithms.ddp_comm_hooks.powerSGD_hook as powerSGD
2018
import torch.distributed.algorithms.model_averaging.averagers as averagers
2119
import torch.distributed.algorithms.model_averaging.utils as model_averaging_utils
2220
import torch.nn as nn
@@ -25,9 +23,13 @@
2523
from torch._utils_internal import TEST_MASTER_ADDR as MASTER_ADDR
2624
from torch._utils_internal import TEST_MASTER_PORT as MASTER_PORT
2725
from torch.cuda.amp import GradScaler, autocast
28-
from torch.distributed.algorithms.ddp_comm_hooks import default_hooks as default
26+
2927
from torch.distributed.algorithms.ddp_comm_hooks import (
28+
post_localSGD_hook as post_localSGD,
29+
powerSGD_hook as powerSGD,
30+
default_hooks as default,
3031
quantization as quantization_hooks,
32+
optimizer_overlap as optimizer_overlap_hooks
3133
)
3234
from torch.distributed.distributed_c10d import (
3335
get_world_size,
@@ -3944,14 +3946,14 @@ def _test_ddp_hook_with_optimizer_parity(
39443946
# Register hook that runs allreduce + functional optimizer
39453947
# step.
39463948
allreduce_hook = default.allreduce_hook
3947-
opt_hook_state = default._OptimizerHookState(
3949+
opt_hook_state = optimizer_overlap_hooks._OptimizerHookState(
39483950
functional_optim_cls,
39493951
*functional_optim_args,
39503952
**functional_optim_kwargs,
39513953
)
39523954
ddp_model_with_optimizer_hook.register_comm_hook(
39533955
None,
3954-
default._hook_then_optimizer(allreduce_hook, opt_hook_state),
3956+
optimizer_overlap_hooks._hook_then_optimizer(allreduce_hook, opt_hook_state),
39553957
)
39563958
# Create DDP model with no hook that does optimizer after
39573959
# backward.

0 commit comments

Comments
 (0)