Skip to content

Pyramid Attention Broadcast rewrite + introduce hooks #9826

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Nov 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
251 changes: 251 additions & 0 deletions src/diffusers/models/hooks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,251 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import functools
from typing import Any, Callable, Dict, Tuple, Union

import torch


# Reference: https://github.com/huggingface/accelerate/blob/ba7ab93f5e688466ea56908ea3b056fae2f9a023/src/accelerate/hooks.py
class ModelHook:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If it's not copy-pasted verbatim, then let's add a note describing how it differs from accelerate.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copied verbatim, except for a single attribute (which I think I will remove soon)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Then why can't we import directly?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For now, it would be okay to import but the implementations may diverge soon as the design in my mind might require diffusers to have some additional class attributes + does not make sense to depend on another library for something core to diffusers IMO (any change they make upstream will affect us, and since this is very simple to have a base class, it might better to have our own implementation). Will use imports if we don't need anything additional at the moment

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh then I am okay to have them copy-pasted with this exact note.

r"""
A hook that contains callbacks to be executed just before and after the forward method of a model. The difference
with PyTorch existing hooks is that they get passed along the kwargs.
"""

def init_hook(self, module: torch.nn.Module) -> torch.nn.Module:
r"""
Hook that is executed when a model is initialized.

Args:
module (`torch.nn.Module`):
The module attached to this hook.
"""
return module

def pre_forward(self, module: torch.nn.Module, *args, **kwargs) -> Tuple[Tuple[Any], Dict[str, Any]]:
r"""
Hook that is executed just before the forward method of the model.

Args:
module (`torch.nn.Module`):
The module whose forward pass will be executed just after this event.
args (`Tuple[Any]`):
The positional arguments passed to the module.
kwargs (`Dict[Str, Any]`):
The keyword arguments passed to the module.

Returns:
`Tuple[Tuple[Any], Dict[Str, Any]]`:
A tuple with the treated `args` and `kwargs`.
"""
return args, kwargs

def post_forward(self, module: torch.nn.Module, output: Any) -> Any:
r"""
Hook that is executed just after the forward method of the model.

Args:
module (`torch.nn.Module`):
The module whose forward pass been executed just before this event.
output (`Any`):
The output of the module.

Returns:
`Any`: The processed `output`.
"""
return output

def detach_hook(self, module: torch.nn.Module) -> torch.nn.Module:
r"""
Hook that is executed when the hook is detached from a module.

Args:
module (`torch.nn.Module`):
The module detached from this hook.
"""
return module

def reset_state(self, module: torch.nn.Module) -> torch.nn.Module:
return module


class SequentialHook(ModelHook):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as above.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copied completely

r"""A hook that can contain several hooks and iterates through them at each event."""

def __init__(self, *hooks):
self.hooks = hooks

def init_hook(self, module):
for hook in self.hooks:
module = hook.init_hook(module)
return module

def pre_forward(self, module, *args, **kwargs):
for hook in self.hooks:
args, kwargs = hook.pre_forward(module, *args, **kwargs)
return args, kwargs

def post_forward(self, module, output):
for hook in self.hooks:
output = hook.post_forward(module, output)
return output

def detach_hook(self, module):
for hook in self.hooks:
module = hook.detach_hook(module)
return module

def reset_state(self, module):
for hook in self.hooks:
module = hook.reset_state(module)
return module


class PyramidAttentionBroadcastHook(ModelHook):
def __init__(
self,
skip_range: int,
timestep_range: Tuple[int, int],
timestep_callback: Callable[[], Union[torch.LongTensor, int]],
) -> None:
super().__init__()

self.skip_range = skip_range
self.timestep_range = timestep_range
self.timestep_callback = timestep_callback

self.attention_cache = None
self._iteration = 0

def new_forward(self, module: torch.nn.Module, *args, **kwargs) -> Any:
args, kwargs = module._diffusers_hook.pre_forward(module, *args, **kwargs)

current_timestep = self.timestep_callback()
is_within_timestep_range = self.timestep_range[0] < current_timestep < self.timestep_range[1]
should_compute_attention = self._iteration % self.skip_range == 0

if not is_within_timestep_range or should_compute_attention:
output = module._old_forward(*args, **kwargs)
else:
output = self.attention_cache

self._iteration = self._iteration + 1

return module._diffusers_hook.post_forward(module, output)

def post_forward(self, module: torch.nn.Module, output: Any) -> Any:
self.attention_cache = output
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could consider making this configurable i.e., if we want to offload the cache to CPU/disk while executing a certain computation on the GPU. This can be done with CUDA streams. This way, we should be able to overlap computation and device transfer, saving time.

https://github.com/huggingface/transformers/blob/f83917ed89292072e57c4ba013d1cc6477538b11/src/transformers/cache_utils.py is a good reference for this.

return output

def reset_state(self, module: torch.nn.Module) -> torch.nn.Module:
self.attention_cache = None
self._iteration = 0
return module


def add_hook_to_module(module: torch.nn.Module, hook: ModelHook, append: bool = False):
r"""
Adds a hook to a given module. This will rewrite the `forward` method of the module to include the hook, to remove
this behavior and restore the original `forward` method, use `remove_hook_from_module`.

<Tip warning={true}>

If the module already contains a hook, this will replace it with the new hook passed by default. To chain two hooks
together, pass `append=True`, so it chains the current and new hook into an instance of the `SequentialHook` class.

</Tip>

Args:
module (`torch.nn.Module`):
The module to attach a hook to.
hook (`ModelHook`):
The hook to attach.
append (`bool`, *optional*, defaults to `False`):
Whether the hook should be chained with an existing one (if module already contains a hook) or not.

Returns:
`torch.nn.Module`:
The same module, with the hook attached (the module is modified in place, so the result can be discarded).
"""
original_hook = hook

if append and getattr(module, "_diffusers_hook", None) is not None:
old_hook = module._diffusers_hook
remove_hook_from_module(module)
hook = SequentialHook(old_hook, hook)

if hasattr(module, "_diffusers_hook") and hasattr(module, "_old_forward"):
# If we already put some hook on this module, we replace it with the new one.
old_forward = module._old_forward
else:
old_forward = module.forward
module._old_forward = old_forward

module = hook.init_hook(module)
module._diffusers_hook = hook

if hasattr(original_hook, "new_forward"):
new_forward = original_hook.new_forward
else:

def new_forward(module, *args, **kwargs):
args, kwargs = module._diffusers_hook.pre_forward(module, *args, **kwargs)
output = module._old_forward(*args, **kwargs)
return module._diffusers_hook.post_forward(module, output)

# Overriding a GraphModuleImpl forward freezes the forward call and later modifications on the graph will fail.
# Reference: https://pytorch.slack.com/archives/C3PDTEV8E/p1705929610405409
if "GraphModuleImpl" in str(type(module)):
module.__class__.forward = functools.update_wrapper(functools.partial(new_forward, module), old_forward)
else:
module.forward = functools.update_wrapper(functools.partial(new_forward, module), old_forward)

return module


def remove_hook_from_module(module: torch.nn.Module, recurse: bool = False) -> torch.nn.Module:
"""
Removes any hook attached to a module via `add_hook_to_module`.

Args:
module (`torch.nn.Module`):
The module to attach a hook to.
recurse (`bool`, defaults to `False`):
Whether to remove the hooks recursively

Returns:
`torch.nn.Module`:
The same module, with the hook detached (the module is modified in place, so the result can be discarded).
"""

if hasattr(module, "_diffusers_hook"):
module._diffusers_hook.detach_hook(module)
delattr(module, "_diffusers_hook")

if hasattr(module, "_old_forward"):
# Overriding a GraphModuleImpl forward freezes the forward call and later modifications on the graph will fail.
# Reference: https://pytorch.slack.com/archives/C3PDTEV8E/p1705929610405409
if "GraphModuleImpl" in str(type(module)):
module.__class__.forward = module._old_forward
else:
module.forward = module._old_forward
delattr(module, "_old_forward")

if recurse:
for child in module.children():
remove_hook_from_module(child, recurse)

return module
7 changes: 6 additions & 1 deletion src/diffusers/pipelines/allegro/pipeline_allegro.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
)
from ...utils.torch_utils import randn_tensor
from ...video_processor import VideoProcessor
from ..pyramid_broadcast_utils import PyramidAttentionBroadcastMixin
from .pipeline_output import AllegroPipelineOutput


Expand Down Expand Up @@ -131,7 +132,7 @@ def retrieve_timesteps(
return timesteps, num_inference_steps


class AllegroPipeline(DiffusionPipeline):
class AllegroPipeline(DiffusionPipeline, PyramidAttentionBroadcastMixin):
r"""
Pipeline for text-to-video generation using Allegro.

Expand Down Expand Up @@ -786,6 +787,7 @@ def __call__(
negative_prompt_attention_mask,
)
self._guidance_scale = guidance_scale
self._current_timestep = None
self._interrupt = False

# 2. Default height and width to transformer
Expand Down Expand Up @@ -863,6 +865,7 @@ def __call__(
if self.interrupt:
continue

self._current_timestep = t
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)

Expand Down Expand Up @@ -901,6 +904,8 @@ def __call__(
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()

self._current_timestep = None

if not output_type == "latent":
latents = latents.to(self.vae.dtype)
video = self.decode_latents(latents)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from ...utils import logging, replace_example_docstring
from ...utils.torch_utils import randn_tensor
from ...video_processor import VideoProcessor
from ..pyramid_broadcast_utils import PyramidAttentionBroadcastMixin
from .pipeline_output import CogVideoXPipelineOutput


Expand Down Expand Up @@ -144,7 +145,7 @@ def retrieve_timesteps(
return timesteps, num_inference_steps


class CogVideoXFunControlPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):
class CogVideoXFunControlPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin, PyramidAttentionBroadcastMixin):
r"""
Pipeline for controlled text-to-video generation using CogVideoX Fun.

Expand Down Expand Up @@ -650,6 +651,7 @@ def __call__(
)
self._guidance_scale = guidance_scale
self._attention_kwargs = attention_kwargs
self._current_timestep = None
self._interrupt = False

# 2. Default call parameters
Expand Down Expand Up @@ -730,6 +732,7 @@ def __call__(
if self.interrupt:
continue

self._current_timestep = t
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)

Expand Down Expand Up @@ -779,6 +782,8 @@ def __call__(
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()

self._current_timestep = None

if not output_type == "latent":
video = self.decode_latents(latents)
video = self.video_processor.postprocess_video(video=video, output_type=output_type)
Expand Down
4 changes: 4 additions & 0 deletions src/diffusers/pipelines/pipeline_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1082,6 +1082,10 @@ def maybe_free_model_hooks(self):
is a no-op. Make sure to add this function to the end of the `__call__` function of your pipeline so that it
functions correctly when applying enable_model_cpu_offload.
"""

if hasattr(self, "_diffusers_hook"):
self._diffusers_hook.reset_state()

if not hasattr(self, "_all_hooks") or len(self._all_hooks) == 0:
# `enable_model_cpu_offload` has not be called, so silently do nothing
return
Expand Down
Loading
Loading