diff --git a/src/diffusers/hooks/_common.py b/src/diffusers/hooks/_common.py new file mode 100644 index 000000000000..3be77dd4cedf --- /dev/null +++ b/src/diffusers/hooks/_common.py @@ -0,0 +1,30 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ..models.attention_processor import Attention, MochiAttention + + +_ATTENTION_CLASSES = (Attention, MochiAttention) + +_SPATIAL_TRANSFORMER_BLOCK_IDENTIFIERS = ("blocks", "transformer_blocks", "single_transformer_blocks", "layers") +_TEMPORAL_TRANSFORMER_BLOCK_IDENTIFIERS = ("temporal_transformer_blocks",) +_CROSS_TRANSFORMER_BLOCK_IDENTIFIERS = ("blocks", "transformer_blocks", "layers") + +_ALL_TRANSFORMER_BLOCK_IDENTIFIERS = tuple( + { + *_SPATIAL_TRANSFORMER_BLOCK_IDENTIFIERS, + *_TEMPORAL_TRANSFORMER_BLOCK_IDENTIFIERS, + *_CROSS_TRANSFORMER_BLOCK_IDENTIFIERS, + } +) diff --git a/src/diffusers/hooks/first_block_cache.py b/src/diffusers/hooks/first_block_cache.py new file mode 100644 index 000000000000..d229cdfd3010 --- /dev/null +++ b/src/diffusers/hooks/first_block_cache.py @@ -0,0 +1,262 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from dataclasses import dataclass +from typing import Tuple, Union + +import torch + +from ..utils import get_logger +from ._common import _ALL_TRANSFORMER_BLOCK_IDENTIFIERS +from .hooks import HookRegistry, ModelHook +from .utils import _extract_return_information + + +logger = get_logger(__name__) # pylint: disable=invalid-name + + +_FBC_LEADER_BLOCK_HOOK = "fbc_leader_block_hook" +_FBC_BLOCK_HOOK = "fbc_block_hook" + + +@dataclass +class FirstBlockCacheConfig: + r""" + Configuration for [First Block + Cache](https://github.com/chengzeyi/ParaAttention/blob/7a266123671b55e7e5a2fe9af3121f07a36afc78/README.md#first-block-cache-our-dynamic-caching). + + Args: + threshold (`float`, defaults to `0.05`): + The threshold to determine whether or not a forward pass through all layers of the model is required. A + higher threshold usually results in lower number of forward passes and faster inference, but might lead to + poorer generation quality. A lower threshold may not result in significant generation speedup. The + threshold is compared against the absmean difference of the residuals between the current and cached + outputs from the first transformer block. If the difference is below the threshold, the forward pass is + skipped. + """ + + threshold: float = 0.05 + + +class FBCSharedBlockState: + def __init__(self) -> None: + self.head_block_output: Union[torch.Tensor, Tuple[torch.Tensor, ...]] = None + self.head_block_residual: torch.Tensor = None + self.tail_block_residuals: Union[torch.Tensor, Tuple[torch.Tensor, ...]] = None + self.should_compute: bool = True + + def reset(self): + self.tail_block_residuals = None + self.should_compute = True + + def __repr__(self): + return f"FirstBlockCacheSharedState(cache={self.cache})" + + +class FBCHeadBlockHook(ModelHook): + _is_stateful = True + + def __init__(self, shared_state: FBCSharedBlockState, threshold: float): + self.shared_state = shared_state + self.threshold = threshold + + def initialize_hook(self, module: torch.nn.Module) -> torch.nn.Module: + inputs = inspect.signature(module.__class__.forward) + inputs_index_to_str = dict(enumerate(inputs.parameters.keys())) + inputs_str_to_index = {v: k for k, v in inputs_index_to_str.items()} + + try: + outputs = _extract_return_information(module.__class__.forward) + outputs_index_to_str = dict(enumerate(outputs)) + outputs_str_to_index = {v: k for k, v in outputs_index_to_str.items()} + except RuntimeError: + logger.error(f"Failed to extract return information for {module.__class__}") + raise NotImplementedError( + f"Module {module.__class__} is not supported with FirstBlockCache. Please open an issue at " + f"https://github.com/huggingface/diffusers to notify us about the error with a minimal example " + f"in order for us to add support for this module." + ) + + self._inputs_index_to_str = inputs_index_to_str + self._inputs_str_to_index = inputs_str_to_index + self._outputs_index_to_str = outputs_index_to_str + self._outputs_str_to_index = outputs_str_to_index + return module + + def new_forward(self, module: torch.nn.Module, *args, **kwargs): + hs_input_idx = self._inputs_str_to_index.get("hidden_states") + ehs_input_idx = self._inputs_str_to_index.get("encoder_hidden_states", None) + original_hs = kwargs.get("hidden_states", None) + original_ehs = kwargs.get("encoder_hidden_states", None) + original_hs = original_hs if original_hs is not None else args[hs_input_idx] + if ehs_input_idx is not None: + original_ehs = original_ehs if original_ehs is not None else args[ehs_input_idx] + + hs_output_idx = self._outputs_str_to_index.get("hidden_states") + ehs_output_idx = self._outputs_str_to_index.get("encoder_hidden_states", None) + assert (ehs_input_idx is None) == (ehs_output_idx is None) + + output = self.fn_ref.original_forward(*args, **kwargs) + + hs_residual = None + if isinstance(output, tuple): + hs_residual = output[hs_output_idx] - original_hs + else: + hs_residual = output - original_hs + + should_compute = self._should_compute_remaining_blocks(hs_residual) + self.shared_state.should_compute = should_compute + + hs, ehs = None, None + if not should_compute: + # Apply caching + logger.info("Skipping forward pass through remaining blocks") + hs = self.shared_state.tail_block_residuals[0] + output[hs_output_idx] + if ehs_output_idx is not None: + ehs = self.shared_state.tail_block_residuals[1] + output[ehs_output_idx] + + if isinstance(output, tuple): + return_output = [None] * len(output) + return_output[hs_output_idx] = hs + return_output[ehs_output_idx] = ehs + return_output = tuple(return_output) + else: + return_output = hs + return return_output + else: + logger.info("Computing forward pass through remaining blocks") + if isinstance(output, tuple): + head_block_output = [None] * len(output) + head_block_output[0] = output[hs_output_idx] + head_block_output[1] = output[ehs_output_idx] + else: + head_block_output = output + self.shared_state.head_block_output = head_block_output + self.shared_state.head_block_residual = hs_residual + return output + + def reset_state(self, module): + self.shared_state.reset() + return module + + def _should_compute_remaining_blocks(self, hs_residual: torch.Tensor) -> bool: + if self.shared_state.head_block_residual is None: + return True + prev_hs_residual = self.shared_state.head_block_residual + hs_absmean = (hs_residual - prev_hs_residual).abs().mean() + prev_hs_mean = prev_hs_residual.abs().mean() + diff = (hs_absmean / prev_hs_mean).item() + logger.info(f"Diff: {diff}, Threshold: {self.threshold}") + return diff > self.threshold + + +class FBCBlockHook(ModelHook): + def __init__(self, shared_state: FBCSharedBlockState, is_tail: bool = False): + super().__init__() + self.shared_state = shared_state + self.is_tail = is_tail + + def initialize_hook(self, module): + inputs = inspect.signature(module.__class__.forward) + inputs_index_to_str = dict(enumerate(inputs.parameters.keys())) + inputs_str_to_index = {v: k for k, v in inputs_index_to_str.items()} + + try: + outputs = _extract_return_information(module.__class__.forward) + outputs_index_to_str = dict(enumerate(outputs)) + outputs_str_to_index = {v: k for k, v in outputs_index_to_str.items()} + except RuntimeError: + logger.error(f"Failed to extract return information for {module.__class__}") + raise NotImplementedError( + f"Module {module.__class__} is not supported with FirstBlockCache. Please open an issue at " + f"https://github.com/huggingface/diffusers to notify us about the error with a minimal example " + f"in order for us to add support for this module." + ) + + self._inputs_index_to_str = inputs_index_to_str + self._inputs_str_to_index = inputs_str_to_index + self._outputs_index_to_str = outputs_index_to_str + self._outputs_str_to_index = outputs_str_to_index + return module + + def new_forward(self, module: torch.nn.Module, *args, **kwargs): + hs_input_idx = self._inputs_str_to_index.get("hidden_states") + ehs_input_idx = self._inputs_str_to_index.get("encoder_hidden_states", None) + original_hs = kwargs.get("hidden_states", None) + original_ehs = kwargs.get("encoder_hidden_states", None) + original_hs = original_hs if original_hs is not None else args[hs_input_idx] + if ehs_input_idx is not None: + original_ehs = original_ehs if original_ehs is not None else args[ehs_input_idx] + + hs_output_idx = self._outputs_str_to_index.get("hidden_states") + ehs_output_idx = self._outputs_str_to_index.get("encoder_hidden_states", None) + assert (ehs_input_idx is None) == (ehs_output_idx is None) + + if self.shared_state.should_compute: + output = self.fn_ref.original_forward(*args, **kwargs) + if self.is_tail: + hs_residual, ehs_residual = None, None + if isinstance(output, tuple): + hs_residual = output[hs_output_idx] - self.shared_state.head_block_output[0] + ehs_residual = output[ehs_output_idx] - self.shared_state.head_block_output[1] + else: + hs_residual = output - self.shared_state.head_block_output + self.shared_state.tail_block_residuals = (hs_residual, ehs_residual) + return output + + output_count = len(self._outputs_index_to_str.keys()) + return_output = [None] * output_count if output_count > 1 else original_hs + if output_count == 1: + return_output = original_hs + else: + return_output[hs_output_idx] = original_hs + return_output[ehs_output_idx] = original_ehs + return return_output + + +def apply_first_block_cache(module: torch.nn.Module, config: FirstBlockCacheConfig) -> None: + shared_state = FBCSharedBlockState() + remaining_blocks = [] + + for name, submodule in module.named_children(): + if name not in _ALL_TRANSFORMER_BLOCK_IDENTIFIERS or not isinstance(submodule, torch.nn.ModuleList): + continue + for block in submodule: + remaining_blocks.append((name, block)) + + head_block_name, head_block = remaining_blocks.pop(0) + tail_block_name, tail_block = remaining_blocks.pop(-1) + + logger.debug(f"Apply FBCHeadBlockHook to '{head_block_name}'") + apply_fbc_head_block_hook(head_block, shared_state, config.threshold) + + for name, block in remaining_blocks: + logger.debug(f"Apply FBCBlockHook to '{name}'") + apply_fbc_block_hook(block, shared_state) + + logger.debug(f"Apply FBCBlockHook to tail block '{tail_block_name}'") + apply_fbc_block_hook(tail_block, shared_state, is_tail=True) + + +def apply_fbc_head_block_hook(block: torch.nn.Module, state: FBCSharedBlockState, threshold: float) -> None: + registry = HookRegistry.check_if_exists_or_initialize(block) + hook = FBCHeadBlockHook(state, threshold) + registry.register_hook(hook, _FBC_LEADER_BLOCK_HOOK) + + +def apply_fbc_block_hook(block: torch.nn.Module, state: FBCSharedBlockState, is_tail: bool = False) -> None: + registry = HookRegistry.check_if_exists_or_initialize(block) + hook = FBCBlockHook(state, is_tail) + registry.register_hook(hook, _FBC_BLOCK_HOOK) diff --git a/src/diffusers/hooks/pyramid_attention_broadcast.py b/src/diffusers/hooks/pyramid_attention_broadcast.py index 9f8597d52f8c..e6c06aaa4456 100644 --- a/src/diffusers/hooks/pyramid_attention_broadcast.py +++ b/src/diffusers/hooks/pyramid_attention_broadcast.py @@ -20,19 +20,18 @@ from ..models.attention_processor import Attention, MochiAttention from ..utils import logging +from ._common import ( + _ATTENTION_CLASSES, + _CROSS_TRANSFORMER_BLOCK_IDENTIFIERS, + _SPATIAL_TRANSFORMER_BLOCK_IDENTIFIERS, + _TEMPORAL_TRANSFORMER_BLOCK_IDENTIFIERS, +) from .hooks import HookRegistry, ModelHook logger = logging.get_logger(__name__) # pylint: disable=invalid-name -_ATTENTION_CLASSES = (Attention, MochiAttention) - -_SPATIAL_ATTENTION_BLOCK_IDENTIFIERS = ("blocks", "transformer_blocks", "single_transformer_blocks") -_TEMPORAL_ATTENTION_BLOCK_IDENTIFIERS = ("temporal_transformer_blocks",) -_CROSS_ATTENTION_BLOCK_IDENTIFIERS = ("blocks", "transformer_blocks") - - @dataclass class PyramidAttentionBroadcastConfig: r""" @@ -76,9 +75,9 @@ class PyramidAttentionBroadcastConfig: temporal_attention_timestep_skip_range: Tuple[int, int] = (100, 800) cross_attention_timestep_skip_range: Tuple[int, int] = (100, 800) - spatial_attention_block_identifiers: Tuple[str, ...] = _SPATIAL_ATTENTION_BLOCK_IDENTIFIERS - temporal_attention_block_identifiers: Tuple[str, ...] = _TEMPORAL_ATTENTION_BLOCK_IDENTIFIERS - cross_attention_block_identifiers: Tuple[str, ...] = _CROSS_ATTENTION_BLOCK_IDENTIFIERS + spatial_attention_block_identifiers: Tuple[str, ...] = _SPATIAL_TRANSFORMER_BLOCK_IDENTIFIERS + temporal_attention_block_identifiers: Tuple[str, ...] = _TEMPORAL_TRANSFORMER_BLOCK_IDENTIFIERS + cross_attention_block_identifiers: Tuple[str, ...] = _CROSS_TRANSFORMER_BLOCK_IDENTIFIERS current_timestep_callback: Callable[[], int] = None diff --git a/src/diffusers/hooks/utils.py b/src/diffusers/hooks/utils.py new file mode 100644 index 000000000000..a72d5db59192 --- /dev/null +++ b/src/diffusers/hooks/utils.py @@ -0,0 +1,59 @@ +import ast +import inspect +import textwrap +from typing import List + + +def _extract_return_information(func) -> List[str]: + """Extracts return variable names in order from a function.""" + try: + source = inspect.getsource(func) + source = textwrap.dedent(source) # Modify indentation to make parsing compatible + except (OSError, TypeError): + try: + source_file = inspect.getfile(func) + with open(source_file, "r", encoding="utf-8") as f: + source = f.read() + + # Extract function definition manually + source_lines = source.splitlines() + func_name = func.__name__ + start_line = None + indent_level = None + extracted_lines = [] + + for i, line in enumerate(source_lines): + stripped = line.strip() + if stripped.startswith(f"def {func_name}("): + start_line = i + indent_level = len(line) - len(line.lstrip()) + extracted_lines.append(line) + continue + + if start_line is not None: + # Stop when indentation level decreases (end of function) + current_indent = len(line) - len(line.lstrip()) + if current_indent <= indent_level and line.strip(): + break + extracted_lines.append(line) + + source = "\n".join(extracted_lines) + except Exception as e: + raise RuntimeError(f"Failed to retrieve function source: {e}") + + # Parse source code using AST + tree = ast.parse(source) + return_vars = [] + + class ReturnVisitor(ast.NodeVisitor): + def visit_Return(self, node): + if isinstance(node.value, ast.Tuple): + # Multiple return values + return_vars.extend(var.id for var in node.value.elts if isinstance(var, ast.Name)) + elif isinstance(node.value, ast.Name): + # Single return value + return_vars.append(node.value.id) + + visitor = ReturnVisitor() + visitor.visit(tree) + return return_vars diff --git a/src/diffusers/models/transformers/transformer_flux.py b/src/diffusers/models/transformers/transformer_flux.py index 8a36f2254e44..5ab675defa09 100644 --- a/src/diffusers/models/transformers/transformer_flux.py +++ b/src/diffusers/models/transformers/transformer_flux.py @@ -87,10 +87,13 @@ def __init__(self, dim, num_attention_heads, attention_head_dim, mlp_ratio=4.0): def forward( self, hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, temb: torch.Tensor, image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, joint_attention_kwargs: Optional[Dict[str, Any]] = None, ) -> torch.Tensor: + hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1) + residual = hidden_states norm_hidden_states, gate = self.norm(hidden_states, emb=temb) mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states)) @@ -108,7 +111,10 @@ def forward( if hidden_states.dtype == torch.float16: hidden_states = hidden_states.clip(-65504, 65504) - return hidden_states + encoder_hidden_states, hidden_states = hidden_states.split( + [encoder_hidden_states.size(1), hidden_states.size(1) - encoder_hidden_states.size(1)], dim=1 + ) + return hidden_states, encoder_hidden_states @maybe_allow_in_graph @@ -224,7 +230,7 @@ def forward( if encoder_hidden_states.dtype == torch.float16: encoder_hidden_states = encoder_hidden_states.clip(-65504, 65504) - return encoder_hidden_states, hidden_states + return hidden_states, encoder_hidden_states class FluxTransformer2DModel( @@ -517,7 +523,7 @@ def forward( for index_block, block in enumerate(self.transformer_blocks): if torch.is_grad_enabled() and self.gradient_checkpointing: - encoder_hidden_states, hidden_states = self._gradient_checkpointing_func( + hidden_states, encoder_hidden_states = self._gradient_checkpointing_func( block, hidden_states, encoder_hidden_states, @@ -526,7 +532,7 @@ def forward( ) else: - encoder_hidden_states, hidden_states = block( + hidden_states, encoder_hidden_states = block( hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, temb=temb, @@ -545,20 +551,21 @@ def forward( ) else: hidden_states = hidden_states + controlnet_block_samples[index_block // interval_control] - hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1) for index_block, block in enumerate(self.single_transformer_blocks): if torch.is_grad_enabled() and self.gradient_checkpointing: - hidden_states = self._gradient_checkpointing_func( + hidden_states, encoder_hidden_states = self._gradient_checkpointing_func( block, hidden_states, + encoder_hidden_states, temb, image_rotary_emb, ) else: - hidden_states = block( + hidden_states, encoder_hidden_states = block( hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, temb=temb, image_rotary_emb=image_rotary_emb, joint_attention_kwargs=joint_attention_kwargs, @@ -568,12 +575,7 @@ def forward( if controlnet_single_block_samples is not None: interval_control = len(self.single_transformer_blocks) / len(controlnet_single_block_samples) interval_control = int(np.ceil(interval_control)) - hidden_states[:, encoder_hidden_states.shape[1] :, ...] = ( - hidden_states[:, encoder_hidden_states.shape[1] :, ...] - + controlnet_single_block_samples[index_block // interval_control] - ) - - hidden_states = hidden_states[:, encoder_hidden_states.shape[1] :, ...] + hidden_states = hidden_states + controlnet_single_block_samples[index_block // interval_control] hidden_states = self.norm_out(hidden_states, temb) output = self.proj_out(hidden_states)