From 7bce1ea92be33e1f33ca3657e6a794657cea764e Mon Sep 17 00:00:00 2001 From: Aryan Date: Thu, 31 Oct 2024 23:58:45 +0100 Subject: [PATCH 1/3] rewrite implementation with hooks --- src/diffusers/models/hooks.py | 249 ++++++++++++++++++ .../pipelines/allegro/pipeline_allegro.py | 7 +- .../pipeline_cogvideox_fun_control.py | 7 +- src/diffusers/pipelines/pipeline_utils.py | 4 + .../pipelines/pyramid_broadcast_utils.py | 189 +++++-------- 5 files changed, 332 insertions(+), 124 deletions(-) create mode 100644 src/diffusers/models/hooks.py diff --git a/src/diffusers/models/hooks.py b/src/diffusers/models/hooks.py new file mode 100644 index 000000000000..df4def202621 --- /dev/null +++ b/src/diffusers/models/hooks.py @@ -0,0 +1,249 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import functools +from typing import Any, Callable, Dict, Tuple, Union + +import torch + + +# Reference: https://github.com/huggingface/accelerate/blob/ba7ab93f5e688466ea56908ea3b056fae2f9a023/src/accelerate/hooks.py +class ModelHook: + r""" + A hook that contains callbacks to be executed just before and after the forward method of a model. The difference + with PyTorch existing hooks is that they get passed along the kwargs. + """ + + _stateful_hook = False + + def init_hook(self, module: torch.nn.Module) -> torch.nn.Module: + r""" + Hook that is executed when a model is initialized. + + Args: + module (`torch.nn.Module`): + The module attached to this hook. + """ + return module + + def pre_forward(self, module: torch.nn.Module, *args, **kwargs) -> Tuple[Tuple[Any], Dict[str, Any]]: + r""" + Hook that is executed just before the forward method of the model. + + Args: + module (`torch.nn.Module`): + The module whose forward pass will be executed just after this event. + args (`Tuple[Any]`): + The positional arguments passed to the module. + kwargs (`Dict[Str, Any]`): + The keyword arguments passed to the module. + + Returns: + `Tuple[Tuple[Any], Dict[Str, Any]]`: + A tuple with the treated `args` and `kwargs`. + """ + return args, kwargs + + def post_forward(self, module: torch.nn.Module, output: Any) -> Any: + r""" + Hook that is executed just after the forward method of the model. + + Args: + module (`torch.nn.Module`): + The module whose forward pass been executed just before this event. + output (`Any`): + The output of the module. + + Returns: + `Any`: The processed `output`. + """ + return output + + def detach_hook(self, module: torch.nn.Module) -> torch.nn.Module: + r""" + Hook that is executed when the hook is detached from a module. + + Args: + module (`torch.nn.Module`): + The module detached from this hook. + """ + return module + + def reset_state(self, module: torch.nn.Module) -> torch.nn.Module: + return module + + +class SequentialHook(ModelHook): + r"""A hook that can contain several hooks and iterates through them at each event.""" + + def __init__(self, *hooks): + self.hooks = hooks + + def init_hook(self, module): + for hook in self.hooks: + module = hook.init_hook(module) + return module + + def pre_forward(self, module, *args, **kwargs): + for hook in self.hooks: + args, kwargs = hook.pre_forward(module, *args, **kwargs) + return args, kwargs + + def post_forward(self, module, output): + for hook in self.hooks: + output = hook.post_forward(module, output) + return output + + def detach_hook(self, module): + for hook in self.hooks: + module = hook.detach_hook(module) + return module + + def reset_state(self, module): + for hook in self.hooks: + module = hook.reset_state(module) + return module + + +class PyramidAttentionBroadcastHook(ModelHook): + _stateful_hook = True + + def __init__(self, skip_range: int, timestep_range: Tuple[int, int], timestep_callback: Callable[[], Union[torch.LongTensor, int]]) -> None: + super().__init__() + + self.skip_range = skip_range + self.timestep_range = timestep_range + self.timestep_callback = timestep_callback + + self.attention_cache = None + self._iteration = 0 + + def new_forward(self, module: torch.nn.Module, *args, **kwargs) -> Any: + args, kwargs = module._diffusers_hook.pre_forward(module, *args, **kwargs) + + current_timestep = self.timestep_callback() + is_within_timestep_range = self.timestep_range[0] < current_timestep < self.timestep_range[1] + should_compute_attention = self._iteration % self.skip_range == 0 + + if not is_within_timestep_range or should_compute_attention: + output = module._old_forward(*args, **kwargs) + else: + output = self.attention_cache + + self._iteration = self._iteration + 1 + + return module._diffusers_hook.post_forward(module, output) + + def post_forward(self, module: torch.nn.Module, output: Any) -> Any: + self.attention_cache = output + return output + + def reset_state(self, module: torch.nn.Module) -> torch.nn.Module: + self.attention_cache = None + self._iteration = 0 + return module + + +def add_hook_to_module(module: torch.nn.Module, hook: ModelHook, append: bool = False): + r""" + Adds a hook to a given module. This will rewrite the `forward` method of the module to include the hook, to remove + this behavior and restore the original `forward` method, use `remove_hook_from_module`. + + + + If the module already contains a hook, this will replace it with the new hook passed by default. To chain two hooks + together, pass `append=True`, so it chains the current and new hook into an instance of the `SequentialHook` class. + + + + Args: + module (`torch.nn.Module`): + The module to attach a hook to. + hook (`ModelHook`): + The hook to attach. + append (`bool`, *optional*, defaults to `False`): + Whether the hook should be chained with an existing one (if module already contains a hook) or not. + + Returns: + `torch.nn.Module`: + The same module, with the hook attached (the module is modified in place, so the result can be discarded). + """ + original_hook = hook + + if append and getattr(module, "_diffusers_hook", None) is not None: + old_hook = module._diffusers_hook + remove_hook_from_module(module) + hook = SequentialHook(old_hook, hook) + + if hasattr(module, "_diffusers_hook") and hasattr(module, "_old_forward"): + # If we already put some hook on this module, we replace it with the new one. + old_forward = module._old_forward + else: + old_forward = module.forward + module._old_forward = old_forward + + module = hook.init_hook(module) + module._diffusers_hook = hook + + if hasattr(original_hook, "new_forward"): + new_forward = original_hook.new_forward + else: + def new_forward(module, *args, **kwargs): + args, kwargs = module._diffusers_hook.pre_forward(module, *args, **kwargs) + output = module._old_forward(*args, **kwargs) + return module._diffusers_hook.post_forward(module, output) + + # Overriding a GraphModuleImpl forward freezes the forward call and later modifications on the graph will fail. + # Reference: https://pytorch.slack.com/archives/C3PDTEV8E/p1705929610405409 + if "GraphModuleImpl" in str(type(module)): + module.__class__.forward = functools.update_wrapper(functools.partial(new_forward, module), old_forward) + else: + module.forward = functools.update_wrapper(functools.partial(new_forward, module), old_forward) + + return module + + +def remove_hook_from_module(module: torch.nn.Module, recurse: bool = False) -> torch.nn.Module: + """ + Removes any hook attached to a module via `add_hook_to_module`. + + Args: + module (`torch.nn.Module`): + The module to attach a hook to. + recurse (`bool`, defaults to `False`): + Whether to remove the hooks recursively + + Returns: + `torch.nn.Module`: + The same module, with the hook detached (the module is modified in place, so the result can be discarded). + """ + + if hasattr(module, "_diffusers_hook"): + module._diffusers_hook.detach_hook(module) + delattr(module, "_diffusers_hook") + + if hasattr(module, "_old_forward"): + # Overriding a GraphModuleImpl forward freezes the forward call and later modifications on the graph will fail. + # Reference: https://pytorch.slack.com/archives/C3PDTEV8E/p1705929610405409 + if "GraphModuleImpl" in str(type(module)): + module.__class__.forward = module._old_forward + else: + module.forward = module._old_forward + delattr(module, "_old_forward") + + if recurse: + for child in module.children(): + remove_hook_from_module(child, recurse) + + return module diff --git a/src/diffusers/pipelines/allegro/pipeline_allegro.py b/src/diffusers/pipelines/allegro/pipeline_allegro.py index 9314960f9618..c93a97920225 100644 --- a/src/diffusers/pipelines/allegro/pipeline_allegro.py +++ b/src/diffusers/pipelines/allegro/pipeline_allegro.py @@ -38,6 +38,7 @@ ) from ...utils.torch_utils import randn_tensor from ...video_processor import VideoProcessor +from ..pyramid_broadcast_utils import PyramidAttentionBroadcastMixin from .pipeline_output import AllegroPipelineOutput @@ -131,7 +132,7 @@ def retrieve_timesteps( return timesteps, num_inference_steps -class AllegroPipeline(DiffusionPipeline): +class AllegroPipeline(DiffusionPipeline, PyramidAttentionBroadcastMixin): r""" Pipeline for text-to-video generation using Allegro. @@ -786,6 +787,7 @@ def __call__( negative_prompt_attention_mask, ) self._guidance_scale = guidance_scale + self._current_timestep = None self._interrupt = False # 2. Default height and width to transformer @@ -863,6 +865,7 @@ def __call__( if self.interrupt: continue + self._current_timestep = t latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) @@ -901,6 +904,8 @@ def __call__( if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() + self._current_timestep = None + if not output_type == "latent": latents = latents.to(self.vae.dtype) video = self.decode_latents(latents) diff --git a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py index 3655075bd519..f2a58c9ad198 100644 --- a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +++ b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py @@ -30,6 +30,7 @@ from ...utils import logging, replace_example_docstring from ...utils.torch_utils import randn_tensor from ...video_processor import VideoProcessor +from ..pyramid_broadcast_utils import PyramidAttentionBroadcastMixin from .pipeline_output import CogVideoXPipelineOutput @@ -144,7 +145,7 @@ def retrieve_timesteps( return timesteps, num_inference_steps -class CogVideoXFunControlPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin): +class CogVideoXFunControlPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin, PyramidAttentionBroadcastMixin): r""" Pipeline for controlled text-to-video generation using CogVideoX Fun. @@ -650,6 +651,7 @@ def __call__( ) self._guidance_scale = guidance_scale self._attention_kwargs = attention_kwargs + self._current_timestep = None self._interrupt = False # 2. Default call parameters @@ -730,6 +732,7 @@ def __call__( if self.interrupt: continue + self._current_timestep = t latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) @@ -779,6 +782,8 @@ def __call__( if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() + self._current_timestep = None + if not output_type == "latent": video = self.decode_latents(latents) video = self.video_processor.postprocess_video(video=video, output_type=output_type) diff --git a/src/diffusers/pipelines/pipeline_utils.py b/src/diffusers/pipelines/pipeline_utils.py index 2e1858b16148..aa790c830d1a 100644 --- a/src/diffusers/pipelines/pipeline_utils.py +++ b/src/diffusers/pipelines/pipeline_utils.py @@ -1082,6 +1082,10 @@ def maybe_free_model_hooks(self): is a no-op. Make sure to add this function to the end of the `__call__` function of your pipeline so that it functions correctly when applying enable_model_cpu_offload. """ + + if hasattr(self, "_diffusers_hook"): + self._diffusers_hook.reset_state() + if not hasattr(self, "_all_hooks") or len(self._all_hooks) == 0: # `enable_model_cpu_offload` has not be called, so silently do nothing return diff --git a/src/diffusers/pipelines/pyramid_broadcast_utils.py b/src/diffusers/pipelines/pyramid_broadcast_utils.py index 6e917568f33a..522b2385d3b9 100644 --- a/src/diffusers/pipelines/pyramid_broadcast_utils.py +++ b/src/diffusers/pipelines/pyramid_broadcast_utils.py @@ -12,12 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. -import inspect -from typing import Optional, Tuple +from typing import Optional, Tuple, List -import torch import torch.nn as nn +from ..models.hooks import PyramidAttentionBroadcastHook, add_hook_to_module, remove_hook_from_module from ..models.attention_processor import Attention, AttentionProcessor from ..utils import logging @@ -25,93 +24,6 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name -class PyramidAttentionBroadcastAttentionProcessorWrapper: - r""" - Helper attention processor that wraps logic required for Pyramid Attention Broadcast to function. - - PAB works by caching and re-using attention computations from past inference steps. This is due to the realization - that the attention states do not differ too much numerically between successive inference steps. The difference is - most significant/prominent in the spatial attention blocks, lesser so in the temporal attention blocks, and least - in cross attention blocks. - - Currently, only spatial and cross attention block skipping is supported in Diffusers due to not having any models - tested with temporal attention blocks. Feel free to open a PR adding support for this in case there's a model that - you would like to use PAB with. - - Args: - pipeline ([`~diffusers.DiffusionPipeline`]): - The underlying DiffusionPipeline object that inherits from the PAB Mixin and utilized this attention - processor. - processor ([`~diffusers.models.attention_processor.AttentionProcessor`]): - The underlying attention processor that will be wrapped to cache the intermediate attention computation. - skip_range (`int`): - The attention block to execute after skipping intermediate attention blocks. If set to the value `N`, `N - - 1` attention blocks are skipped and every N'th block is executed. Different models have different - tolerances to how much attention computation can be reused based on the differences between successive - blocks. So, this parameter must be adjusted per model after performing experimentation. Setting this value - to `2` is recommended for different models PAB has been experimented with. - timestep_range (`Tuple[int, int]`): - The timestep range between which PAB will remain activated in attention blocks. While activated, PAB will - re-use attention computations between inference steps. - """ - - def __init__( - self, pipeline, processor: AttentionProcessor, skip_range: int, timestep_range: Tuple[int, int] - ) -> None: - self.pipeline = pipeline - self._original_processor = processor - self._skip_range = skip_range - self._timestep_range = timestep_range - - self._prev_hidden_states = None - self._iteration = 0 - - original_processor_params = set(inspect.signature(self._original_processor.__call__).parameters.keys()) - supported_parameters = { - "attn", - "hidden_states", - "encoder_hidden_states", - "attention_mask", - "temb", - "image_rotary_emb", - } - self._attn_processor_params = supported_parameters.intersection(original_processor_params) - - def __call__( - self, - attn: Attention, - hidden_states: Optional[torch.Tensor] = None, - encoder_hidden_states: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, - temb: Optional[torch.Tensor] = None, - image_rotary_emb: Optional[torch.Tensor] = None, - *args, - **kwargs, - ): - r"""Method that wraps the underlying call to compute attention and cache states for re-use.""" - - if ( - hasattr(self.pipeline, "_current_timestep") - and self.pipeline._current_timestep is not None - and self._iteration % self._skip_range != 0 - and (self._timestep_range[0] < self.pipeline._current_timestep < self._timestep_range[1]) - ): - # Skip attention computation by re-using past attention states - hidden_states = self._prev_hidden_states - else: - # Perform attention computation - call_kwargs = {} - for param in self._attn_processor_params: - call_kwargs.update({param: locals()[param]}) - call_kwargs.update(kwargs) - hidden_states = self._original_processor(*args, **call_kwargs) - self._prev_hidden_states = hidden_states - - self._iteration = (self._iteration + 1) % self.pipeline.num_timesteps - - return hidden_states - - class PyramidAttentionBroadcastMixin: r"""Mixin class for [Pyramid Attention Broadcast](https://www.arxiv.org/abs/2408.12588).""" @@ -120,40 +32,56 @@ def _enable_pyramid_attention_broadcast(self) -> None: for name, module in denoiser.named_modules(): if isinstance(module, Attention): - logger.debug(f"Enabling Pyramid Attention Broadcast in layer: {name}") + is_spatial_attention = any(x in name for x in self._pab_spatial_attn_layer_identifiers) and self._pab_spatial_attn_skip_range is not None and not module.is_cross_attention + is_temporal_attention = any(x in name for x in self._pab_temporal_attn_layer_identifiers) and self._pab_temporal_attn_skip_range is not None and not module.is_cross_attention + is_cross_attention = any(x in name for x in self._pab_cross_attn_layer_identifiers) and self._pab_cross_attn_skip_range is not None and module.is_cross_attention - skip_range, timestep_range = None, None - if module.is_cross_attention and self._pab_cross_attn_skip_range is not None: - skip_range = self._pab_cross_attn_skip_range - timestep_range = self._pab_cross_attn_timestep_range - if not module.is_cross_attention and self._pab_spatial_attn_skip_range is not None: + if is_spatial_attention: skip_range = self._pab_spatial_attn_skip_range timestep_range = self._pab_spatial_attn_timestep_range - + if is_temporal_attention: + skip_range = self._pab_temporal_attn_skip_range + timestep_range = self._pab_temporal_attn_timestep_range + if is_cross_attention: + skip_range = self._pab_cross_attn_skip_range + timestep_range = self._pab_cross_attn_timestep_range + if skip_range is None: continue - - module.set_processor( - PyramidAttentionBroadcastAttentionProcessorWrapper( - self, module.processor, skip_range, timestep_range - ) + + # logger.debug(f"Enabling Pyramid Attention Broadcast in layer: {name}") + print(f"Enabling Pyramid Attention Broadcast in layer: {name}") + + add_hook_to_module( + module, + PyramidAttentionBroadcastHook( + skip_range=skip_range, + timestep_range=timestep_range, + timestep_callback=self._pyramid_attention_broadcast_timestep_callback + ), + append=True ) def _disable_pyramid_attention_broadcast(self) -> None: denoiser: nn.Module = self.transformer if hasattr(self, "transformer") else self.unet for name, module in denoiser.named_modules(): - if isinstance(module, Attention) and isinstance( - module.processor, PyramidAttentionBroadcastAttentionProcessorWrapper - ): - logger.debug(f"Disabling Pyramid Attention Broadcast in layer: {name}") - module.processor = module.processor._original_processor + logger.debug(f"Disabling Pyramid Attention Broadcast in layer: {name}") + remove_hook_from_module(module) + + def _pyramid_attention_broadcast_timestep_callback(self): + return self._current_timestep def enable_pyramid_attention_broadcast( self, spatial_attn_skip_range: Optional[int] = None, + spatial_attn_timestep_range: Tuple[int, int] = (100, 800), + temporal_attn_skip_range: Optional[int] = None, cross_attn_skip_range: Optional[int] = None, - spatial_attn_timestep_range: Optional[Tuple[int, int]] = None, - cross_attn_timestep_range: Optional[Tuple[int, int]] = None, + temporal_attn_timestep_range: Tuple[int, int] = (100, 800), + cross_attn_timestep_range: Tuple[int, int] = (100, 800), + spatial_attn_layer_identifiers: List[str] = ["blocks", "transformer_blocks"], + temporal_attn_layer_identifiers: List[str] = ["temporal_transformer_blocks"], + cross_attn_layer_identifiers: List[str] = ["blocks", "transformer_blocks"], ) -> None: r""" Enable pyramid attention broadcast to speedup inference by re-using attention states and skipping computation @@ -166,41 +94,53 @@ def enable_pyramid_attention_broadcast( different tolerances to how much attention computation can be reused based on the differences between successive blocks. So, this parameter must be adjusted per model after performing experimentation. Setting this value to `2` is recommended for different models PAB has been experimented with. + temporal_attn_skip_range (`int`, *optional*): + The attention block to execute after skipping intermediate temporal attention blocks. If set to the + value `N`, `N - 1` attention blocks are skipped and every N'th block is executed. Different models have + different tolerances to how much attention computation can be reused based on the differences between + successive blocks. So, this parameter must be adjusted per model after performing experimentation. + Setting this value to `4` is recommended for different models PAB has been experimented with. cross_attn_skip_range (`int`, *optional*): The attention block to execute after skipping intermediate cross attention blocks. If set to the value `N`, `N - 1` attention blocks are skipped and every N'th block is executed. Different models have different tolerances to how much attention computation can be reused based on the differences between successive blocks. So, this parameter must be adjusted per model after performing experimentation. Setting this value to `6` is recommended for different models PAB has been experimented with. - spatial_attn_timestep_range (`Tuple[int, int]`, *optional*): + spatial_attn_timestep_range (`Tuple[int, int]`, defaults to `(100, 800)`): The timestep range between which PAB will remain activated in spatial attention blocks. While - activated, PAB will re-use attention computations between inference steps. Setting this to `(100, 850)` - is recommended for different models PAB has been experimented with. - cross_attn_timestep_range (`Tuple[int, int]`, *optional*): + activated, PAB will re-use attention computations between inference steps. + temporal_attn_timestep_range (`Tuple[int, int]`, defaults to `(100, 800)`): + The timestep range between which PAB will remain activated in temporal attention blocks. While + activated, PAB will re-use attention computations between inference steps. + cross_attn_timestep_range (`Tuple[int, int]`, defaults to `(100, 800)`): The timestep range between which PAB will remain activated in cross attention blocks. While activated, - PAB will re-use attention computations between inference steps. Setting this to `(100, 800)` is - recommended for different models PAB has been experimented with. + PAB will re-use attention computations between inference steps. """ - - if spatial_attn_timestep_range is None: - spatial_attn_timestep_range = (100, 800) - if cross_attn_skip_range is None: - cross_attn_timestep_range = (100, 800) - + if spatial_attn_timestep_range[0] > spatial_attn_timestep_range[1]: raise ValueError( "Expected `spatial_attn_timestep_range` to be a tuple of two integers, with first value lesser or equal than second. These correspond to the min and max timestep between which PAB will be applied." ) + if temporal_attn_timestep_range[0] > temporal_attn_timestep_range[1]: + raise ValueError( + "Expected `temporal_attn_timestep_range` to be a tuple of two integers, with first value lesser or equal than second. These correspond to the min and max timestep between which PAB will be applied." + ) if cross_attn_timestep_range[0] > cross_attn_timestep_range[1]: raise ValueError( "Expected `cross_attn_timestep_range` to be a tuple of two integers, with first value lesser or equal than second. These correspond to the min and max timestep between which PAB will be applied." ) self._pab_spatial_attn_skip_range = spatial_attn_skip_range + self._pab_temporal_attn_skip_range = temporal_attn_skip_range self._pab_cross_attn_skip_range = cross_attn_skip_range self._pab_spatial_attn_timestep_range = spatial_attn_timestep_range + self._pab_temporal_attn_timestep_range = temporal_attn_timestep_range self._pab_cross_attn_timestep_range = cross_attn_timestep_range - self._pab_enabled = spatial_attn_skip_range or cross_attn_skip_range + self._pab_spatial_attn_layer_identifiers = spatial_attn_layer_identifiers + self._pab_temporal_attn_layer_identifiers = temporal_attn_layer_identifiers + self._pab_cross_attn_layer_identifiers = cross_attn_layer_identifiers + + self._pab_enabled = spatial_attn_skip_range or temporal_attn_skip_range or cross_attn_skip_range self._enable_pyramid_attention_broadcast() @@ -208,9 +148,14 @@ def disable_pyramid_attention_broadcast(self) -> None: r"""Disables the pyramid attention broadcast sampling mechanism.""" self._pab_spatial_attn_skip_range = None + self._pab_temporal_attn_skip_range = None self._pab_cross_attn_skip_range = None self._pab_spatial_attn_timestep_range = None + self._pab_temporal_attn_timestep_range = None self._pab_cross_attn_timestep_range = None + self._pab_spatial_attn_layer_identifiers = None + self._pab_temporal_attn_layer_identifiers = None + self._pab_cross_attn_layer_identifiers = None self._pab_enabled = False @property From 8e661d7dc2479efdd450e1b9abb056f917ccd602 Mon Sep 17 00:00:00 2001 From: Aryan Date: Thu, 31 Oct 2024 23:59:04 +0100 Subject: [PATCH 2/3] make style --- src/diffusers/models/hooks.py | 20 +++++++---- .../pipelines/allegro/pipeline_allegro.py | 2 +- .../pipeline_cogvideox_fun_control.py | 2 +- .../pipelines/pyramid_broadcast_utils.py | 34 +++++++++++++------ 4 files changed, 38 insertions(+), 20 deletions(-) diff --git a/src/diffusers/models/hooks.py b/src/diffusers/models/hooks.py index df4def202621..fa7e3af2f811 100644 --- a/src/diffusers/models/hooks.py +++ b/src/diffusers/models/hooks.py @@ -118,8 +118,13 @@ def reset_state(self, module): class PyramidAttentionBroadcastHook(ModelHook): _stateful_hook = True - - def __init__(self, skip_range: int, timestep_range: Tuple[int, int], timestep_callback: Callable[[], Union[torch.LongTensor, int]]) -> None: + + def __init__( + self, + skip_range: int, + timestep_range: Tuple[int, int], + timestep_callback: Callable[[], Union[torch.LongTensor, int]], + ) -> None: super().__init__() self.skip_range = skip_range @@ -128,10 +133,10 @@ def __init__(self, skip_range: int, timestep_range: Tuple[int, int], timestep_ca self.attention_cache = None self._iteration = 0 - + def new_forward(self, module: torch.nn.Module, *args, **kwargs) -> Any: args, kwargs = module._diffusers_hook.pre_forward(module, *args, **kwargs) - + current_timestep = self.timestep_callback() is_within_timestep_range = self.timestep_range[0] < current_timestep < self.timestep_range[1] should_compute_attention = self._iteration % self.skip_range == 0 @@ -140,15 +145,15 @@ def new_forward(self, module: torch.nn.Module, *args, **kwargs) -> Any: output = module._old_forward(*args, **kwargs) else: output = self.attention_cache - + self._iteration = self._iteration + 1 return module._diffusers_hook.post_forward(module, output) - + def post_forward(self, module: torch.nn.Module, output: Any) -> Any: self.attention_cache = output return output - + def reset_state(self, module: torch.nn.Module) -> torch.nn.Module: self.attention_cache = None self._iteration = 0 @@ -199,6 +204,7 @@ def add_hook_to_module(module: torch.nn.Module, hook: ModelHook, append: bool = if hasattr(original_hook, "new_forward"): new_forward = original_hook.new_forward else: + def new_forward(module, *args, **kwargs): args, kwargs = module._diffusers_hook.pre_forward(module, *args, **kwargs) output = module._old_forward(*args, **kwargs) diff --git a/src/diffusers/pipelines/allegro/pipeline_allegro.py b/src/diffusers/pipelines/allegro/pipeline_allegro.py index c93a97920225..10dd6455092d 100644 --- a/src/diffusers/pipelines/allegro/pipeline_allegro.py +++ b/src/diffusers/pipelines/allegro/pipeline_allegro.py @@ -905,7 +905,7 @@ def __call__( progress_bar.update() self._current_timestep = None - + if not output_type == "latent": latents = latents.to(self.vae.dtype) video = self.decode_latents(latents) diff --git a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py index f2a58c9ad198..9eeccec50621 100644 --- a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +++ b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py @@ -783,7 +783,7 @@ def __call__( progress_bar.update() self._current_timestep = None - + if not output_type == "latent": video = self.decode_latents(latents) video = self.video_processor.postprocess_video(video=video, output_type=output_type) diff --git a/src/diffusers/pipelines/pyramid_broadcast_utils.py b/src/diffusers/pipelines/pyramid_broadcast_utils.py index 522b2385d3b9..7fdb6a7f5b93 100644 --- a/src/diffusers/pipelines/pyramid_broadcast_utils.py +++ b/src/diffusers/pipelines/pyramid_broadcast_utils.py @@ -12,12 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional, Tuple, List +from typing import List, Optional, Tuple import torch.nn as nn +from ..models.attention_processor import Attention from ..models.hooks import PyramidAttentionBroadcastHook, add_hook_to_module, remove_hook_from_module -from ..models.attention_processor import Attention, AttentionProcessor from ..utils import logging @@ -32,9 +32,21 @@ def _enable_pyramid_attention_broadcast(self) -> None: for name, module in denoiser.named_modules(): if isinstance(module, Attention): - is_spatial_attention = any(x in name for x in self._pab_spatial_attn_layer_identifiers) and self._pab_spatial_attn_skip_range is not None and not module.is_cross_attention - is_temporal_attention = any(x in name for x in self._pab_temporal_attn_layer_identifiers) and self._pab_temporal_attn_skip_range is not None and not module.is_cross_attention - is_cross_attention = any(x in name for x in self._pab_cross_attn_layer_identifiers) and self._pab_cross_attn_skip_range is not None and module.is_cross_attention + is_spatial_attention = ( + any(x in name for x in self._pab_spatial_attn_layer_identifiers) + and self._pab_spatial_attn_skip_range is not None + and not module.is_cross_attention + ) + is_temporal_attention = ( + any(x in name for x in self._pab_temporal_attn_layer_identifiers) + and self._pab_temporal_attn_skip_range is not None + and not module.is_cross_attention + ) + is_cross_attention = ( + any(x in name for x in self._pab_cross_attn_layer_identifiers) + and self._pab_cross_attn_skip_range is not None + and module.is_cross_attention + ) if is_spatial_attention: skip_range = self._pab_spatial_attn_skip_range @@ -45,10 +57,10 @@ def _enable_pyramid_attention_broadcast(self) -> None: if is_cross_attention: skip_range = self._pab_cross_attn_skip_range timestep_range = self._pab_cross_attn_timestep_range - + if skip_range is None: continue - + # logger.debug(f"Enabling Pyramid Attention Broadcast in layer: {name}") print(f"Enabling Pyramid Attention Broadcast in layer: {name}") @@ -57,9 +69,9 @@ def _enable_pyramid_attention_broadcast(self) -> None: PyramidAttentionBroadcastHook( skip_range=skip_range, timestep_range=timestep_range, - timestep_callback=self._pyramid_attention_broadcast_timestep_callback + timestep_callback=self._pyramid_attention_broadcast_timestep_callback, ), - append=True + append=True, ) def _disable_pyramid_attention_broadcast(self) -> None: @@ -116,7 +128,7 @@ def enable_pyramid_attention_broadcast( The timestep range between which PAB will remain activated in cross attention blocks. While activated, PAB will re-use attention computations between inference steps. """ - + if spatial_attn_timestep_range[0] > spatial_attn_timestep_range[1]: raise ValueError( "Expected `spatial_attn_timestep_range` to be a tuple of two integers, with first value lesser or equal than second. These correspond to the min and max timestep between which PAB will be applied." @@ -139,7 +151,7 @@ def enable_pyramid_attention_broadcast( self._pab_spatial_attn_layer_identifiers = spatial_attn_layer_identifiers self._pab_temporal_attn_layer_identifiers = temporal_attn_layer_identifiers self._pab_cross_attn_layer_identifiers = cross_attn_layer_identifiers - + self._pab_enabled = spatial_attn_skip_range or temporal_attn_skip_range or cross_attn_skip_range self._enable_pyramid_attention_broadcast() From 77f17a545953294e69053391dd1f0e1b8d87e735 Mon Sep 17 00:00:00 2001 From: Aryan Date: Tue, 5 Nov 2024 15:18:08 +0100 Subject: [PATCH 3/3] update --- src/diffusers/models/hooks.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/src/diffusers/models/hooks.py b/src/diffusers/models/hooks.py index fa7e3af2f811..2b4351d4a94e 100644 --- a/src/diffusers/models/hooks.py +++ b/src/diffusers/models/hooks.py @@ -25,8 +25,6 @@ class ModelHook: with PyTorch existing hooks is that they get passed along the kwargs. """ - _stateful_hook = False - def init_hook(self, module: torch.nn.Module) -> torch.nn.Module: r""" Hook that is executed when a model is initialized. @@ -117,8 +115,6 @@ def reset_state(self, module): class PyramidAttentionBroadcastHook(ModelHook): - _stateful_hook = True - def __init__( self, skip_range: int,