Skip to content

Commit c52cf42

Browse files
authored
Pyramid Attention Broadcast rewrite + introduce hooks (#9826)
* rewrite implementation with hooks * make style * update
1 parent 18b7d6d commit c52cf42

File tree

5 files changed

+344
-122
lines changed

5 files changed

+344
-122
lines changed

src/diffusers/models/hooks.py

Lines changed: 251 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,251 @@
1+
# Copyright 2024 The HuggingFace Team. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import functools
16+
from typing import Any, Callable, Dict, Tuple, Union
17+
18+
import torch
19+
20+
21+
# Reference: https://github.com/huggingface/accelerate/blob/ba7ab93f5e688466ea56908ea3b056fae2f9a023/src/accelerate/hooks.py
22+
class ModelHook:
23+
r"""
24+
A hook that contains callbacks to be executed just before and after the forward method of a model. The difference
25+
with PyTorch existing hooks is that they get passed along the kwargs.
26+
"""
27+
28+
def init_hook(self, module: torch.nn.Module) -> torch.nn.Module:
29+
r"""
30+
Hook that is executed when a model is initialized.
31+
32+
Args:
33+
module (`torch.nn.Module`):
34+
The module attached to this hook.
35+
"""
36+
return module
37+
38+
def pre_forward(self, module: torch.nn.Module, *args, **kwargs) -> Tuple[Tuple[Any], Dict[str, Any]]:
39+
r"""
40+
Hook that is executed just before the forward method of the model.
41+
42+
Args:
43+
module (`torch.nn.Module`):
44+
The module whose forward pass will be executed just after this event.
45+
args (`Tuple[Any]`):
46+
The positional arguments passed to the module.
47+
kwargs (`Dict[Str, Any]`):
48+
The keyword arguments passed to the module.
49+
50+
Returns:
51+
`Tuple[Tuple[Any], Dict[Str, Any]]`:
52+
A tuple with the treated `args` and `kwargs`.
53+
"""
54+
return args, kwargs
55+
56+
def post_forward(self, module: torch.nn.Module, output: Any) -> Any:
57+
r"""
58+
Hook that is executed just after the forward method of the model.
59+
60+
Args:
61+
module (`torch.nn.Module`):
62+
The module whose forward pass been executed just before this event.
63+
output (`Any`):
64+
The output of the module.
65+
66+
Returns:
67+
`Any`: The processed `output`.
68+
"""
69+
return output
70+
71+
def detach_hook(self, module: torch.nn.Module) -> torch.nn.Module:
72+
r"""
73+
Hook that is executed when the hook is detached from a module.
74+
75+
Args:
76+
module (`torch.nn.Module`):
77+
The module detached from this hook.
78+
"""
79+
return module
80+
81+
def reset_state(self, module: torch.nn.Module) -> torch.nn.Module:
82+
return module
83+
84+
85+
class SequentialHook(ModelHook):
86+
r"""A hook that can contain several hooks and iterates through them at each event."""
87+
88+
def __init__(self, *hooks):
89+
self.hooks = hooks
90+
91+
def init_hook(self, module):
92+
for hook in self.hooks:
93+
module = hook.init_hook(module)
94+
return module
95+
96+
def pre_forward(self, module, *args, **kwargs):
97+
for hook in self.hooks:
98+
args, kwargs = hook.pre_forward(module, *args, **kwargs)
99+
return args, kwargs
100+
101+
def post_forward(self, module, output):
102+
for hook in self.hooks:
103+
output = hook.post_forward(module, output)
104+
return output
105+
106+
def detach_hook(self, module):
107+
for hook in self.hooks:
108+
module = hook.detach_hook(module)
109+
return module
110+
111+
def reset_state(self, module):
112+
for hook in self.hooks:
113+
module = hook.reset_state(module)
114+
return module
115+
116+
117+
class PyramidAttentionBroadcastHook(ModelHook):
118+
def __init__(
119+
self,
120+
skip_range: int,
121+
timestep_range: Tuple[int, int],
122+
timestep_callback: Callable[[], Union[torch.LongTensor, int]],
123+
) -> None:
124+
super().__init__()
125+
126+
self.skip_range = skip_range
127+
self.timestep_range = timestep_range
128+
self.timestep_callback = timestep_callback
129+
130+
self.attention_cache = None
131+
self._iteration = 0
132+
133+
def new_forward(self, module: torch.nn.Module, *args, **kwargs) -> Any:
134+
args, kwargs = module._diffusers_hook.pre_forward(module, *args, **kwargs)
135+
136+
current_timestep = self.timestep_callback()
137+
is_within_timestep_range = self.timestep_range[0] < current_timestep < self.timestep_range[1]
138+
should_compute_attention = self._iteration % self.skip_range == 0
139+
140+
if not is_within_timestep_range or should_compute_attention:
141+
output = module._old_forward(*args, **kwargs)
142+
else:
143+
output = self.attention_cache
144+
145+
self._iteration = self._iteration + 1
146+
147+
return module._diffusers_hook.post_forward(module, output)
148+
149+
def post_forward(self, module: torch.nn.Module, output: Any) -> Any:
150+
self.attention_cache = output
151+
return output
152+
153+
def reset_state(self, module: torch.nn.Module) -> torch.nn.Module:
154+
self.attention_cache = None
155+
self._iteration = 0
156+
return module
157+
158+
159+
def add_hook_to_module(module: torch.nn.Module, hook: ModelHook, append: bool = False):
160+
r"""
161+
Adds a hook to a given module. This will rewrite the `forward` method of the module to include the hook, to remove
162+
this behavior and restore the original `forward` method, use `remove_hook_from_module`.
163+
164+
<Tip warning={true}>
165+
166+
If the module already contains a hook, this will replace it with the new hook passed by default. To chain two hooks
167+
together, pass `append=True`, so it chains the current and new hook into an instance of the `SequentialHook` class.
168+
169+
</Tip>
170+
171+
Args:
172+
module (`torch.nn.Module`):
173+
The module to attach a hook to.
174+
hook (`ModelHook`):
175+
The hook to attach.
176+
append (`bool`, *optional*, defaults to `False`):
177+
Whether the hook should be chained with an existing one (if module already contains a hook) or not.
178+
179+
Returns:
180+
`torch.nn.Module`:
181+
The same module, with the hook attached (the module is modified in place, so the result can be discarded).
182+
"""
183+
original_hook = hook
184+
185+
if append and getattr(module, "_diffusers_hook", None) is not None:
186+
old_hook = module._diffusers_hook
187+
remove_hook_from_module(module)
188+
hook = SequentialHook(old_hook, hook)
189+
190+
if hasattr(module, "_diffusers_hook") and hasattr(module, "_old_forward"):
191+
# If we already put some hook on this module, we replace it with the new one.
192+
old_forward = module._old_forward
193+
else:
194+
old_forward = module.forward
195+
module._old_forward = old_forward
196+
197+
module = hook.init_hook(module)
198+
module._diffusers_hook = hook
199+
200+
if hasattr(original_hook, "new_forward"):
201+
new_forward = original_hook.new_forward
202+
else:
203+
204+
def new_forward(module, *args, **kwargs):
205+
args, kwargs = module._diffusers_hook.pre_forward(module, *args, **kwargs)
206+
output = module._old_forward(*args, **kwargs)
207+
return module._diffusers_hook.post_forward(module, output)
208+
209+
# Overriding a GraphModuleImpl forward freezes the forward call and later modifications on the graph will fail.
210+
# Reference: https://pytorch.slack.com/archives/C3PDTEV8E/p1705929610405409
211+
if "GraphModuleImpl" in str(type(module)):
212+
module.__class__.forward = functools.update_wrapper(functools.partial(new_forward, module), old_forward)
213+
else:
214+
module.forward = functools.update_wrapper(functools.partial(new_forward, module), old_forward)
215+
216+
return module
217+
218+
219+
def remove_hook_from_module(module: torch.nn.Module, recurse: bool = False) -> torch.nn.Module:
220+
"""
221+
Removes any hook attached to a module via `add_hook_to_module`.
222+
223+
Args:
224+
module (`torch.nn.Module`):
225+
The module to attach a hook to.
226+
recurse (`bool`, defaults to `False`):
227+
Whether to remove the hooks recursively
228+
229+
Returns:
230+
`torch.nn.Module`:
231+
The same module, with the hook detached (the module is modified in place, so the result can be discarded).
232+
"""
233+
234+
if hasattr(module, "_diffusers_hook"):
235+
module._diffusers_hook.detach_hook(module)
236+
delattr(module, "_diffusers_hook")
237+
238+
if hasattr(module, "_old_forward"):
239+
# Overriding a GraphModuleImpl forward freezes the forward call and later modifications on the graph will fail.
240+
# Reference: https://pytorch.slack.com/archives/C3PDTEV8E/p1705929610405409
241+
if "GraphModuleImpl" in str(type(module)):
242+
module.__class__.forward = module._old_forward
243+
else:
244+
module.forward = module._old_forward
245+
delattr(module, "_old_forward")
246+
247+
if recurse:
248+
for child in module.children():
249+
remove_hook_from_module(child, recurse)
250+
251+
return module

src/diffusers/pipelines/allegro/pipeline_allegro.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
)
3939
from ...utils.torch_utils import randn_tensor
4040
from ...video_processor import VideoProcessor
41+
from ..pyramid_broadcast_utils import PyramidAttentionBroadcastMixin
4142
from .pipeline_output import AllegroPipelineOutput
4243

4344

@@ -131,7 +132,7 @@ def retrieve_timesteps(
131132
return timesteps, num_inference_steps
132133

133134

134-
class AllegroPipeline(DiffusionPipeline):
135+
class AllegroPipeline(DiffusionPipeline, PyramidAttentionBroadcastMixin):
135136
r"""
136137
Pipeline for text-to-video generation using Allegro.
137138
@@ -786,6 +787,7 @@ def __call__(
786787
negative_prompt_attention_mask,
787788
)
788789
self._guidance_scale = guidance_scale
790+
self._current_timestep = None
789791
self._interrupt = False
790792

791793
# 2. Default height and width to transformer
@@ -863,6 +865,7 @@ def __call__(
863865
if self.interrupt:
864866
continue
865867

868+
self._current_timestep = t
866869
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
867870
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
868871

@@ -901,6 +904,8 @@ def __call__(
901904
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
902905
progress_bar.update()
903906

907+
self._current_timestep = None
908+
904909
if not output_type == "latent":
905910
latents = latents.to(self.vae.dtype)
906911
video = self.decode_latents(latents)

src/diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
from ...utils import logging, replace_example_docstring
3131
from ...utils.torch_utils import randn_tensor
3232
from ...video_processor import VideoProcessor
33+
from ..pyramid_broadcast_utils import PyramidAttentionBroadcastMixin
3334
from .pipeline_output import CogVideoXPipelineOutput
3435

3536

@@ -144,7 +145,7 @@ def retrieve_timesteps(
144145
return timesteps, num_inference_steps
145146

146147

147-
class CogVideoXFunControlPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):
148+
class CogVideoXFunControlPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin, PyramidAttentionBroadcastMixin):
148149
r"""
149150
Pipeline for controlled text-to-video generation using CogVideoX Fun.
150151
@@ -650,6 +651,7 @@ def __call__(
650651
)
651652
self._guidance_scale = guidance_scale
652653
self._attention_kwargs = attention_kwargs
654+
self._current_timestep = None
653655
self._interrupt = False
654656

655657
# 2. Default call parameters
@@ -730,6 +732,7 @@ def __call__(
730732
if self.interrupt:
731733
continue
732734

735+
self._current_timestep = t
733736
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
734737
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
735738

@@ -779,6 +782,8 @@ def __call__(
779782
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
780783
progress_bar.update()
781784

785+
self._current_timestep = None
786+
782787
if not output_type == "latent":
783788
video = self.decode_latents(latents)
784789
video = self.video_processor.postprocess_video(video=video, output_type=output_type)

src/diffusers/pipelines/pipeline_utils.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1082,6 +1082,10 @@ def maybe_free_model_hooks(self):
10821082
is a no-op. Make sure to add this function to the end of the `__call__` function of your pipeline so that it
10831083
functions correctly when applying enable_model_cpu_offload.
10841084
"""
1085+
1086+
if hasattr(self, "_diffusers_hook"):
1087+
self._diffusers_hook.reset_state()
1088+
10851089
if not hasattr(self, "_all_hooks") or len(self._all_hooks) == 0:
10861090
# `enable_model_cpu_offload` has not be called, so silently do nothing
10871091
return

0 commit comments

Comments
 (0)