diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index fc3022cf7b35..752219b4abd1 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -598,6 +598,8 @@ title: Attention Processor - local: api/activations title: Custom activation functions + - local: api/cache + title: Caching methods - local: api/normalization title: Custom normalization layers - local: api/utilities diff --git a/docs/source/en/api/cache.md b/docs/source/en/api/cache.md new file mode 100644 index 000000000000..403dbf88b431 --- /dev/null +++ b/docs/source/en/api/cache.md @@ -0,0 +1,49 @@ + + +# Caching methods + +## Pyramid Attention Broadcast + +[Pyramid Attention Broadcast](https://huggingface.co/papers/2408.12588) from Xuanlei Zhao, Xiaolong Jin, Kai Wang, Yang You. + +Pyramid Attention Broadcast (PAB) is a method that speeds up inference in diffusion models by systematically skipping attention computations between successive inference steps and reusing cached attention states. The attention states are not very different between successive inference steps. The most prominent difference is in the spatial attention blocks, not as much in the temporal attention blocks, and finally the least in the cross attention blocks. Therefore, many cross attention computation blocks can be skipped, followed by the temporal and spatial attention blocks. By combining other techniques like sequence parallelism and classifier-free guidance parallelism, PAB achieves near real-time video generation. + +Enable PAB with [`~PyramidAttentionBroadcastConfig`] on any pipeline. For some benchmarks, refer to [this](https://github.com/huggingface/diffusers/pull/9562) pull request. + +```python +import torch +from diffusers import CogVideoXPipeline, PyramidAttentionBroadcastConfig + +pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-5b", torch_dtype=torch.bfloat16) +pipe.to("cuda") + +# Increasing the value of `spatial_attention_timestep_skip_range[0]` or decreasing the value of +# `spatial_attention_timestep_skip_range[1]` will decrease the interval in which pyramid attention +# broadcast is active, leader to slower inference speeds. However, large intervals can lead to +# poorer quality of generated videos. +config = PyramidAttentionBroadcastConfig( + spatial_attention_block_skip_range=2, + spatial_attention_timestep_skip_range=(100, 800), + current_timestep_callback=lambda: pipe.current_timestep, +) +pipe.transformer.enable_cache(config) +``` + +### CacheMixin + +[[autodoc]] CacheMixin + +### PyramidAttentionBroadcastConfig + +[[autodoc]] PyramidAttentionBroadcastConfig + +[[autodoc]] apply_pyramid_attention_broadcast diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index b1801fbb2b4b..c36226225ad4 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -28,6 +28,7 @@ _import_structure = { "configuration_utils": ["ConfigMixin"], + "hooks": [], "loaders": ["FromOriginalModelMixin"], "models": [], "pipelines": [], @@ -75,6 +76,13 @@ _import_structure["utils.dummy_pt_objects"] = [name for name in dir(dummy_pt_objects) if not name.startswith("_")] else: + _import_structure["hooks"].extend( + [ + "HookRegistry", + "PyramidAttentionBroadcastConfig", + "apply_pyramid_attention_broadcast", + ] + ) _import_structure["models"].extend( [ "AllegroTransformer3DModel", @@ -90,6 +98,7 @@ "AutoencoderKLTemporalDecoder", "AutoencoderOobleck", "AutoencoderTiny", + "CacheMixin", "CogVideoXTransformer3DModel", "CogView3PlusTransformer2DModel", "ConsisIDTransformer3DModel", @@ -588,6 +597,7 @@ except OptionalDependencyNotAvailable: from .utils.dummy_pt_objects import * # noqa F403 else: + from .hooks import HookRegistry, PyramidAttentionBroadcastConfig, apply_pyramid_attention_broadcast from .models import ( AllegroTransformer3DModel, AsymmetricAutoencoderKL, @@ -602,6 +612,7 @@ AutoencoderKLTemporalDecoder, AutoencoderOobleck, AutoencoderTiny, + CacheMixin, CogVideoXTransformer3DModel, CogView3PlusTransformer2DModel, ConsisIDTransformer3DModel, diff --git a/src/diffusers/hooks/__init__.py b/src/diffusers/hooks/__init__.py index 91b2760acad0..e745b1320e84 100644 --- a/src/diffusers/hooks/__init__.py +++ b/src/diffusers/hooks/__init__.py @@ -2,4 +2,6 @@ if is_torch_available(): + from .hooks import HookRegistry, ModelHook from .layerwise_casting import apply_layerwise_casting, apply_layerwise_casting_hook + from .pyramid_attention_broadcast import PyramidAttentionBroadcastConfig, apply_pyramid_attention_broadcast diff --git a/src/diffusers/hooks/hooks.py b/src/diffusers/hooks/hooks.py index bef4c65c41e1..3b2e4ed91c2f 100644 --- a/src/diffusers/hooks/hooks.py +++ b/src/diffusers/hooks/hooks.py @@ -30,6 +30,9 @@ class ModelHook: _is_stateful = False + def __init__(self): + self.fn_ref: "HookFunctionReference" = None + def initialize_hook(self, module: torch.nn.Module) -> torch.nn.Module: r""" Hook that is executed when a model is initialized. @@ -48,8 +51,6 @@ def deinitalize_hook(self, module: torch.nn.Module) -> torch.nn.Module: module (`torch.nn.Module`): The module attached to this hook. """ - module.forward = module._old_forward - del module._old_forward return module def pre_forward(self, module: torch.nn.Module, *args, **kwargs) -> Tuple[Tuple[Any], Dict[str, Any]]: @@ -99,6 +100,29 @@ def reset_state(self, module: torch.nn.Module): return module +class HookFunctionReference: + def __init__(self) -> None: + """A container class that maintains mutable references to forward pass functions in a hook chain. + + Its mutable nature allows the hook system to modify the execution chain dynamically without rebuilding the + entire forward pass structure. + + Attributes: + pre_forward: A callable that processes inputs before the main forward pass. + post_forward: A callable that processes outputs after the main forward pass. + forward: The current forward function in the hook chain. + original_forward: The original forward function, stored when a hook provides a custom new_forward. + + The class enables hook removal by allowing updates to the forward chain through reference modification rather + than requiring reconstruction of the entire chain. When a hook is removed, only the relevant references need to + be updated, preserving the execution order of the remaining hooks. + """ + self.pre_forward = None + self.post_forward = None + self.forward = None + self.original_forward = None + + class HookRegistry: def __init__(self, module_ref: torch.nn.Module) -> None: super().__init__() @@ -107,51 +131,71 @@ def __init__(self, module_ref: torch.nn.Module) -> None: self._module_ref = module_ref self._hook_order = [] + self._fn_refs = [] def register_hook(self, hook: ModelHook, name: str) -> None: if name in self.hooks.keys(): - logger.warning(f"Hook with name {name} already exists, replacing it.") - - if hasattr(self._module_ref, "_old_forward"): - old_forward = self._module_ref._old_forward - else: - old_forward = self._module_ref.forward - self._module_ref._old_forward = self._module_ref.forward + raise ValueError( + f"Hook with name {name} already exists in the registry. Please use a different name or " + f"first remove the existing hook and then add a new one." + ) self._module_ref = hook.initialize_hook(self._module_ref) - if hasattr(hook, "new_forward"): - rewritten_forward = hook.new_forward - + def create_new_forward(function_reference: HookFunctionReference): def new_forward(module, *args, **kwargs): - args, kwargs = hook.pre_forward(module, *args, **kwargs) - output = rewritten_forward(module, *args, **kwargs) - return hook.post_forward(module, output) - else: + args, kwargs = function_reference.pre_forward(module, *args, **kwargs) + output = function_reference.forward(*args, **kwargs) + return function_reference.post_forward(module, output) - def new_forward(module, *args, **kwargs): - args, kwargs = hook.pre_forward(module, *args, **kwargs) - output = old_forward(*args, **kwargs) - return hook.post_forward(module, output) + return new_forward + + forward = self._module_ref.forward + fn_ref = HookFunctionReference() + fn_ref.pre_forward = hook.pre_forward + fn_ref.post_forward = hook.post_forward + fn_ref.forward = forward + + if hasattr(hook, "new_forward"): + fn_ref.original_forward = forward + fn_ref.forward = functools.update_wrapper( + functools.partial(hook.new_forward, self._module_ref), hook.new_forward + ) + + rewritten_forward = create_new_forward(fn_ref) self._module_ref.forward = functools.update_wrapper( - functools.partial(new_forward, self._module_ref), old_forward + functools.partial(rewritten_forward, self._module_ref), rewritten_forward ) + hook.fn_ref = fn_ref self.hooks[name] = hook self._hook_order.append(name) + self._fn_refs.append(fn_ref) def get_hook(self, name: str) -> Optional[ModelHook]: - if name not in self.hooks.keys(): - return None - return self.hooks[name] + return self.hooks.get(name, None) def remove_hook(self, name: str, recurse: bool = True) -> None: if name in self.hooks.keys(): + num_hooks = len(self._hook_order) hook = self.hooks[name] + index = self._hook_order.index(name) + fn_ref = self._fn_refs[index] + + old_forward = fn_ref.forward + if fn_ref.original_forward is not None: + old_forward = fn_ref.original_forward + + if index == num_hooks - 1: + self._module_ref.forward = old_forward + else: + self._fn_refs[index + 1].forward = old_forward + self._module_ref = hook.deinitalize_hook(self._module_ref) del self.hooks[name] - self._hook_order.remove(name) + self._hook_order.pop(index) + self._fn_refs.pop(index) if recurse: for module_name, module in self._module_ref.named_modules(): @@ -161,7 +205,7 @@ def remove_hook(self, name: str, recurse: bool = True) -> None: module._diffusers_hook.remove_hook(name, recurse=False) def reset_stateful_hooks(self, recurse: bool = True) -> None: - for hook_name in self._hook_order: + for hook_name in reversed(self._hook_order): hook = self.hooks[hook_name] if hook._is_stateful: hook.reset_state(self._module_ref) @@ -180,9 +224,13 @@ def check_if_exists_or_initialize(cls, module: torch.nn.Module) -> "HookRegistry return module._diffusers_hook def __repr__(self) -> str: - hook_repr = "" + registry_repr = "" for i, hook_name in enumerate(self._hook_order): - hook_repr += f" ({i}) {hook_name} - ({self.hooks[hook_name].__class__.__name__})" + if self.hooks[hook_name].__class__.__repr__ is not object.__repr__: + hook_repr = self.hooks[hook_name].__repr__() + else: + hook_repr = self.hooks[hook_name].__class__.__name__ + registry_repr += f" ({i}) {hook_name} - {hook_repr}" if i < len(self._hook_order) - 1: - hook_repr += "\n" - return f"HookRegistry(\n{hook_repr}\n)" + registry_repr += "\n" + return f"HookRegistry(\n{registry_repr}\n)" diff --git a/src/diffusers/hooks/pyramid_attention_broadcast.py b/src/diffusers/hooks/pyramid_attention_broadcast.py new file mode 100644 index 000000000000..9f8597d52f8c --- /dev/null +++ b/src/diffusers/hooks/pyramid_attention_broadcast.py @@ -0,0 +1,314 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import re +from dataclasses import dataclass +from typing import Any, Callable, Optional, Tuple, Union + +import torch + +from ..models.attention_processor import Attention, MochiAttention +from ..utils import logging +from .hooks import HookRegistry, ModelHook + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +_ATTENTION_CLASSES = (Attention, MochiAttention) + +_SPATIAL_ATTENTION_BLOCK_IDENTIFIERS = ("blocks", "transformer_blocks", "single_transformer_blocks") +_TEMPORAL_ATTENTION_BLOCK_IDENTIFIERS = ("temporal_transformer_blocks",) +_CROSS_ATTENTION_BLOCK_IDENTIFIERS = ("blocks", "transformer_blocks") + + +@dataclass +class PyramidAttentionBroadcastConfig: + r""" + Configuration for Pyramid Attention Broadcast. + + Args: + spatial_attention_block_skip_range (`int`, *optional*, defaults to `None`): + The number of times a specific spatial attention broadcast is skipped before computing the attention states + to re-use. If this is set to the value `N`, the attention computation will be skipped `N - 1` times (i.e., + old attention states will be re-used) before computing the new attention states again. + temporal_attention_block_skip_range (`int`, *optional*, defaults to `None`): + The number of times a specific temporal attention broadcast is skipped before computing the attention + states to re-use. If this is set to the value `N`, the attention computation will be skipped `N - 1` times + (i.e., old attention states will be re-used) before computing the new attention states again. + cross_attention_block_skip_range (`int`, *optional*, defaults to `None`): + The number of times a specific cross-attention broadcast is skipped before computing the attention states + to re-use. If this is set to the value `N`, the attention computation will be skipped `N - 1` times (i.e., + old attention states will be re-used) before computing the new attention states again. + spatial_attention_timestep_skip_range (`Tuple[int, int]`, defaults to `(100, 800)`): + The range of timesteps to skip in the spatial attention layer. The attention computations will be + conditionally skipped if the current timestep is within the specified range. + temporal_attention_timestep_skip_range (`Tuple[int, int]`, defaults to `(100, 800)`): + The range of timesteps to skip in the temporal attention layer. The attention computations will be + conditionally skipped if the current timestep is within the specified range. + cross_attention_timestep_skip_range (`Tuple[int, int]`, defaults to `(100, 800)`): + The range of timesteps to skip in the cross-attention layer. The attention computations will be + conditionally skipped if the current timestep is within the specified range. + spatial_attention_block_identifiers (`Tuple[str, ...]`, defaults to `("blocks", "transformer_blocks")`): + The identifiers to match against the layer names to determine if the layer is a spatial attention layer. + temporal_attention_block_identifiers (`Tuple[str, ...]`, defaults to `("temporal_transformer_blocks",)`): + The identifiers to match against the layer names to determine if the layer is a temporal attention layer. + cross_attention_block_identifiers (`Tuple[str, ...]`, defaults to `("blocks", "transformer_blocks")`): + The identifiers to match against the layer names to determine if the layer is a cross-attention layer. + """ + + spatial_attention_block_skip_range: Optional[int] = None + temporal_attention_block_skip_range: Optional[int] = None + cross_attention_block_skip_range: Optional[int] = None + + spatial_attention_timestep_skip_range: Tuple[int, int] = (100, 800) + temporal_attention_timestep_skip_range: Tuple[int, int] = (100, 800) + cross_attention_timestep_skip_range: Tuple[int, int] = (100, 800) + + spatial_attention_block_identifiers: Tuple[str, ...] = _SPATIAL_ATTENTION_BLOCK_IDENTIFIERS + temporal_attention_block_identifiers: Tuple[str, ...] = _TEMPORAL_ATTENTION_BLOCK_IDENTIFIERS + cross_attention_block_identifiers: Tuple[str, ...] = _CROSS_ATTENTION_BLOCK_IDENTIFIERS + + current_timestep_callback: Callable[[], int] = None + + # TODO(aryan): add PAB for MLP layers (very limited speedup from testing with original codebase + # so not added for now) + + def __repr__(self) -> str: + return ( + f"PyramidAttentionBroadcastConfig(" + f" spatial_attention_block_skip_range={self.spatial_attention_block_skip_range},\n" + f" temporal_attention_block_skip_range={self.temporal_attention_block_skip_range},\n" + f" cross_attention_block_skip_range={self.cross_attention_block_skip_range},\n" + f" spatial_attention_timestep_skip_range={self.spatial_attention_timestep_skip_range},\n" + f" temporal_attention_timestep_skip_range={self.temporal_attention_timestep_skip_range},\n" + f" cross_attention_timestep_skip_range={self.cross_attention_timestep_skip_range},\n" + f" spatial_attention_block_identifiers={self.spatial_attention_block_identifiers},\n" + f" temporal_attention_block_identifiers={self.temporal_attention_block_identifiers},\n" + f" cross_attention_block_identifiers={self.cross_attention_block_identifiers},\n" + f" current_timestep_callback={self.current_timestep_callback}\n" + ")" + ) + + +class PyramidAttentionBroadcastState: + r""" + State for Pyramid Attention Broadcast. + + Attributes: + iteration (`int`): + The current iteration of the Pyramid Attention Broadcast. It is necessary to ensure that `reset_state` is + called before starting a new inference forward pass for PAB to work correctly. + cache (`Any`): + The cached output from the previous forward pass. This is used to re-use the attention states when the + attention computation is skipped. It is either a tensor or a tuple of tensors, depending on the module. + """ + + def __init__(self) -> None: + self.iteration = 0 + self.cache = None + + def reset(self): + self.iteration = 0 + self.cache = None + + def __repr__(self): + cache_repr = "" + if self.cache is None: + cache_repr = "None" + else: + cache_repr = f"Tensor(shape={self.cache.shape}, dtype={self.cache.dtype})" + return f"PyramidAttentionBroadcastState(iteration={self.iteration}, cache={cache_repr})" + + +class PyramidAttentionBroadcastHook(ModelHook): + r"""A hook that applies Pyramid Attention Broadcast to a given module.""" + + _is_stateful = True + + def __init__( + self, timestep_skip_range: Tuple[int, int], block_skip_range: int, current_timestep_callback: Callable[[], int] + ) -> None: + super().__init__() + + self.timestep_skip_range = timestep_skip_range + self.block_skip_range = block_skip_range + self.current_timestep_callback = current_timestep_callback + + def initialize_hook(self, module): + self.state = PyramidAttentionBroadcastState() + return module + + def new_forward(self, module: torch.nn.Module, *args, **kwargs) -> Any: + is_within_timestep_range = ( + self.timestep_skip_range[0] < self.current_timestep_callback() < self.timestep_skip_range[1] + ) + should_compute_attention = ( + self.state.cache is None + or self.state.iteration == 0 + or not is_within_timestep_range + or self.state.iteration % self.block_skip_range == 0 + ) + + if should_compute_attention: + output = self.fn_ref.original_forward(*args, **kwargs) + else: + output = self.state.cache + + self.state.cache = output + self.state.iteration += 1 + return output + + def reset_state(self, module: torch.nn.Module) -> None: + self.state.reset() + return module + + +def apply_pyramid_attention_broadcast( + module: torch.nn.Module, + config: PyramidAttentionBroadcastConfig, +): + r""" + Apply [Pyramid Attention Broadcast](https://huggingface.co/papers/2408.12588) to a given pipeline. + + PAB is an attention approximation method that leverages the similarity in attention states between timesteps to + reduce the computational cost of attention computation. The key takeaway from the paper is that the attention + similarity in the cross-attention layers between timesteps is high, followed by less similarity in the temporal and + spatial layers. This allows for the skipping of attention computation in the cross-attention layers more frequently + than in the temporal and spatial layers. Applying PAB will, therefore, speedup the inference process. + + Args: + module (`torch.nn.Module`): + The module to apply Pyramid Attention Broadcast to. + config (`Optional[PyramidAttentionBroadcastConfig]`, `optional`, defaults to `None`): + The configuration to use for Pyramid Attention Broadcast. + + Example: + + ```python + >>> import torch + >>> from diffusers import CogVideoXPipeline, PyramidAttentionBroadcastConfig, apply_pyramid_attention_broadcast + >>> from diffusers.utils import export_to_video + + >>> pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-5b", torch_dtype=torch.bfloat16) + >>> pipe.to("cuda") + + >>> config = PyramidAttentionBroadcastConfig( + ... spatial_attention_block_skip_range=2, + ... spatial_attention_timestep_skip_range=(100, 800), + ... current_timestep_callback=lambda: pipe.current_timestep, + ... ) + >>> apply_pyramid_attention_broadcast(pipe.transformer, config) + ``` + """ + if config.current_timestep_callback is None: + raise ValueError( + "The `current_timestep_callback` function must be provided in the configuration to apply Pyramid Attention Broadcast." + ) + + if ( + config.spatial_attention_block_skip_range is None + and config.temporal_attention_block_skip_range is None + and config.cross_attention_block_skip_range is None + ): + logger.warning( + "Pyramid Attention Broadcast requires one or more of `spatial_attention_block_skip_range`, `temporal_attention_block_skip_range` " + "or `cross_attention_block_skip_range` parameters to be set to an integer, not `None`. Defaulting to using `spatial_attention_block_skip_range=2`. " + "To avoid this warning, please set one of the above parameters." + ) + config.spatial_attention_block_skip_range = 2 + + for name, submodule in module.named_modules(): + if not isinstance(submodule, _ATTENTION_CLASSES): + # PAB has been implemented specific to Diffusers' Attention classes. However, this does not mean that PAB + # cannot be applied to this layer. For custom layers, users can extend this functionality and implement + # their own PAB logic similar to `_apply_pyramid_attention_broadcast_on_attention_class`. + continue + _apply_pyramid_attention_broadcast_on_attention_class(name, submodule, config) + + +def _apply_pyramid_attention_broadcast_on_attention_class( + name: str, module: Attention, config: PyramidAttentionBroadcastConfig +) -> bool: + is_spatial_self_attention = ( + any(re.search(identifier, name) is not None for identifier in config.spatial_attention_block_identifiers) + and config.spatial_attention_block_skip_range is not None + and not getattr(module, "is_cross_attention", False) + ) + is_temporal_self_attention = ( + any(re.search(identifier, name) is not None for identifier in config.temporal_attention_block_identifiers) + and config.temporal_attention_block_skip_range is not None + and not getattr(module, "is_cross_attention", False) + ) + is_cross_attention = ( + any(re.search(identifier, name) is not None for identifier in config.cross_attention_block_identifiers) + and config.cross_attention_block_skip_range is not None + and getattr(module, "is_cross_attention", False) + ) + + block_skip_range, timestep_skip_range, block_type = None, None, None + if is_spatial_self_attention: + block_skip_range = config.spatial_attention_block_skip_range + timestep_skip_range = config.spatial_attention_timestep_skip_range + block_type = "spatial" + elif is_temporal_self_attention: + block_skip_range = config.temporal_attention_block_skip_range + timestep_skip_range = config.temporal_attention_timestep_skip_range + block_type = "temporal" + elif is_cross_attention: + block_skip_range = config.cross_attention_block_skip_range + timestep_skip_range = config.cross_attention_timestep_skip_range + block_type = "cross" + + if block_skip_range is None or timestep_skip_range is None: + logger.info( + f'Unable to apply Pyramid Attention Broadcast to the selected layer: "{name}" because it does ' + f"not match any of the required criteria for spatial, temporal or cross attention layers. Note, " + f"however, that this layer may still be valid for applying PAB. Please specify the correct " + f"block identifiers in the configuration." + ) + return False + + logger.debug(f"Enabling Pyramid Attention Broadcast ({block_type}) in layer: {name}") + _apply_pyramid_attention_broadcast_hook( + module, timestep_skip_range, block_skip_range, config.current_timestep_callback + ) + return True + + +def _apply_pyramid_attention_broadcast_hook( + module: Union[Attention, MochiAttention], + timestep_skip_range: Tuple[int, int], + block_skip_range: int, + current_timestep_callback: Callable[[], int], +): + r""" + Apply [Pyramid Attention Broadcast](https://huggingface.co/papers/2408.12588) to a given torch.nn.Module. + + Args: + module (`torch.nn.Module`): + The module to apply Pyramid Attention Broadcast to. + timestep_skip_range (`Tuple[int, int]`): + The range of timesteps to skip in the attention layer. The attention computations will be conditionally + skipped if the current timestep is within the specified range. + block_skip_range (`int`): + The number of times a specific attention broadcast is skipped before computing the attention states to + re-use. If this is set to the value `N`, the attention computation will be skipped `N - 1` times (i.e., old + attention states will be re-used) before computing the new attention states again. + current_timestep_callback (`Callable[[], int]`): + A callback function that returns the current inference timestep. + """ + registry = HookRegistry.check_if_exists_or_initialize(module) + hook = PyramidAttentionBroadcastHook(timestep_skip_range, block_skip_range, current_timestep_callback) + registry.register_hook(hook, "pyramid_attention_broadcast") diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py index e3f291ce2dc7..57a34609d28e 100644 --- a/src/diffusers/models/__init__.py +++ b/src/diffusers/models/__init__.py @@ -39,6 +39,7 @@ _import_structure["autoencoders.autoencoder_tiny"] = ["AutoencoderTiny"] _import_structure["autoencoders.consistency_decoder_vae"] = ["ConsistencyDecoderVAE"] _import_structure["autoencoders.vq_model"] = ["VQModel"] + _import_structure["cache_utils"] = ["CacheMixin"] _import_structure["controlnets.controlnet"] = ["ControlNetModel"] _import_structure["controlnets.controlnet_flux"] = ["FluxControlNetModel", "FluxMultiControlNetModel"] _import_structure["controlnets.controlnet_hunyuan"] = [ @@ -109,6 +110,7 @@ ConsistencyDecoderVAE, VQModel, ) + from .cache_utils import CacheMixin from .controlnets import ( ControlNetModel, ControlNetUnionModel, diff --git a/src/diffusers/models/cache_utils.py b/src/diffusers/models/cache_utils.py new file mode 100644 index 000000000000..f2c621b3011a --- /dev/null +++ b/src/diffusers/models/cache_utils.py @@ -0,0 +1,89 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ..utils.logging import get_logger + + +logger = get_logger(__name__) # pylint: disable=invalid-name + + +class CacheMixin: + r""" + A class for enable/disabling caching techniques on diffusion models. + + Supported caching techniques: + - [Pyramid Attention Broadcast](https://huggingface.co/papers/2408.12588) + """ + + _cache_config = None + + @property + def is_cache_enabled(self) -> bool: + return self._cache_config is not None + + def enable_cache(self, config) -> None: + r""" + Enable caching techniques on the model. + + Args: + config (`Union[PyramidAttentionBroadcastConfig]`): + The configuration for applying the caching technique. Currently supported caching techniques are: + - [`~hooks.PyramidAttentionBroadcastConfig`] + + Example: + + ```python + >>> import torch + >>> from diffusers import CogVideoXPipeline, PyramidAttentionBroadcastConfig + + >>> pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-5b", torch_dtype=torch.bfloat16) + >>> pipe.to("cuda") + + >>> config = PyramidAttentionBroadcastConfig( + ... spatial_attention_block_skip_range=2, + ... spatial_attention_timestep_skip_range=(100, 800), + ... current_timestep_callback=lambda: pipe.current_timestep, + ... ) + >>> pipe.transformer.enable_cache(config) + ``` + """ + + from ..hooks import PyramidAttentionBroadcastConfig, apply_pyramid_attention_broadcast + + if isinstance(config, PyramidAttentionBroadcastConfig): + apply_pyramid_attention_broadcast(self, config) + else: + raise ValueError(f"Cache config {type(config)} is not supported.") + + self._cache_config = config + + def disable_cache(self) -> None: + from ..hooks import HookRegistry, PyramidAttentionBroadcastConfig + + if self._cache_config is None: + logger.warning("Caching techniques have not been enabled, so there's nothing to disable.") + return + + if isinstance(self._cache_config, PyramidAttentionBroadcastConfig): + registry = HookRegistry.check_if_exists_or_initialize(self) + registry.remove_hook("pyramid_attention_broadcast", recurse=True) + else: + raise ValueError(f"Cache config {type(self._cache_config)} is not supported.") + + self._cache_config = None + + def _reset_stateful_cache(self, recurse: bool = True) -> None: + from ..hooks import HookRegistry + + HookRegistry.check_if_exists_or_initialize(self).reset_stateful_hooks(recurse=recurse) diff --git a/src/diffusers/models/transformers/cogvideox_transformer_3d.py b/src/diffusers/models/transformers/cogvideox_transformer_3d.py index c3039180b81d..583a2482fc07 100644 --- a/src/diffusers/models/transformers/cogvideox_transformer_3d.py +++ b/src/diffusers/models/transformers/cogvideox_transformer_3d.py @@ -24,6 +24,7 @@ from ...utils.torch_utils import maybe_allow_in_graph from ..attention import Attention, FeedForward from ..attention_processor import AttentionProcessor, CogVideoXAttnProcessor2_0, FusedCogVideoXAttnProcessor2_0 +from ..cache_utils import CacheMixin from ..embeddings import CogVideoXPatchEmbed, TimestepEmbedding, Timesteps from ..modeling_outputs import Transformer2DModelOutput from ..modeling_utils import ModelMixin @@ -156,7 +157,7 @@ def forward( return hidden_states, encoder_hidden_states -class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin): +class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, CacheMixin): """ A Transformer model for video-like data in [CogVideoX](https://github.com/THUDM/CogVideo). diff --git a/src/diffusers/models/transformers/latte_transformer_3d.py b/src/diffusers/models/transformers/latte_transformer_3d.py index be06f44a9efe..fbdae37ae561 100644 --- a/src/diffusers/models/transformers/latte_transformer_3d.py +++ b/src/diffusers/models/transformers/latte_transformer_3d.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + from typing import Optional import torch @@ -19,13 +20,14 @@ from ...configuration_utils import ConfigMixin, register_to_config from ...models.embeddings import PixArtAlphaTextProjection, get_1d_sincos_pos_embed_from_grid from ..attention import BasicTransformerBlock +from ..cache_utils import CacheMixin from ..embeddings import PatchEmbed from ..modeling_outputs import Transformer2DModelOutput from ..modeling_utils import ModelMixin from ..normalization import AdaLayerNormSingle -class LatteTransformer3DModel(ModelMixin, ConfigMixin): +class LatteTransformer3DModel(ModelMixin, ConfigMixin, CacheMixin): _supports_gradient_checkpointing = True """ diff --git a/src/diffusers/models/transformers/transformer_allegro.py b/src/diffusers/models/transformers/transformer_allegro.py index f32c38394ba4..672f3c2a1dc3 100644 --- a/src/diffusers/models/transformers/transformer_allegro.py +++ b/src/diffusers/models/transformers/transformer_allegro.py @@ -24,6 +24,7 @@ from ...utils.torch_utils import maybe_allow_in_graph from ..attention import FeedForward from ..attention_processor import AllegroAttnProcessor2_0, Attention +from ..cache_utils import CacheMixin from ..embeddings import PatchEmbed, PixArtAlphaTextProjection from ..modeling_outputs import Transformer2DModelOutput from ..modeling_utils import ModelMixin @@ -172,7 +173,7 @@ def forward( return hidden_states -class AllegroTransformer3DModel(ModelMixin, ConfigMixin): +class AllegroTransformer3DModel(ModelMixin, ConfigMixin, CacheMixin): _supports_gradient_checkpointing = True """ diff --git a/src/diffusers/models/transformers/transformer_flux.py b/src/diffusers/models/transformers/transformer_flux.py index db8d73856689..d65ad00e057f 100644 --- a/src/diffusers/models/transformers/transformer_flux.py +++ b/src/diffusers/models/transformers/transformer_flux.py @@ -35,6 +35,7 @@ from ...utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers from ...utils.import_utils import is_torch_npu_available from ...utils.torch_utils import maybe_allow_in_graph +from ..cache_utils import CacheMixin from ..embeddings import CombinedTimestepGuidanceTextProjEmbeddings, CombinedTimestepTextProjEmbeddings, FluxPosEmbed from ..modeling_outputs import Transformer2DModelOutput @@ -227,7 +228,7 @@ def forward( class FluxTransformer2DModel( - ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, FluxTransformer2DLoadersMixin + ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, FluxTransformer2DLoadersMixin, CacheMixin ): """ The Transformer model introduced in Flux. diff --git a/src/diffusers/models/transformers/transformer_hunyuan_video.py b/src/diffusers/models/transformers/transformer_hunyuan_video.py index 210a2e711972..4a820d98d584 100644 --- a/src/diffusers/models/transformers/transformer_hunyuan_video.py +++ b/src/diffusers/models/transformers/transformer_hunyuan_video.py @@ -25,6 +25,7 @@ from ...utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers from ..attention import FeedForward from ..attention_processor import Attention, AttentionProcessor +from ..cache_utils import CacheMixin from ..embeddings import ( CombinedTimestepGuidanceTextProjEmbeddings, CombinedTimestepTextProjEmbeddings, @@ -502,7 +503,7 @@ def forward( return hidden_states, encoder_hidden_states -class HunyuanVideoTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin): +class HunyuanVideoTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, CacheMixin): r""" A Transformer model for video-like data used in [HunyuanVideo](https://huggingface.co/tencent/HunyuanVideo). diff --git a/src/diffusers/models/transformers/transformer_mochi.py b/src/diffusers/models/transformers/transformer_mochi.py index d16430f27931..ce4ee510cfa5 100644 --- a/src/diffusers/models/transformers/transformer_mochi.py +++ b/src/diffusers/models/transformers/transformer_mochi.py @@ -25,6 +25,7 @@ from ...utils.torch_utils import maybe_allow_in_graph from ..attention import FeedForward from ..attention_processor import MochiAttention, MochiAttnProcessor2_0 +from ..cache_utils import CacheMixin from ..embeddings import MochiCombinedTimestepCaptionEmbedding, PatchEmbed from ..modeling_outputs import Transformer2DModelOutput from ..modeling_utils import ModelMixin @@ -305,7 +306,7 @@ def forward( @maybe_allow_in_graph -class MochiTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin): +class MochiTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, CacheMixin): r""" A Transformer model for video-like data introduced in [Mochi](https://huggingface.co/genmo/mochi-1-preview). diff --git a/src/diffusers/pipelines/allegro/pipeline_allegro.py b/src/diffusers/pipelines/allegro/pipeline_allegro.py index 91aedf2cdbe6..cb36a7a672de 100644 --- a/src/diffusers/pipelines/allegro/pipeline_allegro.py +++ b/src/diffusers/pipelines/allegro/pipeline_allegro.py @@ -683,6 +683,10 @@ def guidance_scale(self): def num_timesteps(self): return self._num_timesteps + @property + def current_timestep(self): + return self._current_timestep + @property def interrupt(self): return self._interrupt @@ -815,6 +819,7 @@ def __call__( negative_prompt_attention_mask, ) self._guidance_scale = guidance_scale + self._current_timestep = None self._interrupt = False # 2. Default height and width to transformer @@ -892,6 +897,7 @@ def __call__( if self.interrupt: continue + self._current_timestep = t latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) @@ -933,6 +939,8 @@ def __call__( if XLA_AVAILABLE: xm.mark_step() + self._current_timestep = None + if not output_type == "latent": latents = latents.to(self.vae.dtype) video = self.decode_latents(latents) diff --git a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py index d78d5508dc7f..99ae9025cd3e 100644 --- a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py +++ b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py @@ -494,6 +494,10 @@ def num_timesteps(self): def attention_kwargs(self): return self._attention_kwargs + @property + def current_timestep(self): + return self._current_timestep + @property def interrupt(self): return self._interrupt @@ -627,6 +631,7 @@ def __call__( ) self._guidance_scale = guidance_scale self._attention_kwargs = attention_kwargs + self._current_timestep = None self._interrupt = False # 2. Default call parameters @@ -705,6 +710,7 @@ def __call__( if self.interrupt: continue + self._current_timestep = t latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) @@ -763,6 +769,8 @@ def __call__( if XLA_AVAILABLE: xm.mark_step() + self._current_timestep = None + if not output_type == "latent": # Discard any padding frames that were added for CogVideoX 1.5 latents = latents[:, additional_frames:] diff --git a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py index 46e7b9ee468e..e37574ec9cb2 100644 --- a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +++ b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py @@ -540,6 +540,10 @@ def num_timesteps(self): def attention_kwargs(self): return self._attention_kwargs + @property + def current_timestep(self): + return self._current_timestep + @property def interrupt(self): return self._interrupt @@ -680,6 +684,7 @@ def __call__( ) self._guidance_scale = guidance_scale self._attention_kwargs = attention_kwargs + self._current_timestep = None self._interrupt = False # 2. Default call parameters @@ -766,6 +771,7 @@ def __call__( if self.interrupt: continue + self._current_timestep = t latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) @@ -818,6 +824,8 @@ def __call__( if XLA_AVAILABLE: xm.mark_step() + self._current_timestep = None + if not output_type == "latent": video = self.decode_latents(latents) video = self.video_processor.postprocess_video(video=video, output_type=output_type) diff --git a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py index 58793902345a..59d7c4cad547 100644 --- a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +++ b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py @@ -591,6 +591,10 @@ def num_timesteps(self): def attention_kwargs(self): return self._attention_kwargs + @property + def current_timestep(self): + return self._current_timestep + @property def interrupt(self): return self._interrupt @@ -728,6 +732,7 @@ def __call__( negative_prompt_embeds=negative_prompt_embeds, ) self._guidance_scale = guidance_scale + self._current_timestep = None self._attention_kwargs = attention_kwargs self._interrupt = False @@ -815,6 +820,7 @@ def __call__( if self.interrupt: continue + self._current_timestep = t latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) @@ -877,6 +883,8 @@ def __call__( if XLA_AVAILABLE: xm.mark_step() + self._current_timestep = None + if not output_type == "latent": # Discard any padding frames that were added for CogVideoX 1.5 latents = latents[:, additional_frames:] diff --git a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py index 333e3418dca2..c4dc7e574f7e 100644 --- a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +++ b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py @@ -564,6 +564,10 @@ def num_timesteps(self): def attention_kwargs(self): return self._attention_kwargs + @property + def current_timestep(self): + return self._current_timestep + @property def interrupt(self): return self._interrupt @@ -700,6 +704,7 @@ def __call__( ) self._guidance_scale = guidance_scale self._attention_kwargs = attention_kwargs + self._current_timestep = None self._interrupt = False # 2. Default call parameters @@ -786,6 +791,7 @@ def __call__( if self.interrupt: continue + self._current_timestep = t latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) @@ -844,6 +850,8 @@ def __call__( if XLA_AVAILABLE: xm.mark_step() + self._current_timestep = None + if not output_type == "latent": video = self.decode_latents(latents) video = self.video_processor.postprocess_video(video=video, output_type=output_type) diff --git a/src/diffusers/pipelines/flux/pipeline_flux.py b/src/diffusers/pipelines/flux/pipeline_flux.py index f5716dc9c8ea..aa02dc1de5da 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux.py +++ b/src/diffusers/pipelines/flux/pipeline_flux.py @@ -28,8 +28,7 @@ from ...image_processor import PipelineImageInput, VaeImageProcessor from ...loaders import FluxIPAdapterMixin, FluxLoraLoaderMixin, FromSingleFileMixin, TextualInversionLoaderMixin -from ...models.autoencoders import AutoencoderKL -from ...models.transformers import FluxTransformer2DModel +from ...models import AutoencoderKL, FluxTransformer2DModel from ...schedulers import FlowMatchEulerDiscreteScheduler from ...utils import ( USE_PEFT_BACKEND, @@ -620,6 +619,10 @@ def joint_attention_kwargs(self): def num_timesteps(self): return self._num_timesteps + @property + def current_timestep(self): + return self._current_timestep + @property def interrupt(self): return self._interrupt @@ -775,6 +778,7 @@ def __call__( self._guidance_scale = guidance_scale self._joint_attention_kwargs = joint_attention_kwargs + self._current_timestep = None self._interrupt = False # 2. Define call parameters @@ -899,6 +903,7 @@ def __call__( if self.interrupt: continue + self._current_timestep = t if image_embeds is not None: self._joint_attention_kwargs["ip_adapter_image_embeds"] = image_embeds # broadcast to batch dimension in a way that's compatible with ONNX/Core ML @@ -957,9 +962,10 @@ def __call__( if XLA_AVAILABLE: xm.mark_step() + self._current_timestep = None + if output_type == "latent": image = latents - else: latents = self._unpack_latents(latents, height, width, self.vae_scale_factor) latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor diff --git a/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py b/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py index 5c3d6ce611cc..8cc77ed4c148 100644 --- a/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +++ b/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py @@ -456,6 +456,10 @@ def num_timesteps(self): def attention_kwargs(self): return self._attention_kwargs + @property + def current_timestep(self): + return self._current_timestep + @property def interrupt(self): return self._interrupt @@ -577,6 +581,7 @@ def __call__( self._guidance_scale = guidance_scale self._attention_kwargs = attention_kwargs + self._current_timestep = None self._interrupt = False device = self._execution_device @@ -644,6 +649,7 @@ def __call__( if self.interrupt: continue + self._current_timestep = t latent_model_input = latents.to(transformer_dtype) # broadcast to batch dimension in a way that's compatible with ONNX/Core ML timestep = t.expand(latents.shape[0]).to(latents.dtype) @@ -678,6 +684,8 @@ def __call__( if XLA_AVAILABLE: xm.mark_step() + self._current_timestep = None + if not output_type == "latent": latents = latents.to(self.vae.dtype) / self.vae.config.scaling_factor video = self.vae.decode(latents, return_dict=False)[0] diff --git a/src/diffusers/pipelines/latte/pipeline_latte.py b/src/diffusers/pipelines/latte/pipeline_latte.py index ce4ca313ebc4..578f373e8e3f 100644 --- a/src/diffusers/pipelines/latte/pipeline_latte.py +++ b/src/diffusers/pipelines/latte/pipeline_latte.py @@ -602,6 +602,10 @@ def do_classifier_free_guidance(self): def num_timesteps(self): return self._num_timesteps + @property + def current_timestep(self): + return self._current_timestep + @property def interrupt(self): return self._interrupt @@ -633,7 +637,7 @@ def __call__( clean_caption: bool = True, mask_feature: bool = True, enable_temporal_attentions: bool = True, - decode_chunk_size: Optional[int] = None, + decode_chunk_size: int = 14, ) -> Union[LattePipelineOutput, Tuple]: """ Function invoked when calling the pipeline for generation. @@ -729,6 +733,7 @@ def __call__( negative_prompt_embeds, ) self._guidance_scale = guidance_scale + self._current_timestep = None self._interrupt = False # 2. Default height and width to transformer @@ -790,6 +795,7 @@ def __call__( if self.interrupt: continue + self._current_timestep = t latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) @@ -850,6 +856,8 @@ def __call__( if XLA_AVAILABLE: xm.mark_step() + self._current_timestep = None + if output_type == "latents": deprecation_message = ( "Passing `output_type='latents'` is deprecated. Please pass `output_type='latent'` instead." @@ -858,7 +866,7 @@ def __call__( output_type = "latent" if not output_type == "latent": - video = self.decode_latents(latents, video_length, decode_chunk_size=14) + video = self.decode_latents(latents, video_length, decode_chunk_size=decode_chunk_size) video = self.video_processor.postprocess_video(video=video, output_type=output_type) else: video = latents diff --git a/src/diffusers/pipelines/mochi/pipeline_mochi.py b/src/diffusers/pipelines/mochi/pipeline_mochi.py index a3028c50d8b7..d1f88b02c5cc 100644 --- a/src/diffusers/pipelines/mochi/pipeline_mochi.py +++ b/src/diffusers/pipelines/mochi/pipeline_mochi.py @@ -21,8 +21,7 @@ from ...callbacks import MultiPipelineCallbacks, PipelineCallback from ...loaders import Mochi1LoraLoaderMixin -from ...models.autoencoders import AutoencoderKLMochi -from ...models.transformers import MochiTransformer3DModel +from ...models import AutoencoderKLMochi, MochiTransformer3DModel from ...schedulers import FlowMatchEulerDiscreteScheduler from ...utils import ( is_torch_xla_available, @@ -467,6 +466,10 @@ def num_timesteps(self): def attention_kwargs(self): return self._attention_kwargs + @property + def current_timestep(self): + return self._current_timestep + @property def interrupt(self): return self._interrupt @@ -591,6 +594,7 @@ def __call__( self._guidance_scale = guidance_scale self._attention_kwargs = attention_kwargs + self._current_timestep = None self._interrupt = False # 2. Define call parameters @@ -660,6 +664,9 @@ def __call__( if self.interrupt: continue + # Note: Mochi uses reversed timesteps. To ensure compatibility with methods like FasterCache, we need + # to make sure we're using the correct non-reversed timestep values. + self._current_timestep = 1000 - t latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents # broadcast to batch dimension in a way that's compatible with ONNX/Core ML timestep = t.expand(latent_model_input.shape[0]).to(latents.dtype) @@ -705,6 +712,8 @@ def __call__( if XLA_AVAILABLE: xm.mark_step() + self._current_timestep = None + if output_type == "latent": video = latents else: diff --git a/src/diffusers/pipelines/pipeline_utils.py b/src/diffusers/pipelines/pipeline_utils.py index d56a2ce6eb30..0c1371c7556f 100644 --- a/src/diffusers/pipelines/pipeline_utils.py +++ b/src/diffusers/pipelines/pipeline_utils.py @@ -1133,11 +1133,20 @@ def enable_model_cpu_offload(self, gpu_id: Optional[int] = None, device: Union[t def maybe_free_model_hooks(self): r""" - Function that offloads all components, removes all model hooks that were added when using - `enable_model_cpu_offload` and then applies them again. In case the model has not been offloaded this function - is a no-op. Make sure to add this function to the end of the `__call__` function of your pipeline so that it - functions correctly when applying enable_model_cpu_offload. + Method that performs the following: + - Offloads all components. + - Removes all model hooks that were added when using `enable_model_cpu_offload`, and then applies them again. + In case the model has not been offloaded, this function is a no-op. + - Resets stateful diffusers hooks of denoiser components if they were added with + [`~hooks.HookRegistry.register_hook`]. + + Make sure to add this function to the end of the `__call__` function of your pipeline so that it functions + correctly when applying `enable_model_cpu_offload`. """ + for component in self.components.values(): + if hasattr(component, "_reset_stateful_cache"): + component._reset_stateful_cache() + if not hasattr(self, "_all_hooks") or len(self._all_hooks) == 0: # `enable_model_cpu_offload` has not be called, so silently do nothing return diff --git a/src/diffusers/utils/dummy_pt_objects.py b/src/diffusers/utils/dummy_pt_objects.py index 183d6beb35c3..6a1978944c9f 100644 --- a/src/diffusers/utils/dummy_pt_objects.py +++ b/src/diffusers/utils/dummy_pt_objects.py @@ -2,6 +2,40 @@ from ..utils import DummyObject, requires_backends +class HookRegistry(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + +class PyramidAttentionBroadcastConfig(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + +def apply_pyramid_attention_broadcast(*args, **kwargs): + requires_backends(apply_pyramid_attention_broadcast, ["torch"]) + + class AllegroTransformer3DModel(metaclass=DummyObject): _backends = ["torch"] @@ -197,6 +231,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) +class CacheMixin(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + class CogVideoXTransformer3DModel(metaclass=DummyObject): _backends = ["torch"] diff --git a/tests/hooks/test_hooks.py b/tests/hooks/test_hooks.py new file mode 100644 index 000000000000..74bd43c52315 --- /dev/null +++ b/tests/hooks/test_hooks.py @@ -0,0 +1,382 @@ +# Copyright 2024 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import gc +import unittest + +import torch + +from diffusers.hooks import HookRegistry, ModelHook +from diffusers.training_utils import free_memory +from diffusers.utils.logging import get_logger +from diffusers.utils.testing_utils import CaptureLogger, torch_device + + +logger = get_logger(__name__) # pylint: disable=invalid-name + + +class DummyBlock(torch.nn.Module): + def __init__(self, in_features: int, hidden_features: int, out_features: int) -> None: + super().__init__() + + self.proj_in = torch.nn.Linear(in_features, hidden_features) + self.activation = torch.nn.ReLU() + self.proj_out = torch.nn.Linear(hidden_features, out_features) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.proj_in(x) + x = self.activation(x) + x = self.proj_out(x) + return x + + +class DummyModel(torch.nn.Module): + def __init__(self, in_features: int, hidden_features: int, out_features: int, num_layers: int) -> None: + super().__init__() + + self.linear_1 = torch.nn.Linear(in_features, hidden_features) + self.activation = torch.nn.ReLU() + self.blocks = torch.nn.ModuleList( + [DummyBlock(hidden_features, hidden_features, hidden_features) for _ in range(num_layers)] + ) + self.linear_2 = torch.nn.Linear(hidden_features, out_features) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.linear_1(x) + x = self.activation(x) + for block in self.blocks: + x = block(x) + x = self.linear_2(x) + return x + + +class AddHook(ModelHook): + def __init__(self, value: int): + super().__init__() + self.value = value + + def pre_forward(self, module: torch.nn.Module, *args, **kwargs): + logger.debug("AddHook pre_forward") + args = ((x + self.value) if torch.is_tensor(x) else x for x in args) + return args, kwargs + + def post_forward(self, module, output): + logger.debug("AddHook post_forward") + return output + + +class MultiplyHook(ModelHook): + def __init__(self, value: int): + super().__init__() + self.value = value + + def pre_forward(self, module, *args, **kwargs): + logger.debug("MultiplyHook pre_forward") + args = ((x * self.value) if torch.is_tensor(x) else x for x in args) + return args, kwargs + + def post_forward(self, module, output): + logger.debug("MultiplyHook post_forward") + return output + + def __repr__(self): + return f"MultiplyHook(value={self.value})" + + +class StatefulAddHook(ModelHook): + _is_stateful = True + + def __init__(self, value: int): + super().__init__() + self.value = value + self.increment = 0 + + def pre_forward(self, module, *args, **kwargs): + logger.debug("StatefulAddHook pre_forward") + add_value = self.value + self.increment + self.increment += 1 + args = ((x + add_value) if torch.is_tensor(x) else x for x in args) + return args, kwargs + + def reset_state(self, module): + self.increment = 0 + + +class SkipLayerHook(ModelHook): + def __init__(self, skip_layer: bool): + super().__init__() + self.skip_layer = skip_layer + + def pre_forward(self, module, *args, **kwargs): + logger.debug("SkipLayerHook pre_forward") + return args, kwargs + + def new_forward(self, module, *args, **kwargs): + logger.debug("SkipLayerHook new_forward") + if self.skip_layer: + return args[0] + return self.fn_ref.original_forward(*args, **kwargs) + + def post_forward(self, module, output): + logger.debug("SkipLayerHook post_forward") + return output + + +class HookTests(unittest.TestCase): + in_features = 4 + hidden_features = 8 + out_features = 4 + num_layers = 2 + + def setUp(self): + params = self.get_module_parameters() + self.model = DummyModel(**params) + self.model.to(torch_device) + + def tearDown(self): + super().tearDown() + + del self.model + gc.collect() + free_memory() + + def get_module_parameters(self): + return { + "in_features": self.in_features, + "hidden_features": self.hidden_features, + "out_features": self.out_features, + "num_layers": self.num_layers, + } + + def get_generator(self): + return torch.manual_seed(0) + + def test_hook_registry(self): + registry = HookRegistry.check_if_exists_or_initialize(self.model) + registry.register_hook(AddHook(1), "add_hook") + registry.register_hook(MultiplyHook(2), "multiply_hook") + + registry_repr = repr(registry) + expected_repr = ( + "HookRegistry(\n" " (0) add_hook - AddHook\n" " (1) multiply_hook - MultiplyHook(value=2)\n" ")" + ) + + self.assertEqual(len(registry.hooks), 2) + self.assertEqual(registry._hook_order, ["add_hook", "multiply_hook"]) + self.assertEqual(registry_repr, expected_repr) + + registry.remove_hook("add_hook") + + self.assertEqual(len(registry.hooks), 1) + self.assertEqual(registry._hook_order, ["multiply_hook"]) + + def test_stateful_hook(self): + registry = HookRegistry.check_if_exists_or_initialize(self.model) + registry.register_hook(StatefulAddHook(1), "stateful_add_hook") + + self.assertEqual(registry.hooks["stateful_add_hook"].increment, 0) + + input = torch.randn(1, 4, device=torch_device, generator=self.get_generator()) + num_repeats = 3 + + for i in range(num_repeats): + result = self.model(input) + if i == 0: + output1 = result + + self.assertEqual(registry.get_hook("stateful_add_hook").increment, num_repeats) + + registry.reset_stateful_hooks() + output2 = self.model(input) + + self.assertEqual(registry.get_hook("stateful_add_hook").increment, 1) + self.assertTrue(torch.allclose(output1, output2)) + + def test_inference(self): + registry = HookRegistry.check_if_exists_or_initialize(self.model) + registry.register_hook(AddHook(1), "add_hook") + registry.register_hook(MultiplyHook(2), "multiply_hook") + + input = torch.randn(1, 4, device=torch_device, generator=self.get_generator()) + output1 = self.model(input).mean().detach().cpu().item() + + registry.remove_hook("multiply_hook") + new_input = input * 2 + output2 = self.model(new_input).mean().detach().cpu().item() + + registry.remove_hook("add_hook") + new_input = input * 2 + 1 + output3 = self.model(new_input).mean().detach().cpu().item() + + self.assertAlmostEqual(output1, output2, places=5) + self.assertAlmostEqual(output1, output3, places=5) + + def test_skip_layer_hook(self): + registry = HookRegistry.check_if_exists_or_initialize(self.model) + registry.register_hook(SkipLayerHook(skip_layer=True), "skip_layer_hook") + + input = torch.zeros(1, 4, device=torch_device) + output = self.model(input).mean().detach().cpu().item() + self.assertEqual(output, 0.0) + + registry.remove_hook("skip_layer_hook") + registry.register_hook(SkipLayerHook(skip_layer=False), "skip_layer_hook") + output = self.model(input).mean().detach().cpu().item() + self.assertNotEqual(output, 0.0) + + def test_skip_layer_internal_block(self): + registry = HookRegistry.check_if_exists_or_initialize(self.model.linear_1) + input = torch.zeros(1, 4, device=torch_device) + + registry.register_hook(SkipLayerHook(skip_layer=True), "skip_layer_hook") + with self.assertRaises(RuntimeError) as cm: + self.model(input).mean().detach().cpu().item() + self.assertIn("mat1 and mat2 shapes cannot be multiplied", str(cm.exception)) + + registry.remove_hook("skip_layer_hook") + output = self.model(input).mean().detach().cpu().item() + self.assertNotEqual(output, 0.0) + + registry = HookRegistry.check_if_exists_or_initialize(self.model.blocks[1]) + registry.register_hook(SkipLayerHook(skip_layer=True), "skip_layer_hook") + output = self.model(input).mean().detach().cpu().item() + self.assertNotEqual(output, 0.0) + + def test_invocation_order_stateful_first(self): + registry = HookRegistry.check_if_exists_or_initialize(self.model) + registry.register_hook(StatefulAddHook(1), "add_hook") + registry.register_hook(AddHook(2), "add_hook_2") + registry.register_hook(MultiplyHook(3), "multiply_hook") + + input = torch.randn(1, 4, device=torch_device, generator=self.get_generator()) + + logger = get_logger(__name__) + logger.setLevel("DEBUG") + + with CaptureLogger(logger) as cap_logger: + self.model(input) + output = cap_logger.out.replace(" ", "").replace("\n", "") + expected_invocation_order_log = ( + ( + "MultiplyHook pre_forward\n" + "AddHook pre_forward\n" + "StatefulAddHook pre_forward\n" + "AddHook post_forward\n" + "MultiplyHook post_forward\n" + ) + .replace(" ", "") + .replace("\n", "") + ) + self.assertEqual(output, expected_invocation_order_log) + + registry.remove_hook("add_hook") + with CaptureLogger(logger) as cap_logger: + self.model(input) + output = cap_logger.out.replace(" ", "").replace("\n", "") + expected_invocation_order_log = ( + ( + "MultiplyHook pre_forward\n" + "AddHook pre_forward\n" + "AddHook post_forward\n" + "MultiplyHook post_forward\n" + ) + .replace(" ", "") + .replace("\n", "") + ) + self.assertEqual(output, expected_invocation_order_log) + + def test_invocation_order_stateful_middle(self): + registry = HookRegistry.check_if_exists_or_initialize(self.model) + registry.register_hook(AddHook(2), "add_hook") + registry.register_hook(StatefulAddHook(1), "add_hook_2") + registry.register_hook(MultiplyHook(3), "multiply_hook") + + input = torch.randn(1, 4, device=torch_device, generator=self.get_generator()) + + logger = get_logger(__name__) + logger.setLevel("DEBUG") + + with CaptureLogger(logger) as cap_logger: + self.model(input) + output = cap_logger.out.replace(" ", "").replace("\n", "") + expected_invocation_order_log = ( + ( + "MultiplyHook pre_forward\n" + "StatefulAddHook pre_forward\n" + "AddHook pre_forward\n" + "AddHook post_forward\n" + "MultiplyHook post_forward\n" + ) + .replace(" ", "") + .replace("\n", "") + ) + self.assertEqual(output, expected_invocation_order_log) + + registry.remove_hook("add_hook") + with CaptureLogger(logger) as cap_logger: + self.model(input) + output = cap_logger.out.replace(" ", "").replace("\n", "") + expected_invocation_order_log = ( + ("MultiplyHook pre_forward\nStatefulAddHook pre_forward\nMultiplyHook post_forward\n") + .replace(" ", "") + .replace("\n", "") + ) + self.assertEqual(output, expected_invocation_order_log) + + registry.remove_hook("add_hook_2") + with CaptureLogger(logger) as cap_logger: + self.model(input) + output = cap_logger.out.replace(" ", "").replace("\n", "") + expected_invocation_order_log = ( + ("MultiplyHook pre_forward\nMultiplyHook post_forward\n").replace(" ", "").replace("\n", "") + ) + self.assertEqual(output, expected_invocation_order_log) + + def test_invocation_order_stateful_last(self): + registry = HookRegistry.check_if_exists_or_initialize(self.model) + registry.register_hook(AddHook(1), "add_hook") + registry.register_hook(MultiplyHook(2), "multiply_hook") + registry.register_hook(StatefulAddHook(3), "add_hook_2") + + input = torch.randn(1, 4, device=torch_device, generator=self.get_generator()) + + logger = get_logger(__name__) + logger.setLevel("DEBUG") + + with CaptureLogger(logger) as cap_logger: + self.model(input) + output = cap_logger.out.replace(" ", "").replace("\n", "") + expected_invocation_order_log = ( + ( + "StatefulAddHook pre_forward\n" + "MultiplyHook pre_forward\n" + "AddHook pre_forward\n" + "AddHook post_forward\n" + "MultiplyHook post_forward\n" + ) + .replace(" ", "") + .replace("\n", "") + ) + self.assertEqual(output, expected_invocation_order_log) + + registry.remove_hook("add_hook") + with CaptureLogger(logger) as cap_logger: + self.model(input) + output = cap_logger.out.replace(" ", "").replace("\n", "") + expected_invocation_order_log = ( + ("StatefulAddHook pre_forward\nMultiplyHook pre_forward\nMultiplyHook post_forward\n") + .replace(" ", "") + .replace("\n", "") + ) + self.assertEqual(output, expected_invocation_order_log) diff --git a/tests/pipelines/allegro/test_allegro.py b/tests/pipelines/allegro/test_allegro.py index 322be373641a..2a4d0a36dffa 100644 --- a/tests/pipelines/allegro/test_allegro.py +++ b/tests/pipelines/allegro/test_allegro.py @@ -34,13 +34,13 @@ ) from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS -from ..test_pipelines_common import PipelineTesterMixin, to_np +from ..test_pipelines_common import PipelineTesterMixin, PyramidAttentionBroadcastTesterMixin, to_np enable_full_determinism() -class AllegroPipelineFastTests(PipelineTesterMixin, unittest.TestCase): +class AllegroPipelineFastTests(PipelineTesterMixin, PyramidAttentionBroadcastTesterMixin, unittest.TestCase): pipeline_class = AllegroPipeline params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"} batch_params = TEXT_TO_IMAGE_BATCH_PARAMS @@ -59,14 +59,14 @@ class AllegroPipelineFastTests(PipelineTesterMixin, unittest.TestCase): test_xformers_attention = False test_layerwise_casting = True - def get_dummy_components(self): + def get_dummy_components(self, num_layers: int = 1): torch.manual_seed(0) transformer = AllegroTransformer3DModel( num_attention_heads=2, attention_head_dim=12, in_channels=4, out_channels=4, - num_layers=1, + num_layers=num_layers, cross_attention_dim=24, sample_width=8, sample_height=8, diff --git a/tests/pipelines/cogvideo/test_cogvideox.py b/tests/pipelines/cogvideo/test_cogvideox.py index 9ce3d8e9de31..750f20f8fbe5 100644 --- a/tests/pipelines/cogvideo/test_cogvideox.py +++ b/tests/pipelines/cogvideo/test_cogvideox.py @@ -32,6 +32,7 @@ from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS from ..test_pipelines_common import ( PipelineTesterMixin, + PyramidAttentionBroadcastTesterMixin, check_qkv_fusion_matches_attn_procs_length, check_qkv_fusion_processors_exist, to_np, @@ -41,7 +42,7 @@ enable_full_determinism() -class CogVideoXPipelineFastTests(PipelineTesterMixin, unittest.TestCase): +class CogVideoXPipelineFastTests(PipelineTesterMixin, PyramidAttentionBroadcastTesterMixin, unittest.TestCase): pipeline_class = CogVideoXPipeline params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"} batch_params = TEXT_TO_IMAGE_BATCH_PARAMS @@ -60,7 +61,7 @@ class CogVideoXPipelineFastTests(PipelineTesterMixin, unittest.TestCase): test_xformers_attention = False test_layerwise_casting = True - def get_dummy_components(self): + def get_dummy_components(self, num_layers: int = 1): torch.manual_seed(0) transformer = CogVideoXTransformer3DModel( # Product of num_attention_heads * attention_head_dim must be divisible by 16 for 3D positional embeddings @@ -72,7 +73,7 @@ def get_dummy_components(self): out_channels=4, time_embed_dim=2, text_embed_dim=32, # Must match with tiny-random-t5 - num_layers=1, + num_layers=num_layers, sample_width=2, # latent width: 2 -> final width: 16 sample_height=2, # latent height: 2 -> final height: 16 sample_frames=9, # latent frames: (9 - 1) / 4 + 1 = 3 -> final frames: 9 diff --git a/tests/pipelines/flux/test_pipeline_flux.py b/tests/pipelines/flux/test_pipeline_flux.py index a3bc1658de74..bab343a5954c 100644 --- a/tests/pipelines/flux/test_pipeline_flux.py +++ b/tests/pipelines/flux/test_pipeline_flux.py @@ -19,12 +19,15 @@ from ..test_pipelines_common import ( FluxIPAdapterTesterMixin, PipelineTesterMixin, + PyramidAttentionBroadcastTesterMixin, check_qkv_fusion_matches_attn_procs_length, check_qkv_fusion_processors_exist, ) -class FluxPipelineFastTests(unittest.TestCase, PipelineTesterMixin, FluxIPAdapterTesterMixin): +class FluxPipelineFastTests( + unittest.TestCase, PipelineTesterMixin, FluxIPAdapterTesterMixin, PyramidAttentionBroadcastTesterMixin +): pipeline_class = FluxPipeline params = frozenset(["prompt", "height", "width", "guidance_scale", "prompt_embeds", "pooled_prompt_embeds"]) batch_params = frozenset(["prompt"]) @@ -33,13 +36,13 @@ class FluxPipelineFastTests(unittest.TestCase, PipelineTesterMixin, FluxIPAdapte test_xformers_attention = False test_layerwise_casting = True - def get_dummy_components(self): + def get_dummy_components(self, num_layers: int = 1, num_single_layers: int = 1): torch.manual_seed(0) transformer = FluxTransformer2DModel( patch_size=1, in_channels=4, - num_layers=1, - num_single_layers=1, + num_layers=num_layers, + num_single_layers=num_single_layers, attention_head_dim=16, num_attention_heads=2, joint_attention_dim=32, diff --git a/tests/pipelines/hunyuan_video/test_hunyuan_video.py b/tests/pipelines/hunyuan_video/test_hunyuan_video.py index ce03381f90d2..1ecfee666fcd 100644 --- a/tests/pipelines/hunyuan_video/test_hunyuan_video.py +++ b/tests/pipelines/hunyuan_video/test_hunyuan_video.py @@ -30,13 +30,13 @@ torch_device, ) -from ..test_pipelines_common import PipelineTesterMixin, to_np +from ..test_pipelines_common import PipelineTesterMixin, PyramidAttentionBroadcastTesterMixin, to_np enable_full_determinism() -class HunyuanVideoPipelineFastTests(PipelineTesterMixin, unittest.TestCase): +class HunyuanVideoPipelineFastTests(PipelineTesterMixin, PyramidAttentionBroadcastTesterMixin, unittest.TestCase): pipeline_class = HunyuanVideoPipeline params = frozenset(["prompt", "height", "width", "guidance_scale", "prompt_embeds", "pooled_prompt_embeds"]) batch_params = frozenset(["prompt"]) @@ -55,15 +55,15 @@ class HunyuanVideoPipelineFastTests(PipelineTesterMixin, unittest.TestCase): test_xformers_attention = False test_layerwise_casting = True - def get_dummy_components(self): + def get_dummy_components(self, num_layers: int = 1, num_single_layers: int = 1): torch.manual_seed(0) transformer = HunyuanVideoTransformer3DModel( in_channels=4, out_channels=4, num_attention_heads=2, attention_head_dim=10, - num_layers=1, - num_single_layers=1, + num_layers=num_layers, + num_single_layers=num_single_layers, num_refiner_layers=1, patch_size=1, patch_size_t=1, diff --git a/tests/pipelines/latte/test_latte.py b/tests/pipelines/latte/test_latte.py index 2d5bcba8237a..64459a659179 100644 --- a/tests/pipelines/latte/test_latte.py +++ b/tests/pipelines/latte/test_latte.py @@ -27,6 +27,7 @@ DDIMScheduler, LattePipeline, LatteTransformer3DModel, + PyramidAttentionBroadcastConfig, ) from diffusers.utils.import_utils import is_xformers_available from diffusers.utils.testing_utils import ( @@ -38,13 +39,13 @@ ) from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS -from ..test_pipelines_common import PipelineTesterMixin, to_np +from ..test_pipelines_common import PipelineTesterMixin, PyramidAttentionBroadcastTesterMixin, to_np enable_full_determinism() -class LattePipelineFastTests(PipelineTesterMixin, unittest.TestCase): +class LattePipelineFastTests(PipelineTesterMixin, PyramidAttentionBroadcastTesterMixin, unittest.TestCase): pipeline_class = LattePipeline params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"} batch_params = TEXT_TO_IMAGE_BATCH_PARAMS @@ -54,11 +55,23 @@ class LattePipelineFastTests(PipelineTesterMixin, unittest.TestCase): required_optional_params = PipelineTesterMixin.required_optional_params test_layerwise_casting = True - def get_dummy_components(self): + pab_config = PyramidAttentionBroadcastConfig( + spatial_attention_block_skip_range=2, + temporal_attention_block_skip_range=2, + cross_attention_block_skip_range=2, + spatial_attention_timestep_skip_range=(100, 700), + temporal_attention_timestep_skip_range=(100, 800), + cross_attention_timestep_skip_range=(100, 800), + spatial_attention_block_identifiers=["transformer_blocks"], + temporal_attention_block_identifiers=["temporal_transformer_blocks"], + cross_attention_block_identifiers=["transformer_blocks"], + ) + + def get_dummy_components(self, num_layers: int = 1): torch.manual_seed(0) transformer = LatteTransformer3DModel( sample_size=8, - num_layers=1, + num_layers=num_layers, patch_size=2, attention_head_dim=8, num_attention_heads=3, diff --git a/tests/pipelines/test_pipelines_common.py b/tests/pipelines/test_pipelines_common.py index 139778994b87..de5faa185c2f 100644 --- a/tests/pipelines/test_pipelines_common.py +++ b/tests/pipelines/test_pipelines_common.py @@ -24,10 +24,12 @@ DDIMScheduler, DiffusionPipeline, KolorsPipeline, + PyramidAttentionBroadcastConfig, StableDiffusionPipeline, StableDiffusionXLPipeline, UNet2DConditionModel, ) +from diffusers.hooks.pyramid_attention_broadcast import PyramidAttentionBroadcastHook from diffusers.image_processor import VaeImageProcessor from diffusers.loaders import FluxIPAdapterMixin, IPAdapterMixin from diffusers.models.attention_processor import AttnProcessor @@ -2322,6 +2324,141 @@ def _test_save_load_optional_components(self, expected_max_difference=1e-4): self.assertLess(max_diff, expected_max_difference) +class PyramidAttentionBroadcastTesterMixin: + pab_config = PyramidAttentionBroadcastConfig( + spatial_attention_block_skip_range=2, + spatial_attention_timestep_skip_range=(100, 800), + spatial_attention_block_identifiers=["transformer_blocks"], + ) + + def test_pyramid_attention_broadcast_layers(self): + device = "cpu" # ensure determinism for the device-dependent torch.Generator + + num_layers = 0 + num_single_layers = 0 + dummy_component_kwargs = {} + dummy_component_parameters = inspect.signature(self.get_dummy_components).parameters + if "num_layers" in dummy_component_parameters: + num_layers = 2 + dummy_component_kwargs["num_layers"] = num_layers + if "num_single_layers" in dummy_component_parameters: + num_single_layers = 2 + dummy_component_kwargs["num_single_layers"] = num_single_layers + + components = self.get_dummy_components(**dummy_component_kwargs) + pipe = self.pipeline_class(**components) + pipe.set_progress_bar_config(disable=None) + + self.pab_config.current_timestep_callback = lambda: pipe.current_timestep + denoiser = pipe.transformer if hasattr(pipe, "transformer") else pipe.unet + denoiser.enable_cache(self.pab_config) + + expected_hooks = 0 + if self.pab_config.spatial_attention_block_skip_range is not None: + expected_hooks += num_layers + num_single_layers + if self.pab_config.temporal_attention_block_skip_range is not None: + expected_hooks += num_layers + num_single_layers + if self.pab_config.cross_attention_block_skip_range is not None: + expected_hooks += num_layers + num_single_layers + + denoiser = pipe.transformer if hasattr(pipe, "transformer") else pipe.unet + count = 0 + for module in denoiser.modules(): + if hasattr(module, "_diffusers_hook"): + hook = module._diffusers_hook.get_hook("pyramid_attention_broadcast") + if hook is None: + continue + count += 1 + self.assertTrue( + isinstance(hook, PyramidAttentionBroadcastHook), + "Hook should be of type PyramidAttentionBroadcastHook.", + ) + self.assertTrue(hook.state.cache is None, "Cache should be None at initialization.") + self.assertEqual(count, expected_hooks, "Number of hooks should match the expected number.") + + # Perform dummy inference step to ensure state is updated + def pab_state_check_callback(pipe, i, t, kwargs): + for module in denoiser.modules(): + if hasattr(module, "_diffusers_hook"): + hook = module._diffusers_hook.get_hook("pyramid_attention_broadcast") + if hook is None: + continue + self.assertTrue( + hook.state.cache is not None, + "Cache should have updated during inference.", + ) + self.assertTrue( + hook.state.iteration == i + 1, + "Hook iteration state should have updated during inference.", + ) + return {} + + inputs = self.get_dummy_inputs(device) + inputs["num_inference_steps"] = 2 + inputs["callback_on_step_end"] = pab_state_check_callback + pipe(**inputs)[0] + + # After inference, reset_stateful_hooks is called within the pipeline, which should have reset the states + for module in denoiser.modules(): + if hasattr(module, "_diffusers_hook"): + hook = module._diffusers_hook.get_hook("pyramid_attention_broadcast") + if hook is None: + continue + self.assertTrue( + hook.state.cache is None, + "Cache should be reset to None after inference.", + ) + self.assertTrue( + hook.state.iteration == 0, + "Iteration should be reset to 0 after inference.", + ) + + def test_pyramid_attention_broadcast_inference(self, expected_atol: float = 0.2): + # We need to use higher tolerance because we are using a random model. With a converged/trained + # model, the tolerance can be lower. + + device = "cpu" # ensure determinism for the device-dependent torch.Generator + num_layers = 2 + components = self.get_dummy_components(num_layers=num_layers) + pipe = self.pipeline_class(**components) + pipe = pipe.to(device) + pipe.set_progress_bar_config(disable=None) + + # Run inference without PAB + inputs = self.get_dummy_inputs(device) + inputs["num_inference_steps"] = 4 + output = pipe(**inputs)[0] + original_image_slice = output.flatten() + original_image_slice = np.concatenate((original_image_slice[:8], original_image_slice[-8:])) + + # Run inference with PAB enabled + self.pab_config.current_timestep_callback = lambda: pipe.current_timestep + denoiser = pipe.transformer if hasattr(pipe, "transformer") else pipe.unet + denoiser.enable_cache(self.pab_config) + + inputs = self.get_dummy_inputs(device) + inputs["num_inference_steps"] = 4 + output = pipe(**inputs)[0] + image_slice_pab_enabled = output.flatten() + image_slice_pab_enabled = np.concatenate((image_slice_pab_enabled[:8], image_slice_pab_enabled[-8:])) + + # Run inference with PAB disabled + denoiser.disable_cache() + + inputs = self.get_dummy_inputs(device) + inputs["num_inference_steps"] = 4 + output = pipe(**inputs)[0] + image_slice_pab_disabled = output.flatten() + image_slice_pab_disabled = np.concatenate((image_slice_pab_disabled[:8], image_slice_pab_disabled[-8:])) + + assert np.allclose( + original_image_slice, image_slice_pab_enabled, atol=expected_atol + ), "PAB outputs should not differ much in specified timestep range." + assert np.allclose( + original_image_slice, image_slice_pab_disabled, atol=1e-4 + ), "Outputs from normal inference and after disabling cache should not differ." + + # Some models (e.g. unCLIP) are extremely likely to significantly deviate depending on which hardware is used. # This helper function is used to check that the image doesn't deviate on average more than 10 pixels from a # reference image.