From 0c4c1a843089a6411233a69b7e27473d78e869c8 Mon Sep 17 00:00:00 2001 From: Aryan Date: Mon, 14 Apr 2025 10:04:13 +0200 Subject: [PATCH 01/24] cfg; slg; pag; sdxl without controlnet --- src/diffusers/__init__.py | 15 + src/diffusers/guiders/__init__.py | 24 ++ .../guiders/classifier_free_guidance.py | 111 +++++++ src/diffusers/guiders/guider_utils.py | 148 ++++++++++ src/diffusers/guiders/skip_layer_guidance.py | 218 ++++++++++++++ src/diffusers/hooks/__init__.py | 1 + src/diffusers/hooks/_common.py | 32 +++ src/diffusers/hooks/_helpers.py | 271 ++++++++++++++++++ src/diffusers/hooks/layer_skip.py | 194 +++++++++++++ .../pipeline_stable_diffusion_xl_modular.py | 147 +++++----- src/diffusers/utils/torch_utils.py | 5 + 11 files changed, 1093 insertions(+), 73 deletions(-) create mode 100644 src/diffusers/guiders/__init__.py create mode 100644 src/diffusers/guiders/classifier_free_guidance.py create mode 100644 src/diffusers/guiders/guider_utils.py create mode 100644 src/diffusers/guiders/skip_layer_guidance.py create mode 100644 src/diffusers/hooks/_common.py create mode 100644 src/diffusers/hooks/_helpers.py create mode 100644 src/diffusers/hooks/layer_skip.py diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 440c67da629d..d8e274cb93b6 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -33,6 +33,7 @@ _import_structure = { "configuration_utils": ["ConfigMixin"], + "guiders": [], "hooks": [], "loaders": ["FromOriginalModelMixin"], "models": [], @@ -129,12 +130,20 @@ _import_structure["utils.dummy_pt_objects"] = [name for name in dir(dummy_pt_objects) if not name.startswith("_")] else: + _import_structure["guiders"].extend( + [ + "ClassifierFreeGuidance", + "SkipLayerGuidance", + ] + ) _import_structure["hooks"].extend( [ "FasterCacheConfig", "HookRegistry", "PyramidAttentionBroadcastConfig", + "LayerSkipConfig", "apply_faster_cache", + "apply_layer_skip", "apply_pyramid_attention_broadcast", ] ) @@ -711,10 +720,16 @@ except OptionalDependencyNotAvailable: from .utils.dummy_pt_objects import * # noqa F403 else: + from .guiders import ( + ClassifierFreeGuidance, + SkipLayerGuidance, + ) from .hooks import ( FasterCacheConfig, HookRegistry, + LayerSkipConfig, PyramidAttentionBroadcastConfig, + apply_layer_skip, apply_faster_cache, apply_pyramid_attention_broadcast, ) diff --git a/src/diffusers/guiders/__init__.py b/src/diffusers/guiders/__init__.py new file mode 100644 index 000000000000..adef65277b6d --- /dev/null +++ b/src/diffusers/guiders/__init__.py @@ -0,0 +1,24 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Union + +from ..utils import is_torch_available + + +if is_torch_available(): + from .classifier_free_guidance import ClassifierFreeGuidance + from .skip_layer_guidance import SkipLayerGuidance + + GuiderType = Union[ClassifierFreeGuidance, SkipLayerGuidance] diff --git a/src/diffusers/guiders/classifier_free_guidance.py b/src/diffusers/guiders/classifier_free_guidance.py new file mode 100644 index 000000000000..4048d70484c5 --- /dev/null +++ b/src/diffusers/guiders/classifier_free_guidance.py @@ -0,0 +1,111 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from typing import Optional, Union, Tuple, List + +import torch + +from .guider_utils import BaseGuidance, rescale_noise_cfg, _default_prepare_inputs + + +class ClassifierFreeGuidance(BaseGuidance): + """ + Classifier-free guidance (CFG): https://huggingface.co/papers/2207.12598 + CFG is a technique used to improve generation quality and condition-following in diffusion models. It works by + jointly training a model on both conditional and unconditional data, and using a weighted sum of the two during + inference. This allows the model to tradeoff between generation quality and sample diversity. + The original paper proposes scaling and shifting the conditional distribution based on the difference between + conditional and unconditional predictions. [x_pred = x_cond + scale * (x_cond - x_uncond)] + Diffusers implemented the scaling and shifting on the unconditional prediction instead based on the [Imagen + paper](https://huggingface.co/papers/2205.11487), which is equivalent to what the original paper proposed in + theory. [x_pred = x_uncond + scale * (x_cond - x_uncond)] + The intution behind the original formulation can be thought of as moving the conditional distribution estimates + further away from the unconditional distribution estimates, while the diffusers-native implementation can be + thought of as moving the unconditional distribution towards the conditional distribution estimates to get rid of + the unconditional predictions (usually negative features like "bad quality, bad anotomy, watermarks", etc.) + The `use_original_formulation` argument can be set to `True` to use the original CFG formulation mentioned in the + paper. By default, we use the diffusers-native implementation that has been in the codebase for a long time. + Args: + guidance_scale (`float`, defaults to `7.5`): + The scale parameter for classifier-free guidance. Higher values result in stronger conditioning on the text + prompt, while lower values allow for more freedom in generation. Higher values may lead to saturation and + deterioration of image quality. + guidance_rescale (`float`, defaults to `0.0`): + The rescale factor applied to the noise predictions. This is used to improve image quality and fix + overexposure. Based on Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are + Flawed](https://huggingface.co/papers/2305.08891). + use_original_formulation (`bool`, defaults to `False`): + Whether to use the original formulation of classifier-free guidance as proposed in the paper. By default, + we use the diffusers-native implementation that has been in the codebase for a long time. See + [~guiders.classifier_free_guidance.ClassifierFreeGuidance] for more details. + start (`float`, defaults to `0.0`): + The fraction of the total number of denoising steps after which guidance starts. + stop (`float`, defaults to `1.0`): + The fraction of the total number of denoising steps after which guidance stops. + """ + + _input_predictions = ["pred_cond", "pred_uncond"] + + def __init__( + self, guidance_scale: float = 7.5, guidance_rescale: float = 0.0, use_original_formulation: bool = False, start: float = 0.0, stop: float = 1.0 + ): + super().__init__(start, stop) + + self.guidance_scale = guidance_scale + self.guidance_rescale = guidance_rescale + self.use_original_formulation = use_original_formulation + + def prepare_inputs(self, denoiser: torch.nn.Module, *args: Union[Tuple[torch.Tensor], List[torch.Tensor]]) -> Tuple[List[torch.Tensor], ...]: + return _default_prepare_inputs(denoiser, self.num_conditions, *args) + + def prepare_outputs(self, denoiser: torch.nn.Module, pred: torch.Tensor) -> None: + self._num_outputs_prepared += 1 + if self._num_outputs_prepared > self.num_conditions: + raise ValueError(f"Expected {self.num_conditions} outputs, but prepare_outputs called more times.") + key = self._input_predictions[self._num_outputs_prepared - 1] + self._preds[key] = pred + + def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> torch.Tensor: + pred = None + + if not self._is_cfg_enabled(): + pred = pred_cond + else: + shift = pred_cond - pred_uncond + pred = pred_cond if self.use_original_formulation else pred_uncond + pred = pred + self.guidance_scale * shift + + if self.guidance_rescale > 0.0: + pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale) + + return pred + + @property + def num_conditions(self) -> int: + num_conditions = 1 + if self._is_cfg_enabled(): + num_conditions += 1 + return num_conditions + + def _is_cfg_enabled(self) -> bool: + skip_start_step = int(self._start * self._num_inference_steps) + skip_stop_step = int(self._stop * self._num_inference_steps) + is_within_range = skip_start_step <= self._step < skip_stop_step + is_close = False + if self.use_original_formulation: + is_close = math.isclose(self.guidance_scale, 0.0) + else: + is_close = math.isclose(self.guidance_scale, 1.0) + return is_within_range and not is_close diff --git a/src/diffusers/guiders/guider_utils.py b/src/diffusers/guiders/guider_utils.py new file mode 100644 index 000000000000..ecde7334b2b9 --- /dev/null +++ b/src/diffusers/guiders/guider_utils.py @@ -0,0 +1,148 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import TYPE_CHECKING, Any, Dict, List, Tuple, Union + +import torch + +from ..utils import get_logger + + +if TYPE_CHECKING: + from ..models.attention_processor import AttentionProcessor + + +logger = get_logger(__name__) # pylint: disable=invalid-name + + +class BaseGuidance: + r"""Base class providing the skeleton for implementing guidance techniques.""" + + _input_predictions = None + + def __init__(self, start: float = 0.0, stop: float = 1.0): + self._start = start + self._stop = stop + self._step: int = None + self._num_inference_steps: int = None + self._timestep: torch.LongTensor = None + self._preds: Dict[str, torch.Tensor] = {} + self._num_outputs_prepared: int = 0 + + if not (0.0 <= start < 1.0): + raise ValueError( + f"Expected `start` to be between 0.0 and 1.0, but got {start}." + ) + if not (start <= stop <= 1.0): + raise ValueError( + f"Expected `stop` to be between {start} and 1.0, but got {stop}." + ) + + if self._input_predictions is None or not isinstance(self._input_predictions, list): + raise ValueError( + "`_input_predictions` must be a list of required prediction names for the guidance technique." + ) + + def set_state(self, step: int, num_inference_steps: int, timestep: torch.LongTensor) -> None: + self._step = step + self._num_inference_steps = num_inference_steps + self._timestep = timestep + self._preds = {} + self._num_outputs_prepared = 0 + + def prepare_inputs(self, denoiser: torch.nn.Module, *args: Union[Tuple[torch.Tensor], List[torch.Tensor]]) -> Tuple[List[torch.Tensor], ...]: + raise NotImplementedError("GuidanceMixin::prepare_inputs must be implemented in subclasses.") + + def prepare_outputs(self, denoiser: torch.nn.Module, pred: torch.Tensor) -> None: + raise NotImplementedError("GuidanceMixin::prepare_outputs must be implemented in subclasses.") + + def __call__(self, **kwargs) -> Any: + if len(kwargs) != self.num_conditions: + raise ValueError( + f"Expected {self.num_conditions} arguments, but got {len(kwargs)}. Please provide the correct number of arguments." + ) + return self.forward(**kwargs) + + def forward(self, *args, **kwargs) -> Any: + raise NotImplementedError("GuidanceMixin::forward must be implemented in subclasses.") + + @property + def num_conditions(self) -> int: + raise NotImplementedError("GuidanceMixin::num_conditions must be implemented in subclasses.") + + @property + def outputs(self) -> Dict[str, torch.Tensor]: + return self._preds + + +def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): + r""" + Rescales `noise_cfg` tensor based on `guidance_rescale` to improve image quality and fix overexposure. Based on + Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are + Flawed](https://arxiv.org/pdf/2305.08891.pdf). + Args: + noise_cfg (`torch.Tensor`): + The predicted noise tensor for the guided diffusion process. + noise_pred_text (`torch.Tensor`): + The predicted noise tensor for the text-guided diffusion process. + guidance_rescale (`float`, *optional*, defaults to 0.0): + A rescale factor applied to the noise predictions. + Returns: + noise_cfg (`torch.Tensor`): The rescaled noise prediction tensor. + """ + std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True) + std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True) + # rescale the results from guidance (fixes overexposure) + noise_pred_rescaled = noise_cfg * (std_text / std_cfg) + # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images + noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg + return noise_cfg + + +def _default_prepare_inputs(denoiser: torch.nn.Module, num_conditions: int, *args: Union[Tuple[torch.Tensor], List[torch.Tensor]]) -> Tuple[List[torch.Tensor], ...]: + """ + Prepares the inputs for the denoiser by ensuring that the conditional and unconditional inputs are correctly + prepared based on required number of conditions. This function is used in the `prepare_inputs` method of the + `GuidanceMixin` class. + + Either tensors or tuples/lists of tensors can be provided. If a tuple/list is provided, it should contain two elements: + - The first element is the conditional input. + - The second element is the unconditional input or None. + + If only the conditional input is provided, it will be repeated for all batches. + + If both conditional and unconditional inputs are provided, they are alternated as batches of data. + """ + list_of_inputs = [] + for arg in args: + if arg is None or isinstance(arg, torch.Tensor): + list_of_inputs.append([arg] * num_conditions) + elif isinstance(arg, (tuple, list)): + if len(arg) != 2: + raise ValueError( + f"Expected a tuple or list of length 2, but got {len(arg)} for argument {arg}. Please provide a tuple/list of length 2 " + f"with the first element being the conditional input and the second element being the unconditional input or None." + ) + if arg[1] is None: + # Only conditioning inputs for all batches + list_of_inputs.append([arg[0]] * num_conditions) + else: + # Alternating conditional and unconditional inputs as batches + inputs = [arg[i % 2] for i in range(num_conditions)] + list_of_inputs.append(inputs) + else: + raise ValueError( + f"Expected a tensor, tuple, or list, but got {type(arg)} for argument {arg}. Please provide a tensor, tuple, or list." + ) + return tuple(list_of_inputs) diff --git a/src/diffusers/guiders/skip_layer_guidance.py b/src/diffusers/guiders/skip_layer_guidance.py new file mode 100644 index 000000000000..e20a700fee6a --- /dev/null +++ b/src/diffusers/guiders/skip_layer_guidance.py @@ -0,0 +1,218 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from typing import List, Optional, Tuple, Union + +import torch + +from ..hooks import HookRegistry, LayerSkipConfig +from ..hooks.layer_skip import _apply_layer_skip_hook +from .guider_utils import BaseGuidance, rescale_noise_cfg + + +class SkipLayerGuidance(BaseGuidance): + """ + Skip Layer Guidance (SLG): https://github.com/Stability-AI/sd3.5 Spatio-Temporal Guidance (STG): + https://huggingface.co/papers/2411.18664 + SLG was introduced by StabilityAI for improving structure and anotomy coherence in generated images. It works by + skipping the forward pass of specified transformer blocks during the denoising process on an additional conditional + batch of data, apart from the conditional and unconditional batches already used in CFG + ([~guiders.classifier_free_guidance.ClassifierFreeGuidance]), and then scaling and shifting the CFG predictions + based on the difference between conditional without skipping and conditional with skipping predictions. + The intution behind SLG can be thought of as moving the CFG predicted distribution estimates further away from + worse versions of the conditional distribution estimates (because skipping layers is equivalent to using a worse + version of the model for the conditional prediction). + STG is an improvement and follow-up work combining ideas from SLG, PAG and similar techniques for improving + generation quality in video diffusion models. + Additional reading: + - [Guiding a Diffusion Model with a Bad Version of Itself](https://huggingface.co/papers/2406.02507) + The values for `skip_layer_guidance_scale`, `skip_layer_guidance_start`, and `skip_layer_guidance_stop` are + defaulted to the recommendations by StabilityAI for Stable Diffusion 3.5 Medium. + Args: + guidance_scale (`float`, defaults to `7.5`): + The scale parameter for classifier-free guidance. Higher values result in stronger conditioning on the text + prompt, while lower values allow for more freedom in generation. Higher values may lead to saturation and + deterioration of image quality. + skip_layer_guidance_scale (`float`, defaults to `2.8`): + The scale parameter for skip layer guidance. Anatomy and structure coherence may improve with higher + values, but it may also lead to overexposure and saturation. + skip_layer_guidance_layers (`int` or `List[int]`, *optional*): + The layer indices to apply skip layer guidance to. Can be a single integer or a list of integers. If not + provided, `skip_layer_config` must be provided. The recommended values are `[7, 8, 9]` for Stable Diffusion + 3.5 Medium. + skip_layer_config (`LayerSkipConfig` or `List[LayerSkipConfig]`, *optional*): + The configuration for the skip layer guidance. Can be a single `LayerSkipConfig` or a list of + `LayerSkipConfig`. If not provided, `skip_layer_guidance_layers` must be provided. + guidance_rescale (`float`, defaults to `0.0`): + The rescale factor applied to the noise predictions. This is used to improve image quality and fix + overexposure. Based on Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are + Flawed](https://huggingface.co/papers/2305.08891). + use_original_formulation (`bool`, defaults to `False`): + Whether to use the original formulation of classifier-free guidance as proposed in the paper. By default, + we use the diffusers-native implementation that has been in the codebase for a long time. See + [~guiders.classifier_free_guidance.ClassifierFreeGuidance] for more details. + start (`float`, defaults to `0.01`): + The fraction of the total number of denoising steps after which guidance starts. + stop (`float`, defaults to `0.2`): + The fraction of the total number of denoising steps after which guidance stops. + """ + + _input_predictions = ["pred_cond", "pred_uncond", "pred_cond_skip"] + + def __init__( + self, + guidance_scale: float = 7.5, + skip_layer_guidance_scale: float = 2.8, + skip_layer_guidance_layers: Optional[Union[int, List[int]]] = None, + skip_layer_config: Union[LayerSkipConfig, List[LayerSkipConfig]] = None, + guidance_rescale: float = 0.0, + use_original_formulation: bool = False, + start: float = 0.01, + stop: float = 0.2, + ): + super().__init__(start, stop) + + self.guidance_scale = guidance_scale + self.skip_layer_guidance_scale = skip_layer_guidance_scale + self.guidance_rescale = guidance_rescale + self.use_original_formulation = use_original_formulation + + if skip_layer_guidance_layers is None and skip_layer_config is None: + raise ValueError( + "Either `skip_layer_guidance_layers` or `skip_layer_config` must be provided to enable Skip Layer Guidance." + ) + if skip_layer_guidance_layers is not None and skip_layer_config is not None: + raise ValueError("Only one of `skip_layer_guidance_layers` or `skip_layer_config` can be provided.") + + if skip_layer_guidance_layers is not None: + if isinstance(skip_layer_guidance_layers, int): + skip_layer_guidance_layers = [skip_layer_guidance_layers] + if not isinstance(skip_layer_guidance_layers, list): + raise ValueError( + f"Expected `skip_layer_guidance_layers` to be an int or a list of ints, but got {type(skip_layer_guidance_layers)}." + ) + skip_layer_config = [LayerSkipConfig(layer, fqn="auto") for layer in skip_layer_guidance_layers] + + if isinstance(skip_layer_config, LayerSkipConfig): + skip_layer_config = [skip_layer_config] + + if not isinstance(skip_layer_config, list): + raise ValueError( + f"Expected `skip_layer_config` to be a LayerSkipConfig or a list of LayerSkipConfig, but got {type(skip_layer_config)}." + ) + + self.skip_layer_config = skip_layer_config + self._skip_layer_hook_names = [f"SkipLayerGuidance_{i}" for i in range(len(self.skip_layer_config))] + + def prepare_inputs(self, denoiser: torch.nn.Module, *args: Union[Tuple[torch.Tensor], List[torch.Tensor]]) -> Tuple[List[torch.Tensor], ...]: + if self._num_outputs_prepared == 0 and self._is_slg_enabled(): + for name, config in zip(self._skip_layer_hook_names, self.skip_layer_config): + _apply_layer_skip_hook(denoiser, config, name=name) + + num_conditions = self.num_conditions + list_of_inputs = [] + for arg in args: + if arg is None or isinstance(arg, torch.Tensor): + list_of_inputs.append([arg] * num_conditions) + elif isinstance(arg, (tuple, list)): + if len(arg) != 2: + raise ValueError( + f"Expected a tuple or list of length 2, but got {len(arg)} for argument {arg}. Please provide a tuple/list of length 2 " + f"with the first element being the conditional input and the second element being the unconditional input or None." + ) + if arg[1] is None: + # Only conditioning inputs for all batches + list_of_inputs.append([arg[0]] * num_conditions) + else: + list_of_inputs.append([arg[0], arg[1], arg[0]]) + else: + raise ValueError( + f"Expected a tensor, tuple, or list, but got {type(arg)} for argument {arg}. Please provide a tensor, tuple, or list." + ) + return tuple(list_of_inputs) + + def prepare_outputs(self, denoiser: torch.nn.Module, pred: torch.Tensor) -> None: + self._num_outputs_prepared += 1 + if self._num_outputs_prepared > self.num_conditions: + raise ValueError(f"Expected {self.num_conditions} outputs, but prepare_outputs called more times.") + key = self._input_predictions[self._num_outputs_prepared - 1] + if not self._is_cfg_enabled() and self._is_slg_enabled(): + # If we're predicting pred_cond and pred_cond_skip only, we need to set the key to pred_cond_skip + # to avoid writing into pred_uncond which is not used + if self._num_outputs_prepared == 2: + key = "pred_cond_skip" + self._preds[key] = pred + + if self._num_outputs_prepared == self.num_conditions: + registry = HookRegistry.check_if_exists_or_initialize(denoiser) + # Remove the hooks after inference + for hook_name in self._skip_layer_hook_names: + registry.remove_hook(hook_name, recurse=True) + + def forward( + self, + pred_cond: torch.Tensor, + pred_uncond: Optional[torch.Tensor] = None, + pred_cond_skip: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + pred = None + + if not self._is_cfg_enabled() and not self._is_slg_enabled(): + pred = pred_cond + elif not self._is_cfg_enabled(): + shift = pred_cond - pred_cond_skip + pred = pred_cond if self.use_original_formulation else pred_cond_skip + pred = pred + self.skip_layer_guidance_scale * shift + elif not self._is_slg_enabled(): + shift = pred_cond - pred_uncond + pred = pred_cond if self.use_original_formulation else pred_uncond + pred = pred + self.guidance_scale * shift + else: + shift = pred_cond - pred_uncond + shift_skip = pred_cond - pred_cond_skip + pred = pred_cond if self.use_original_formulation else pred_uncond + pred = pred + self.guidance_scale * shift + self.skip_layer_guidance_scale * shift_skip + + if self.guidance_rescale > 0.0: + pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale) + + return pred + + @property + def num_conditions(self) -> int: + num_conditions = 1 + if self._is_cfg_enabled(): + num_conditions += 1 + if self._is_slg_enabled(): + num_conditions += 1 + return num_conditions + + def _is_cfg_enabled(self) -> bool: + skip_start_step = int(self._start * self._num_inference_steps) + skip_stop_step = int(self._stop * self._num_inference_steps) + is_within_range = skip_start_step <= self._step < skip_stop_step + is_close = False + if self.use_original_formulation: + is_close = math.isclose(self.guidance_scale, 0.0) + else: + is_close = math.isclose(self.guidance_scale, 1.0) + return is_within_range and not is_close + + def _is_slg_enabled(self) -> bool: + skip_start_step = int(self._start * self._num_inference_steps) + skip_stop_step = int(self._stop * self._num_inference_steps) + is_within_range = skip_start_step < self._step < skip_stop_step + is_zero = math.isclose(self.skip_layer_guidance_scale, 0.0) + return is_within_range and not is_zero diff --git a/src/diffusers/hooks/__init__.py b/src/diffusers/hooks/__init__.py index 764ceb25b465..142ff860371c 100644 --- a/src/diffusers/hooks/__init__.py +++ b/src/diffusers/hooks/__init__.py @@ -5,5 +5,6 @@ from .faster_cache import FasterCacheConfig, apply_faster_cache from .group_offloading import apply_group_offloading from .hooks import HookRegistry, ModelHook + from .layer_skip import LayerSkipConfig, apply_layer_skip from .layerwise_casting import apply_layerwise_casting, apply_layerwise_casting_hook from .pyramid_attention_broadcast import PyramidAttentionBroadcastConfig, apply_pyramid_attention_broadcast diff --git a/src/diffusers/hooks/_common.py b/src/diffusers/hooks/_common.py new file mode 100644 index 000000000000..6ea83dcbf6a7 --- /dev/null +++ b/src/diffusers/hooks/_common.py @@ -0,0 +1,32 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ..models.attention import FeedForward, LuminaFeedForward +from ..models.attention_processor import Attention, MochiAttention + + +_ATTENTION_CLASSES = (Attention, MochiAttention) +_FEEDFORWARD_CLASSES = (FeedForward, LuminaFeedForward) + +_SPATIAL_TRANSFORMER_BLOCK_IDENTIFIERS = ("blocks", "transformer_blocks", "single_transformer_blocks", "layers") +_TEMPORAL_TRANSFORMER_BLOCK_IDENTIFIERS = ("temporal_transformer_blocks",) +_CROSS_TRANSFORMER_BLOCK_IDENTIFIERS = ("blocks", "transformer_blocks", "layers") + +_ALL_TRANSFORMER_BLOCK_IDENTIFIERS = tuple( + { + *_SPATIAL_TRANSFORMER_BLOCK_IDENTIFIERS, + *_TEMPORAL_TRANSFORMER_BLOCK_IDENTIFIERS, + *_CROSS_TRANSFORMER_BLOCK_IDENTIFIERS, + } +) diff --git a/src/diffusers/hooks/_helpers.py b/src/diffusers/hooks/_helpers.py new file mode 100644 index 000000000000..9043ffc41838 --- /dev/null +++ b/src/diffusers/hooks/_helpers.py @@ -0,0 +1,271 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass +from typing import Any, Callable, Type + +from ..models.attention import BasicTransformerBlock +from ..models.attention_processor import AttnProcessor2_0 +from ..models.transformers.cogvideox_transformer_3d import CogVideoXBlock +from ..models.transformers.transformer_cogview4 import CogView4AttnProcessor, CogView4TransformerBlock +from ..models.transformers.transformer_flux import FluxSingleTransformerBlock, FluxTransformerBlock +from ..models.transformers.transformer_hunyuan_video import ( + HunyuanVideoSingleTransformerBlock, + HunyuanVideoTokenReplaceSingleTransformerBlock, + HunyuanVideoTokenReplaceTransformerBlock, + HunyuanVideoTransformerBlock, +) +from ..models.transformers.transformer_ltx import LTXVideoTransformerBlock +from ..models.transformers.transformer_mochi import MochiTransformerBlock +from ..models.transformers.transformer_wan import WanTransformerBlock + + +@dataclass +class AttentionProcessorMetadata: + skip_processor_output_fn: Callable[[Any], Any] + + +@dataclass +class TransformerBlockMetadata: + skip_block_output_fn: Callable[[Any], Any] + return_hidden_states_index: int = None + return_encoder_hidden_states_index: int = None + + +class AttentionProcessorRegistry: + _registry = {} + + @classmethod + def register(cls, model_class: Type, metadata: AttentionProcessorMetadata): + cls._registry[model_class] = metadata + + @classmethod + def get(cls, model_class: Type) -> AttentionProcessorMetadata: + if model_class not in cls._registry: + raise ValueError(f"Model class {model_class} not registered.") + return cls._registry[model_class] + + +class TransformerBlockRegistry: + _registry = {} + + @classmethod + def register(cls, model_class: Type, metadata: TransformerBlockMetadata): + cls._registry[model_class] = metadata + + @classmethod + def get(cls, model_class: Type) -> TransformerBlockMetadata: + if model_class not in cls._registry: + raise ValueError(f"Model class {model_class} not registered.") + return cls._registry[model_class] + + +def _register_attention_processors_metadata(): + # AttnProcessor2_0 + AttentionProcessorRegistry.register( + model_class=AttnProcessor2_0, + metadata=AttentionProcessorMetadata( + skip_processor_output_fn=_skip_proc_output_fn_Attention_AttnProcessor2_0, + ), + ) + + # CogView4AttnProcessor + AttentionProcessorRegistry.register( + model_class=CogView4AttnProcessor, + metadata=AttentionProcessorMetadata( + skip_processor_output_fn=_skip_proc_output_fn_Attention_CogView4AttnProcessor, + ), + ) + + +def _register_transformer_blocks_metadata(): + # BasicTransformerBlock + TransformerBlockRegistry.register( + model_class=BasicTransformerBlock, + metadata=TransformerBlockMetadata( + skip_block_output_fn=_skip_block_output_fn_BasicTransformerBlock, + return_hidden_states_index=0, + return_encoder_hidden_states_index=None, + ), + ) + + # CogVideoX + TransformerBlockRegistry.register( + model_class=CogVideoXBlock, + metadata=TransformerBlockMetadata( + skip_block_output_fn=_skip_block_output_fn_CogVideoXBlock, + return_hidden_states_index=0, + return_encoder_hidden_states_index=1, + ), + ) + + # CogView4 + TransformerBlockRegistry.register( + model_class=CogView4TransformerBlock, + metadata=TransformerBlockMetadata( + skip_block_output_fn=_skip_block_output_fn_CogView4TransformerBlock, + return_hidden_states_index=0, + return_encoder_hidden_states_index=1, + ), + ) + + # Flux + TransformerBlockRegistry.register( + model_class=FluxTransformerBlock, + metadata=TransformerBlockMetadata( + skip_block_output_fn=_skip_block_output_fn_FluxTransformerBlock, + return_hidden_states_index=1, + return_encoder_hidden_states_index=0, + ), + ) + TransformerBlockRegistry.register( + model_class=FluxSingleTransformerBlock, + metadata=TransformerBlockMetadata( + skip_block_output_fn=_skip_block_output_fn_FluxSingleTransformerBlock, + return_hidden_states_index=1, + return_encoder_hidden_states_index=0, + ), + ) + + # HunyuanVideo + TransformerBlockRegistry.register( + model_class=HunyuanVideoTransformerBlock, + metadata=TransformerBlockMetadata( + skip_block_output_fn=_skip_block_output_fn_HunyuanVideoTransformerBlock, + return_hidden_states_index=0, + return_encoder_hidden_states_index=1, + ), + ) + TransformerBlockRegistry.register( + model_class=HunyuanVideoSingleTransformerBlock, + metadata=TransformerBlockMetadata( + skip_block_output_fn=_skip_block_output_fn_HunyuanVideoSingleTransformerBlock, + return_hidden_states_index=0, + return_encoder_hidden_states_index=1, + ), + ) + TransformerBlockRegistry.register( + model_class=HunyuanVideoTokenReplaceTransformerBlock, + metadata=TransformerBlockMetadata( + skip_block_output_fn=_skip_block_output_fn_HunyuanVideoTokenReplaceTransformerBlock, + return_hidden_states_index=0, + return_encoder_hidden_states_index=1, + ), + ) + TransformerBlockRegistry.register( + model_class=HunyuanVideoTokenReplaceSingleTransformerBlock, + metadata=TransformerBlockMetadata( + skip_block_output_fn=_skip_block_output_fn_HunyuanVideoTokenReplaceSingleTransformerBlock, + return_hidden_states_index=0, + return_encoder_hidden_states_index=1, + ), + ) + + # LTXVideo + TransformerBlockRegistry.register( + model_class=LTXVideoTransformerBlock, + metadata=TransformerBlockMetadata( + skip_block_output_fn=_skip_block_output_fn_LTXVideoTransformerBlock, + return_hidden_states_index=0, + return_encoder_hidden_states_index=None, + ), + ) + + # Mochi + TransformerBlockRegistry.register( + model_class=MochiTransformerBlock, + metadata=TransformerBlockMetadata( + skip_block_output_fn=_skip_block_output_fn_MochiTransformerBlock, + return_hidden_states_index=0, + return_encoder_hidden_states_index=1, + ), + ) + + # Wan + TransformerBlockRegistry.register( + model_class=WanTransformerBlock, + metadata=TransformerBlockMetadata( + skip_block_output_fn=_skip_block_output_fn_WanTransformerBlock, + return_hidden_states_index=0, + return_encoder_hidden_states_index=None, + ), + ) + + +# fmt: off +def _skip_attention___ret___hidden_states(self, *args, **kwargs): + hidden_states = kwargs.get("hidden_states", None) + if hidden_states is None and len(args) > 0: + hidden_states = args[0] + return hidden_states + + +def _skip_attention___ret___hidden_states___encoder_hidden_states(self, *args, **kwargs): + hidden_states = kwargs.get("hidden_states", None) + encoder_hidden_states = kwargs.get("encoder_hidden_states", None) + if hidden_states is None and len(args) > 0: + hidden_states = args[0] + if encoder_hidden_states is None and len(args) > 1: + encoder_hidden_states = args[1] + return hidden_states, encoder_hidden_states + + +_skip_proc_output_fn_Attention_AttnProcessor2_0 = _skip_attention___ret___hidden_states +_skip_proc_output_fn_Attention_CogView4AttnProcessor = _skip_attention___ret___hidden_states___encoder_hidden_states + + +def _skip_block_output_fn___hidden_states_0___ret___hidden_states(self, *args, **kwargs): + hidden_states = kwargs.get("hidden_states", None) + if hidden_states is None and len(args) > 0: + hidden_states = args[0] + return hidden_states + + +def _skip_block_output_fn___hidden_states_0___encoder_hidden_states_1___ret___hidden_states___encoder_hidden_states(self, *args, **kwargs): + hidden_states = kwargs.get("hidden_states", None) + encoder_hidden_states = kwargs.get("encoder_hidden_states", None) + if hidden_states is None and len(args) > 0: + hidden_states = args[0] + if encoder_hidden_states is None and len(args) > 1: + encoder_hidden_states = args[1] + return hidden_states, encoder_hidden_states + + +def _skip_block_output_fn___hidden_states_0___encoder_hidden_states_1___ret___encoder_hidden_states___hidden_states(self, *args, **kwargs): + hidden_states = kwargs.get("hidden_states", None) + encoder_hidden_states = kwargs.get("encoder_hidden_states", None) + if hidden_states is None and len(args) > 0: + hidden_states = args[0] + if encoder_hidden_states is None and len(args) > 1: + encoder_hidden_states = args[1] + return encoder_hidden_states, hidden_states + + +_skip_block_output_fn_BasicTransformerBlock = _skip_block_output_fn___hidden_states_0___ret___hidden_states +_skip_block_output_fn_CogVideoXBlock = _skip_block_output_fn___hidden_states_0___encoder_hidden_states_1___ret___hidden_states___encoder_hidden_states +_skip_block_output_fn_CogView4TransformerBlock = _skip_block_output_fn___hidden_states_0___encoder_hidden_states_1___ret___hidden_states___encoder_hidden_states +_skip_block_output_fn_FluxTransformerBlock = _skip_block_output_fn___hidden_states_0___encoder_hidden_states_1___ret___encoder_hidden_states___hidden_states +_skip_block_output_fn_FluxSingleTransformerBlock = _skip_block_output_fn___hidden_states_0___encoder_hidden_states_1___ret___encoder_hidden_states___hidden_states +_skip_block_output_fn_HunyuanVideoTransformerBlock = _skip_block_output_fn___hidden_states_0___encoder_hidden_states_1___ret___hidden_states___encoder_hidden_states +_skip_block_output_fn_HunyuanVideoSingleTransformerBlock = _skip_block_output_fn___hidden_states_0___encoder_hidden_states_1___ret___hidden_states___encoder_hidden_states +_skip_block_output_fn_HunyuanVideoTokenReplaceTransformerBlock = _skip_block_output_fn___hidden_states_0___encoder_hidden_states_1___ret___hidden_states___encoder_hidden_states +_skip_block_output_fn_HunyuanVideoTokenReplaceSingleTransformerBlock = _skip_block_output_fn___hidden_states_0___encoder_hidden_states_1___ret___hidden_states___encoder_hidden_states +_skip_block_output_fn_LTXVideoTransformerBlock = _skip_block_output_fn___hidden_states_0___ret___hidden_states +_skip_block_output_fn_MochiTransformerBlock = _skip_block_output_fn___hidden_states_0___encoder_hidden_states_1___ret___hidden_states___encoder_hidden_states +_skip_block_output_fn_WanTransformerBlock = _skip_block_output_fn___hidden_states_0___ret___hidden_states +# fmt: on + + +_register_attention_processors_metadata() +_register_transformer_blocks_metadata() diff --git a/src/diffusers/hooks/layer_skip.py b/src/diffusers/hooks/layer_skip.py new file mode 100644 index 000000000000..b8471523902a --- /dev/null +++ b/src/diffusers/hooks/layer_skip.py @@ -0,0 +1,194 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass +from typing import Callable, List, Optional + +import torch + +from ..utils import get_logger +from ..utils.torch_utils import unwrap_module +from ._common import _ALL_TRANSFORMER_BLOCK_IDENTIFIERS, _ATTENTION_CLASSES, _FEEDFORWARD_CLASSES +from ._helpers import AttentionProcessorRegistry, TransformerBlockRegistry +from .hooks import HookRegistry, ModelHook + + +logger = get_logger(__name__) # pylint: disable=invalid-name + +_LAYER_SKIP_HOOK = "layer_skip_hook" + + +@dataclass +class LayerSkipConfig: + r""" + Configuration for skipping internal transformer blocks when executing a transformer model. + + Args: + indices (`List[int]`): + The indices of the layer to skip. This is typically the first layer in the transformer block. + fqn (`str`, defaults to `"auto"`): + The fully qualified name identifying the stack of transformer blocks. Typically, this is + `transformer_blocks`, `single_transformer_blocks`, `blocks`, `layers`, or `temporal_transformer_blocks`. + For automatic detection, set this to `"auto"`. + "auto" only works on DiT models. For UNet models, you must provide the correct fqn. + skip_attention (`bool`, defaults to `True`): + Whether to skip attention blocks. + skip_ff (`bool`, defaults to `True`): + Whether to skip feed-forward blocks. + skip_attention_scores (`bool`, defaults to `False`): + Whether to skip attention score computation in the attention blocks. This is equivalent to using + `value` projections as the output of scaled dot product attention. + """ + + indices: List[int] + fqn: str = "auto" + skip_attention: bool = True + skip_attention_scores: bool = False + skip_ff: bool = True + + +class AttentionScoreSkipFunctionMode(torch.overrides.TorchFunctionMode): + def __init__(self) -> None: + super().__init__() + + def __torch_function__(self, func, types, args=(), kwargs=None): + if kwargs is None: + kwargs = {} + if func is torch.nn.functional.scaled_dot_product_attention: + value = kwargs.get("value", None) + if value is None: + value = args[2] + return value + return func(*args, **kwargs) + + +class AttentionProcessorSkipHook(ModelHook): + def __init__(self, skip_processor_output_fn: Callable, skip_attention_scores: bool = False): + self.skip_processor_output_fn = skip_processor_output_fn + self.skip_attention_scores = skip_attention_scores + + def new_forward(self, module: torch.nn.Module, *args, **kwargs): + if self.skip_attention_scores: + with AttentionScoreSkipFunctionMode(): + return self.fn_ref.original_forward(*args, **kwargs) + else: + return self.skip_processor_output_fn(module, *args, **kwargs) + + +class FeedForwardSkipHook(ModelHook): + def new_forward(self, module: torch.nn.Module, *args, **kwargs): + output = kwargs.get("hidden_states", None) + if output is None: + output = kwargs.get("x", None) + if output is None and len(args) > 0: + output = args[0] + return output + + +class TransformerBlockSkipHook(ModelHook): + def initialize_hook(self, module): + self._metadata = TransformerBlockRegistry.get(unwrap_module(module).__class__) + return module + + def new_forward(self, module: torch.nn.Module, *args, **kwargs): + return self._metadata.skip_block_output_fn(module, *args, **kwargs) + + +def apply_layer_skip(module: torch.nn.Module, config: LayerSkipConfig) -> None: + r""" + Apply layer skipping to internal layers of a transformer. + Args: + module (`torch.nn.Module`): + The transformer model to which the layer skip hook should be applied. + config (`LayerSkipConfig`): + The configuration for the layer skip hook. + Example: + ```python + >>> from diffusers import apply_layer_skip_hook, CogVideoXTransformer3DModel, LayerSkipConfig + >>> transformer = CogVideoXTransformer3DModel.from_pretrained("THUDM/CogVideoX-5b", torch_dtype=torch.bfloat16) + >>> config = LayerSkipConfig(layer_index=[10, 20], fqn="transformer_blocks") + >>> apply_layer_skip_hook(transformer, config) + ``` + """ + _apply_layer_skip_hook(module, config) + + +def _apply_layer_skip_hook(module: torch.nn.Module, config: LayerSkipConfig, name: Optional[str] = None) -> None: + name = name or _LAYER_SKIP_HOOK + + if config.skip_attention and config.skip_attention_scores: + raise ValueError("Cannot set both `skip_attention` and `skip_attention_scores` to True. Please choose one.") + + if config.fqn == "auto": + for identifier in _ALL_TRANSFORMER_BLOCK_IDENTIFIERS: + if hasattr(module, identifier): + config.fqn = identifier + break + else: + raise ValueError( + "Could not find a suitable identifier for the transformer blocks automatically. Please provide a valid " + "`fqn` (fully qualified name) that identifies a stack of transformer blocks." + ) + + transformer_blocks = _get_submodule_from_fqn(module, config.fqn) + if transformer_blocks is None or not isinstance(transformer_blocks, torch.nn.ModuleList): + raise ValueError( + f"Could not find {config.fqn} in the provided module, or configured `fqn` (fully qualified name) does not identify " + f"a `torch.nn.ModuleList`. Please provide a valid `fqn` that identifies a stack of transformer blocks." + ) + if len(config.indices) == 0: + raise ValueError("Layer index list is empty. Please provide a non-empty list of layer indices to skip.") + + blocks_found = False + for i, block in enumerate(transformer_blocks): + if i not in config.indices: + continue + blocks_found = True + if config.skip_attention and config.skip_ff: + logger.debug(f"Applying TransformerBlockSkipHook to '{config.fqn}.{i}'") + registry = HookRegistry.check_if_exists_or_initialize(block) + hook = TransformerBlockSkipHook() + registry.register_hook(hook, name) + elif config.skip_attention or config.skip_attention_scores: + for submodule_name, submodule in block.named_modules(): + if isinstance(submodule, _ATTENTION_CLASSES) and not submodule.is_cross_attention: + logger.debug(f"Applying AttentionProcessorSkipHook to '{config.fqn}.{i}.{submodule_name}'") + output_fn = AttentionProcessorRegistry.get(submodule.processor.__class__).skip_processor_output_fn + registry = HookRegistry.check_if_exists_or_initialize(submodule) + hook = AttentionProcessorSkipHook(output_fn, config.skip_attention_scores) + registry.register_hook(hook, name) + elif config.skip_ff: + for submodule_name, submodule in block.named_modules(): + if isinstance(submodule, _FEEDFORWARD_CLASSES): + logger.debug(f"Applying FeedForwardSkipHook to '{config.fqn}.{i}.{submodule_name}'") + registry = HookRegistry.check_if_exists_or_initialize(submodule) + hook = FeedForwardSkipHook() + registry.register_hook(hook, name) + else: + raise ValueError( + "At least one of `skip_attention`, `skip_attention_scores`, or `skip_ff` must be set to True." + ) + + if not blocks_found: + raise ValueError( + f"Could not find any transformer blocks matching the provided indices {config.indices} and " + f"fully qualified name '{config.fqn}'. Please check the indices and fqn for correctness." + ) + + +def _get_submodule_from_fqn(module: torch.nn.Module, fqn: str) -> Optional[torch.nn.Module]: + for submodule_name, submodule in module.named_modules(): + if submodule_name == fqn: + return submodule + return None diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py index 8e7109308962..5f125605a271 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py @@ -19,7 +19,6 @@ import torch from collections import OrderedDict -from ...guider import CFGGuider from ...image_processor import VaeImageProcessor, PipelineImageInput from ...loaders import StableDiffusionXLLoraLoaderMixin, TextualInversionLoaderMixin, ModularIPAdapterMixin from ...models import ControlNetModel, ImageProjection, UNet2DConditionModel, AutoencoderKL, ControlNetUnionModel @@ -58,7 +57,7 @@ ) from ...schedulers import KarrasDiffusionSchedulers -from ...guider import Guiders, CFGGuider +from ...guiders import GuiderType, ClassifierFreeGuidance import numpy as np @@ -2068,7 +2067,7 @@ class StableDiffusionXLDenoiseStep(PipelineBlock): @property def expected_components(self) -> List[ComponentSpec]: return [ - ComponentSpec("guider", CFGGuider, obj=CFGGuider()), + ComponentSpec("guider", GuiderType, obj=ClassifierFreeGuidance()), ComponentSpec("scheduler", KarrasDiffusionSchedulers), ComponentSpec("unet", UNet2DConditionModel), ] @@ -2082,12 +2081,9 @@ def description(self) -> str: @property def inputs(self) -> List[Tuple[str, Any]]: return [ - InputParam("guidance_scale", default=5.0), - InputParam("guidance_rescale", default=0.0), InputParam("cross_attention_kwargs"), InputParam("generator"), InputParam("eta", default=0.0), - InputParam("guider_kwargs"), InputParam("num_images_per_prompt", default=1), ] @@ -2239,77 +2235,83 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: data.num_channels_unet = pipeline.unet.config.in_channels data.disable_guidance = True if pipeline.unet.config.time_cond_proj_dim is not None else False - # adding default guider arguments: do_classifier_free_guidance, guidance_scale, guidance_rescale - data.guider_kwargs = data.guider_kwargs or {} - data.guider_kwargs = { - **data.guider_kwargs, - "disable_guidance": data.disable_guidance, - "guidance_scale": data.guidance_scale, - "guidance_rescale": data.guidance_rescale, - "batch_size": data.batch_size * data.num_images_per_prompt, - } - - pipeline.guider.set_guider(pipeline, data.guider_kwargs) - # Prepare conditional inputs using the guider - data.prompt_embeds = pipeline.guider.prepare_input( - data.prompt_embeds, - data.negative_prompt_embeds, - ) - data.add_time_ids = pipeline.guider.prepare_input( - data.add_time_ids, - data.negative_add_time_ids, - ) - data.pooled_prompt_embeds = pipeline.guider.prepare_input( - data.pooled_prompt_embeds, - data.negative_pooled_prompt_embeds, - ) - - if data.num_channels_unet == 9: - data.mask = pipeline.guider.prepare_input(data.mask, data.mask) - data.masked_image_latents = pipeline.guider.prepare_input(data.masked_image_latents, data.masked_image_latents) - - data.added_cond_kwargs = { - "text_embeds": data.pooled_prompt_embeds, - "time_ids": data.add_time_ids, - } - - if data.ip_adapter_embeds is not None: - data.ip_adapter_embeds = pipeline.guider.prepare_input(data.ip_adapter_embeds, data.negative_ip_adapter_embeds) - data.added_cond_kwargs["image_embeds"] = data.ip_adapter_embeds - # Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline data.extra_step_kwargs = self.prepare_extra_step_kwargs(pipeline, data.generator, data.eta) data.num_warmup_steps = max(len(data.timesteps) - data.num_inference_steps * pipeline.scheduler.order, 0) with pipeline.progress_bar(total=data.num_inference_steps) as progress_bar: for i, t in enumerate(data.timesteps): - # expand the latents if we are doing classifier free guidance - data.latent_model_input = pipeline.guider.prepare_input(data.latents, data.latents) - data.latent_model_input = pipeline.scheduler.scale_model_input(data.latent_model_input, t) - - # inpainting - if data.num_channels_unet == 9: - data.latent_model_input = torch.cat([data.latent_model_input, data.mask, data.masked_image_latents], dim=1) - - # predict the noise residual - data.noise_pred = pipeline.unet( - data.latent_model_input, - t, - encoder_hidden_states=data.prompt_embeds, - timestep_cond=data.timestep_cond, - cross_attention_kwargs=data.cross_attention_kwargs, - added_cond_kwargs=data.added_cond_kwargs, - return_dict=False, - )[0] - # perform guidance - data.noise_pred = pipeline.guider.apply_guidance( - data.noise_pred, - timestep=t, - latents=data.latents, + pipeline.guider.set_state(step=i, num_inference_steps=data.num_inference_steps, timestep=t) + + ( + latents, + prompt_embeds, + add_time_ids, + pooled_prompt_embeds, + mask, + masked_image_latents, + ip_adapter_embeds, + ) = pipeline.guider.prepare_inputs( + pipeline.unet, + data.latents, + (data.prompt_embeds, data.negative_prompt_embeds), + (data.add_time_ids, data.negative_add_time_ids), + (data.pooled_prompt_embeds, data.negative_pooled_prompt_embeds), + data.mask, + data.masked_image_latents, + (data.ip_adapter_embeds, data.negative_ip_adapter_embeds), ) - # compute the previous noisy sample x_t -> x_t-1 + + for batch_index, ( + latents_i, + prompt_embeds_i, + add_time_ids_i, + pooled_prompt_embeds_i, + mask_i, + masked_image_latents_i, + ip_adapter_embeds_i, + ) in enumerate(zip( + latents, + prompt_embeds, + add_time_ids, + pooled_prompt_embeds, + mask, + masked_image_latents, + ip_adapter_embeds + )): + latents_i = pipeline.scheduler.scale_model_input(latents_i, t) + + # Prepare for inpainting + if data.num_channels_unet == 9: + latents_i = torch.cat([latents_i, mask_i, masked_image_latents_i], dim=1) + + data.added_cond_kwargs = { + "text_embeds": pooled_prompt_embeds_i, + "time_ids": add_time_ids_i, + } + if ip_adapter_embeds_i is not None: + data.added_cond_kwargs["image_embeds"] = ip_adapter_embeds_i + + # predict the noise residual + data.noise_pred = pipeline.unet( + latents_i, + t, + encoder_hidden_states=prompt_embeds_i, + timestep_cond=data.timestep_cond, + cross_attention_kwargs=data.cross_attention_kwargs, + added_cond_kwargs=data.added_cond_kwargs, + return_dict=False, + )[0] + data.noise_pred = pipeline.guider.prepare_outputs(pipeline.unet, data.noise_pred) + + # Perform guidance + outputs = pipeline.guider.outputs + data.noise_pred = pipeline.guider(**outputs) + + # Perform scheduler step using the predicted output data.latents_dtype = data.latents.dtype data.latents = pipeline.scheduler.step(data.noise_pred, t, data.latents, **data.extra_step_kwargs, return_dict=False)[0] + if data.latents.dtype != data.latents_dtype: if torch.backends.mps.is_available(): # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 @@ -2328,7 +2330,6 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: if i == len(data.timesteps) - 1 or ((i + 1) > data.num_warmup_steps and (i + 1) % pipeline.scheduler.order == 0): progress_bar.update() - pipeline.guider.reset_guider(pipeline) self.add_block_state(state, data) return pipeline, state @@ -2341,12 +2342,12 @@ class StableDiffusionXLControlNetDenoiseStep(PipelineBlock): @property def expected_components(self) -> List[ComponentSpec]: return [ - ComponentSpec("guider", CFGGuider, obj=CFGGuider()), + ComponentSpec("guider", GuiderType, obj=ClassifierFreeGuidance()), ComponentSpec("scheduler", KarrasDiffusionSchedulers), ComponentSpec("unet", UNet2DConditionModel), ComponentSpec("controlnet", ControlNetModel), ComponentSpec("control_image_processor", VaeImageProcessor, obj=VaeImageProcessor(do_convert_rgb=True, do_normalize=False)), - ComponentSpec("controlnet_guider", CFGGuider, obj=CFGGuider()), + ComponentSpec("controlnet_guider", GuiderType, obj=ClassifierFreeGuidance()), ] @property @@ -2792,8 +2793,8 @@ def expected_components(self) -> List[ComponentSpec]: ComponentSpec("unet", UNet2DConditionModel), ComponentSpec("controlnet", ControlNetUnionModel), ComponentSpec("scheduler", KarrasDiffusionSchedulers), - ComponentSpec("guider", CFGGuider, obj=CFGGuider()), - ComponentSpec("controlnet_guider", CFGGuider, obj=CFGGuider()), + ComponentSpec("guider", GuiderType, obj=ClassifierFreeGuidance()), + ComponentSpec("controlnet_guider", GuiderType, obj=ClassifierFreeGuidance()), ComponentSpec("control_image_processor", VaeImageProcessor, obj=VaeImageProcessor(do_convert_rgb=True, do_normalize=False)), ] diff --git a/src/diffusers/utils/torch_utils.py b/src/diffusers/utils/torch_utils.py index 3c8911773e39..06f9981f0138 100644 --- a/src/diffusers/utils/torch_utils.py +++ b/src/diffusers/utils/torch_utils.py @@ -90,6 +90,11 @@ def is_compiled_module(module) -> bool: return isinstance(module, torch._dynamo.eval_frame.OptimizedModule) +def unwrap_module(module): + """Unwraps a module if it was compiled with torch.compile()""" + return module._orig_mod if is_compiled_module(module) else module + + def fourier_filter(x_in: "torch.Tensor", threshold: int, scale: int) -> "torch.Tensor": """Fourier filter as introduced in FreeU (https://arxiv.org/abs/2309.11497). From 9da8a9d1d557a290d691ee2a9eaebd107463130f Mon Sep 17 00:00:00 2001 From: Aryan Date: Mon, 14 Apr 2025 14:04:04 +0200 Subject: [PATCH 02/24] support sdxl controlnet --- .../guiders/classifier_free_guidance.py | 6 + src/diffusers/guiders/guider_utils.py | 25 ++- src/diffusers/guiders/skip_layer_guidance.py | 8 + .../pipeline_stable_diffusion_xl_modular.py | 200 ++++++++---------- 4 files changed, 119 insertions(+), 120 deletions(-) diff --git a/src/diffusers/guiders/classifier_free_guidance.py b/src/diffusers/guiders/classifier_free_guidance.py index 4048d70484c5..b3508307d478 100644 --- a/src/diffusers/guiders/classifier_free_guidance.py +++ b/src/diffusers/guiders/classifier_free_guidance.py @@ -92,6 +92,10 @@ def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = return pred + @property + def is_conditional(self) -> bool: + return self._num_outputs_prepared == 0 + @property def num_conditions(self) -> int: num_conditions = 1 @@ -100,6 +104,8 @@ def num_conditions(self) -> int: return num_conditions def _is_cfg_enabled(self) -> bool: + if not self._enabled: + return False skip_start_step = int(self._start * self._num_inference_steps) skip_stop_step = int(self._stop * self._num_inference_steps) is_within_range = skip_start_step <= self._step < skip_stop_step diff --git a/src/diffusers/guiders/guider_utils.py b/src/diffusers/guiders/guider_utils.py index ecde7334b2b9..690afae89178 100644 --- a/src/diffusers/guiders/guider_utils.py +++ b/src/diffusers/guiders/guider_utils.py @@ -39,6 +39,7 @@ def __init__(self, start: float = 0.0, stop: float = 1.0): self._timestep: torch.LongTensor = None self._preds: Dict[str, torch.Tensor] = {} self._num_outputs_prepared: int = 0 + self._enabled = True if not (0.0 <= start < 1.0): raise ValueError( @@ -54,6 +55,12 @@ def __init__(self, start: float = 0.0, stop: float = 1.0): "`_input_predictions` must be a list of required prediction names for the guidance technique." ) + def force_disable(self): + self._enabled = False + + def force_enable(self): + self._enabled = True + def set_state(self, step: int, num_inference_steps: int, timestep: torch.LongTensor) -> None: self._step = step self._num_inference_steps = num_inference_steps @@ -62,10 +69,10 @@ def set_state(self, step: int, num_inference_steps: int, timestep: torch.LongTen self._num_outputs_prepared = 0 def prepare_inputs(self, denoiser: torch.nn.Module, *args: Union[Tuple[torch.Tensor], List[torch.Tensor]]) -> Tuple[List[torch.Tensor], ...]: - raise NotImplementedError("GuidanceMixin::prepare_inputs must be implemented in subclasses.") + raise NotImplementedError("BaseGuidance::prepare_inputs must be implemented in subclasses.") def prepare_outputs(self, denoiser: torch.nn.Module, pred: torch.Tensor) -> None: - raise NotImplementedError("GuidanceMixin::prepare_outputs must be implemented in subclasses.") + raise NotImplementedError("BaseGuidance::prepare_outputs must be implemented in subclasses.") def __call__(self, **kwargs) -> Any: if len(kwargs) != self.num_conditions: @@ -75,11 +82,19 @@ def __call__(self, **kwargs) -> Any: return self.forward(**kwargs) def forward(self, *args, **kwargs) -> Any: - raise NotImplementedError("GuidanceMixin::forward must be implemented in subclasses.") + raise NotImplementedError("BaseGuidance::forward must be implemented in subclasses.") + @property + def is_conditional(self) -> bool: + raise NotImplementedError("BaseGuidance::is_conditional must be implemented in subclasses.") + + @property + def is_unconditional(self) -> bool: + return not self.is_conditional + @property def num_conditions(self) -> int: - raise NotImplementedError("GuidanceMixin::num_conditions must be implemented in subclasses.") + raise NotImplementedError("BaseGuidance::num_conditions must be implemented in subclasses.") @property def outputs(self) -> Dict[str, torch.Tensor]: @@ -114,7 +129,7 @@ def _default_prepare_inputs(denoiser: torch.nn.Module, num_conditions: int, *arg """ Prepares the inputs for the denoiser by ensuring that the conditional and unconditional inputs are correctly prepared based on required number of conditions. This function is used in the `prepare_inputs` method of the - `GuidanceMixin` class. + `BaseGuidance` class. Either tensors or tuples/lists of tensors can be provided. If a tuple/list is provided, it should contain two elements: - The first element is the conditional input. diff --git a/src/diffusers/guiders/skip_layer_guidance.py b/src/diffusers/guiders/skip_layer_guidance.py index e20a700fee6a..92ae7f8518d8 100644 --- a/src/diffusers/guiders/skip_layer_guidance.py +++ b/src/diffusers/guiders/skip_layer_guidance.py @@ -189,6 +189,10 @@ def forward( pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale) return pred + + @property + def is_conditional(self) -> bool: + return self._num_outputs_prepared == 0 or self._num_outputs_prepared == 2 @property def num_conditions(self) -> int: @@ -200,6 +204,8 @@ def num_conditions(self) -> int: return num_conditions def _is_cfg_enabled(self) -> bool: + if not self._enabled: + return False skip_start_step = int(self._start * self._num_inference_steps) skip_stop_step = int(self._stop * self._num_inference_steps) is_within_range = skip_start_step <= self._step < skip_stop_step @@ -211,6 +217,8 @@ def _is_cfg_enabled(self) -> bool: return is_within_range and not is_close def _is_slg_enabled(self) -> bool: + if not self._enabled: + return False skip_start_step = int(self._start * self._num_inference_steps) skip_stop_step = int(self._stop * self._num_inference_steps) is_within_range = skip_start_step < self._step < skip_stop_step diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py index 5f125605a271..5df57c6c1605 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py @@ -2263,21 +2263,9 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: ) for batch_index, ( - latents_i, - prompt_embeds_i, - add_time_ids_i, - pooled_prompt_embeds_i, - mask_i, - masked_image_latents_i, - ip_adapter_embeds_i, + latents_i, prompt_embeds_i, add_time_ids_i, pooled_prompt_embeds_i, mask_i, masked_image_latents_i, ip_adapter_embeds_i, ) in enumerate(zip( - latents, - prompt_embeds, - add_time_ids, - pooled_prompt_embeds, - mask, - masked_image_latents, - ip_adapter_embeds + latents, prompt_embeds, add_time_ids, pooled_prompt_embeds, mask, masked_image_latents, ip_adapter_embeds )): latents_i = pipeline.scheduler.scale_model_input(latents_i, t) @@ -2285,6 +2273,7 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: if data.num_channels_unet == 9: latents_i = torch.cat([latents_i, mask_i, masked_image_latents_i], dim=1) + # Prepare additional conditionings data.added_cond_kwargs = { "text_embeds": pooled_prompt_embeds_i, "time_ids": add_time_ids_i, @@ -2292,7 +2281,7 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: if ip_adapter_embeds_i is not None: data.added_cond_kwargs["image_embeds"] = ip_adapter_embeds_i - # predict the noise residual + # Predict the noise residual data.noise_pred = pipeline.unet( latents_i, t, @@ -2347,7 +2336,6 @@ def expected_components(self) -> List[ComponentSpec]: ComponentSpec("unet", UNet2DConditionModel), ComponentSpec("controlnet", ControlNetModel), ComponentSpec("control_image_processor", VaeImageProcessor, obj=VaeImageProcessor(do_convert_rgb=True, do_normalize=False)), - ComponentSpec("controlnet_guider", GuiderType, obj=ClassifierFreeGuidance()), ] @property @@ -2363,8 +2351,6 @@ def inputs(self) -> List[Tuple[str, Any]]: InputParam("controlnet_conditioning_scale", default=1.0), InputParam("guess_mode", default=False), InputParam("num_images_per_prompt", default=1), - InputParam("guidance_scale", default=5.0), - InputParam("guidance_rescale", default=0.0), InputParam("cross_attention_kwargs"), InputParam("generator"), InputParam("eta", default=0.0), @@ -2515,8 +2501,8 @@ def prepare_control_image( image = components.control_image_processor.preprocess(image, height=height, width=width, crops_coords=crops_coords, resize_mode="fill").to(dtype=torch.float32) else: image = components.control_image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32) + image_batch_size = image.shape[0] - if image_batch_size == 1: repeat_by = batch_size else: @@ -2524,9 +2510,7 @@ def prepare_control_image( repeat_by = num_images_per_prompt image = image.repeat_interleave(repeat_by, dim=0) - image = image.to(device=device, dtype=dtype) - return image # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs with self -> components @@ -2557,9 +2541,7 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: data.num_channels_unet = pipeline.unet.config.in_channels # (1) prepare controlnet inputs - data.device = pipeline._execution_device - data.height, data.width = data.latents.shape[-2:] data.height = data.height * pipeline.vae_scale_factor data.width = data.width * pipeline.vae_scale_factor @@ -2642,59 +2624,12 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: data.controlnet_keep.append(keeps[0] if isinstance(controlnet, ControlNetModel) else keeps) # (2) Prepare conditional inputs for unet using the guider - # adding default guider arguments: disable_guidance, guidance_scale, guidance_rescale data.disable_guidance = True if pipeline.unet.config.time_cond_proj_dim is not None else False - data.guider_kwargs = data.guider_kwargs or {} - data.guider_kwargs = { - **data.guider_kwargs, - "disable_guidance": data.disable_guidance, - "guidance_scale": data.guidance_scale, - "guidance_rescale": data.guidance_rescale, - "batch_size": data.batch_size * data.num_images_per_prompt, - } - pipeline.guider.set_guider(pipeline, data.guider_kwargs) - data.prompt_embeds = pipeline.guider.prepare_input( - data.prompt_embeds, - data.negative_prompt_embeds, - ) - data.add_time_ids = pipeline.guider.prepare_input( - data.add_time_ids, - data.negative_add_time_ids, - ) - data.pooled_prompt_embeds = pipeline.guider.prepare_input( - data.pooled_prompt_embeds, - data.negative_pooled_prompt_embeds, - ) - if data.num_channels_unet == 9: - data.mask = pipeline.guider.prepare_input(data.mask, data.mask) - data.masked_image_latents = pipeline.guider.prepare_input(data.masked_image_latents, data.masked_image_latents) - - data.added_cond_kwargs = { - "text_embeds": data.pooled_prompt_embeds, - "time_ids": data.add_time_ids, - } - - if data.ip_adapter_embeds is not None: - data.ip_adapter_embeds = pipeline.guider.prepare_input(data.ip_adapter_embeds, data.negative_ip_adapter_embeds) - data.added_cond_kwargs["image_embeds"] = data.ip_adapter_embeds + if data.disable_guidance: + pipeline.guider.force_disable() # (3) Prepare conditional inputs for controlnet using the guider data.controlnet_disable_guidance = True if data.disable_guidance or data.guess_mode else False - data.controlnet_guider_kwargs = data.guider_kwargs or {} - data.controlnet_guider_kwargs = { - **data.controlnet_guider_kwargs, - "disable_guidance": data.controlnet_disable_guidance, - "guidance_scale": data.guidance_scale, - "guidance_rescale": data.guidance_rescale, - "batch_size": data.batch_size * data.num_images_per_prompt, - } - pipeline.controlnet_guider.set_guider(pipeline, data.controlnet_guider_kwargs) - data.controlnet_prompt_embeds = pipeline.controlnet_guider.prepare_input(data.prompt_embeds) - data.controlnet_added_cond_kwargs = { - "text_embeds": pipeline.controlnet_guider.prepare_input(data.pooled_prompt_embeds), - "time_ids": pipeline.controlnet_guider.prepare_input(data.add_time_ids), - } - data.control_image = pipeline.controlnet_guider.prepare_input(data.control_image, data.control_image) # (4) Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline data.extra_step_kwargs = self.prepare_extra_step_kwargs(pipeline, data.generator, data.eta) @@ -2703,11 +2638,26 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: # (5) Denoise loop with pipeline.progress_bar(total=data.num_inference_steps) as progress_bar: for i, t in enumerate(data.timesteps): - # prepare latents for unet using the guider - data.latent_model_input = pipeline.guider.prepare_input(data.latents, data.latents) + pipeline.guider.set_state(step=i, num_inference_steps=data.num_inference_steps, timestep=t) - # prepare latents for controlnet using the guider - data.control_model_input = pipeline.controlnet_guider.prepare_input(data.latents, data.latents) + ( + latents, + prompt_embeds, + add_time_ids, + pooled_prompt_embeds, + mask, + masked_image_latents, + ip_adapter_embeds, + ) = pipeline.guider.prepare_inputs( + pipeline.unet, + data.latents, + (data.prompt_embeds, data.negative_prompt_embeds), + (data.add_time_ids, data.negative_add_time_ids), + (data.pooled_prompt_embeds, data.negative_pooled_prompt_embeds), + data.mask, + data.masked_image_latents, + (data.ip_adapter_embeds, data.negative_ip_adapter_embeds), + ) if isinstance(data.controlnet_keep[i], list): data.cond_scale = [c * s for c, s in zip(data.controlnet_conditioning_scale, data.controlnet_keep[i])] @@ -2717,51 +2667,74 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: data.controlnet_cond_scale = data.controlnet_cond_scale[0] data.cond_scale = data.controlnet_cond_scale * data.controlnet_keep[i] - data.down_block_res_samples, data.mid_block_res_sample = pipeline.controlnet( - pipeline.scheduler.scale_model_input(data.control_model_input, t), - t, - encoder_hidden_states=data.controlnet_prompt_embeds, - controlnet_cond=data.control_image, - conditioning_scale=data.cond_scale, - guess_mode=data.guess_mode, - added_cond_kwargs=data.controlnet_added_cond_kwargs, - return_dict=False, - ) + for batch_index, ( + latents_i, prompt_embeds_i, add_time_ids_i, pooled_prompt_embeds_i, mask_i, masked_image_latents_i, ip_adapter_embeds_i + ) in enumerate(zip( + latents, prompt_embeds, add_time_ids, pooled_prompt_embeds, mask, masked_image_latents, ip_adapter_embeds + )): + latents_i = pipeline.scheduler.scale_model_input(latents_i, t) + + # Prepare for inpainting + if data.num_channels_unet == 9: + latents_i = torch.cat([latents_i, mask_i, masked_image_latents_i], dim=1) + + # Prepare additional conditionings + data.added_cond_kwargs = { + "text_embeds": pooled_prompt_embeds_i, + "time_ids": add_time_ids_i, + } + if ip_adapter_embeds_i is not None: + data.added_cond_kwargs["image_embeds"] = ip_adapter_embeds_i + + # Prepare controlnet additional conditionings + data.controlnet_added_cond_kwargs = { + "text_embeds": pooled_prompt_embeds_i, + "time_ids": add_time_ids_i, + } - # when we apply guidance for unet, but not for controlnet: - # add 0 to the unconditional batch - data.down_block_res_samples = pipeline.guider.prepare_input( - data.down_block_res_samples, [torch.zeros_like(d) for d in data.down_block_res_samples] - ) - data.mid_block_res_sample = pipeline.guider.prepare_input( - data.mid_block_res_sample, torch.zeros_like(data.mid_block_res_sample) - ) + if pipeline.guider.is_conditional or not data.guess_mode: + data.down_block_res_samples, data.mid_block_res_sample = pipeline.controlnet( + latents_i, + t, + encoder_hidden_states=prompt_embeds_i, + controlnet_cond=data.control_image, + conditioning_scale=data.cond_scale, + guess_mode=data.guess_mode, + added_cond_kwargs=data.controlnet_added_cond_kwargs, + return_dict=False, + ) + elif pipeline.guider.is_unconditional and data.guess_mode: + data.down_block_res_samples = [torch.zeros_like(d) for d in data.down_block_res_samples] + data.mid_block_res_sample = torch.zeros_like(data.mid_block_res_sample) - data.latent_model_input = pipeline.scheduler.scale_model_input(data.latent_model_input, t) - if data.num_channels_unet == 9: - data.latent_model_input = torch.cat([data.latent_model_input, data.mask, data.masked_image_latents], dim=1) + if data.num_channels_unet == 9: + latents_i = torch.cat([latents_i, mask_i, masked_image_latents_i], dim=1) - data.noise_pred = pipeline.unet( - data.latent_model_input, - t, - encoder_hidden_states=data.prompt_embeds, - timestep_cond=data.timestep_cond, - cross_attention_kwargs=data.cross_attention_kwargs, - added_cond_kwargs=data.added_cond_kwargs, - down_block_additional_residuals=data.down_block_res_samples, - mid_block_additional_residual=data.mid_block_res_sample, - return_dict=False, - )[0] - # perform guidance - data.noise_pred = pipeline.guider.apply_guidance(data.noise_pred, timestep=t, latents=data.latents) - # compute the previous noisy sample x_t -> x_t-1 + data.noise_pred = pipeline.unet( + latents_i, + t, + encoder_hidden_states=prompt_embeds_i, + timestep_cond=data.timestep_cond, + cross_attention_kwargs=data.cross_attention_kwargs, + added_cond_kwargs=data.added_cond_kwargs, + down_block_additional_residuals=data.down_block_res_samples, + mid_block_additional_residual=data.mid_block_res_sample, + return_dict=False, + )[0] + data.noise_pred = pipeline.guider.prepare_outputs(pipeline.unet, data.noise_pred) + + # Perform guidance + outputs = pipeline.guider.outputs + data.noise_pred = pipeline.guider(**outputs) + + # Perform scheduler step using the predicted output data.latents_dtype = data.latents.dtype data.latents = pipeline.scheduler.step(data.noise_pred, t, data.latents, **data.extra_step_kwargs, return_dict=False)[0] + if data.latents.dtype != data.latents_dtype: if torch.backends.mps.is_available(): # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 data.latents = data.latents.to(data.latents_dtype) - if data.num_channels_unet == 4 and data.mask is not None and data.image_latents is not None: data.init_latents_proper = data.image_latents @@ -2775,9 +2748,6 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: if i == len(data.timesteps) - 1 or ((i + 1) > data.num_warmup_steps and (i + 1) % pipeline.scheduler.order == 0): progress_bar.update() - - pipeline.guider.reset_guider(pipeline) - pipeline.controlnet_guider.reset_guider(pipeline) self.add_block_state(state, data) From b81bd78bf9ad422720b52ec86907ea0a945181e3 Mon Sep 17 00:00:00 2001 From: Aryan Date: Mon, 14 Apr 2025 15:39:16 +0200 Subject: [PATCH 03/24] support controlnet union --- src/diffusers/__init__.py | 2 + src/diffusers/guiders/__init__.py | 1 + .../guiders/adaptive_projected_guidance.py | 174 +++++++++++++ .../guiders/classifier_free_guidance.py | 11 +- src/diffusers/guiders/skip_layer_guidance.py | 22 +- .../pipeline_stable_diffusion_xl_modular.py | 246 ++++++++---------- 6 files changed, 312 insertions(+), 144 deletions(-) create mode 100644 src/diffusers/guiders/adaptive_projected_guidance.py diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index d8e274cb93b6..d3a61df3baff 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -132,6 +132,7 @@ else: _import_structure["guiders"].extend( [ + "AdaptiveProjectedGuidance", "ClassifierFreeGuidance", "SkipLayerGuidance", ] @@ -721,6 +722,7 @@ from .utils.dummy_pt_objects import * # noqa F403 else: from .guiders import ( + AdaptiveProjectedGuidance, ClassifierFreeGuidance, SkipLayerGuidance, ) diff --git a/src/diffusers/guiders/__init__.py b/src/diffusers/guiders/__init__.py index adef65277b6d..e3c6494de090 100644 --- a/src/diffusers/guiders/__init__.py +++ b/src/diffusers/guiders/__init__.py @@ -18,6 +18,7 @@ if is_torch_available(): + from .adaptive_projected_guidance import AdaptiveProjectedGuidance from .classifier_free_guidance import ClassifierFreeGuidance from .skip_layer_guidance import SkipLayerGuidance diff --git a/src/diffusers/guiders/adaptive_projected_guidance.py b/src/diffusers/guiders/adaptive_projected_guidance.py new file mode 100644 index 000000000000..ab5175745b07 --- /dev/null +++ b/src/diffusers/guiders/adaptive_projected_guidance.py @@ -0,0 +1,174 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from typing import Optional, Union, Tuple, List + +import torch + +from .guider_utils import BaseGuidance, rescale_noise_cfg, _default_prepare_inputs + + +class AdaptiveProjectedGuidance(BaseGuidance): + """ + Adaptive Projected Guidance (APG): https://huggingface.co/papers/2410.02416 + + Args: + guidance_scale (`float`, defaults to `7.5`): + The scale parameter for classifier-free guidance. Higher values result in stronger conditioning on the text + prompt, while lower values allow for more freedom in generation. Higher values may lead to saturation and + deterioration of image quality. + adaptive_projected_guidance_momentum (`float`, defaults to `None`): + The momentum parameter for the adaptive projected guidance. Disabled if set to `None`. + adaptive_projected_guidance_rescale (`float`, defaults to `15.0`): + The rescale factor applied to the noise predictions. This is used to improve image quality and fix + guidance_rescale (`float`, defaults to `0.0`): + The rescale factor applied to the noise predictions. This is used to improve image quality and fix + overexposure. Based on Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are + Flawed](https://huggingface.co/papers/2305.08891). + use_original_formulation (`bool`, defaults to `False`): + Whether to use the original formulation of classifier-free guidance as proposed in the paper. By default, + we use the diffusers-native implementation that has been in the codebase for a long time. See + [~guiders.classifier_free_guidance.ClassifierFreeGuidance] for more details. + start (`float`, defaults to `0.0`): + The fraction of the total number of denoising steps after which guidance starts. + stop (`float`, defaults to `1.0`): + The fraction of the total number of denoising steps after which guidance stops. + """ + + _input_predictions = ["pred_cond", "pred_uncond"] + + def __init__( + self, + guidance_scale: float = 7.5, + adaptive_projected_guidance_momentum: Optional[float] = None, + adaptive_projected_guidance_rescale: float = 15.0, + eta: float = 1.0, + guidance_rescale: float = 0.0, + use_original_formulation: bool = False, + start: float = 0.0, + stop: float = 1.0, + ): + super().__init__(start, stop) + + self.guidance_scale = guidance_scale + self.adaptive_projected_guidance_momentum = adaptive_projected_guidance_momentum + self.adaptive_projected_guidance_rescale = adaptive_projected_guidance_rescale + self.eta = eta + self.guidance_rescale = guidance_rescale + self.use_original_formulation = use_original_formulation + self.momentum_buffer = None + + def prepare_inputs(self, denoiser: torch.nn.Module, *args: Union[Tuple[torch.Tensor], List[torch.Tensor]]) -> Tuple[List[torch.Tensor], ...]: + if self._step == 0: + if self.adaptive_projected_guidance_momentum is not None: + self.momentum_buffer = MomentumBuffer(self.adaptive_projected_guidance_momentum) + return _default_prepare_inputs(denoiser, self.num_conditions, *args) + + def prepare_outputs(self, denoiser: torch.nn.Module, pred: torch.Tensor) -> None: + self._num_outputs_prepared += 1 + if self._num_outputs_prepared > self.num_conditions: + raise ValueError(f"Expected {self.num_conditions} outputs, but prepare_outputs called more times.") + key = self._input_predictions[self._num_outputs_prepared - 1] + self._preds[key] = pred + + def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> torch.Tensor: + pred = None + + if not self._is_cfg_enabled(): + pred = pred_cond + else: + pred = normalized_guidance( + pred_cond, + pred_uncond, + self.guidance_scale, + self.momentum_buffer, + self.eta, + self.adaptive_projected_guidance_rescale, + self.use_original_formulation, + ) + + if self.guidance_rescale > 0.0: + pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale) + + return pred + + @property + def is_conditional(self) -> bool: + return self._num_outputs_prepared == 0 + + @property + def num_conditions(self) -> int: + num_conditions = 1 + if self._is_cfg_enabled(): + num_conditions += 1 + return num_conditions + + def _is_cfg_enabled(self) -> bool: + if not self._enabled: + return False + + is_within_range = True + if self._num_inference_steps is not None: + skip_start_step = int(self._start * self._num_inference_steps) + skip_stop_step = int(self._stop * self._num_inference_steps) + is_within_range = skip_start_step <= self._step < skip_stop_step + + is_close = False + if self.use_original_formulation: + is_close = math.isclose(self.guidance_scale, 0.0) + else: + is_close = math.isclose(self.guidance_scale, 1.0) + + return is_within_range and not is_close + + +class MomentumBuffer: + def __init__(self, momentum: float): + self.momentum = momentum + self.running_average = 0 + + def update(self, update_value: torch.Tensor): + new_average = self.momentum * self.running_average + self.running_average = update_value + new_average + + +def normalized_guidance( + pred_cond: torch.Tensor, + pred_uncond: torch.Tensor, + guidance_scale: float, + momentum_buffer: Optional[MomentumBuffer] = None, + eta: float = 1.0, + norm_threshold: float = 0.0, + use_original_formulation: bool = False, +): + diff = pred_cond - pred_uncond + dim = [-i for i in range(1, len(diff.shape))] + if momentum_buffer is not None: + momentum_buffer.update(diff) + diff = momentum_buffer.running_average + if norm_threshold > 0: + ones = torch.ones_like(diff) + diff_norm = diff.norm(p=2, dim=dim, keepdim=True) + scale_factor = torch.minimum(ones, norm_threshold / diff_norm) + diff = diff * scale_factor + v0, v1 = diff.double(), pred_cond.double() + v1 = torch.nn.functional.normalize(v1, dim=dim) + v0_parallel = (v0 * v1).sum(dim=dim, keepdim=True) * v1 + v0_orthogonal = v0 - v0_parallel + diff_parallel, diff_orthogonal = v0_parallel.type_as(diff), v0_orthogonal.type_as(diff) + normalized_update = diff_orthogonal + eta * diff_parallel + pred = pred_cond if use_original_formulation else pred_uncond + pred = pred + (guidance_scale - 1) * normalized_update + return pred diff --git a/src/diffusers/guiders/classifier_free_guidance.py b/src/diffusers/guiders/classifier_free_guidance.py index b3508307d478..3deacdfb2863 100644 --- a/src/diffusers/guiders/classifier_free_guidance.py +++ b/src/diffusers/guiders/classifier_free_guidance.py @@ -106,12 +106,17 @@ def num_conditions(self) -> int: def _is_cfg_enabled(self) -> bool: if not self._enabled: return False - skip_start_step = int(self._start * self._num_inference_steps) - skip_stop_step = int(self._stop * self._num_inference_steps) - is_within_range = skip_start_step <= self._step < skip_stop_step + + is_within_range = True + if self._num_inference_steps is not None: + skip_start_step = int(self._start * self._num_inference_steps) + skip_stop_step = int(self._stop * self._num_inference_steps) + is_within_range = skip_start_step <= self._step < skip_stop_step + is_close = False if self.use_original_formulation: is_close = math.isclose(self.guidance_scale, 0.0) else: is_close = math.isclose(self.guidance_scale, 1.0) + return is_within_range and not is_close diff --git a/src/diffusers/guiders/skip_layer_guidance.py b/src/diffusers/guiders/skip_layer_guidance.py index 92ae7f8518d8..4abb93272e9a 100644 --- a/src/diffusers/guiders/skip_layer_guidance.py +++ b/src/diffusers/guiders/skip_layer_guidance.py @@ -206,21 +206,31 @@ def num_conditions(self) -> int: def _is_cfg_enabled(self) -> bool: if not self._enabled: return False - skip_start_step = int(self._start * self._num_inference_steps) - skip_stop_step = int(self._stop * self._num_inference_steps) - is_within_range = skip_start_step <= self._step < skip_stop_step + + is_within_range = True + if self._num_inference_steps is not None: + skip_start_step = int(self._start * self._num_inference_steps) + skip_stop_step = int(self._stop * self._num_inference_steps) + is_within_range = skip_start_step <= self._step < skip_stop_step + is_close = False if self.use_original_formulation: is_close = math.isclose(self.guidance_scale, 0.0) else: is_close = math.isclose(self.guidance_scale, 1.0) + return is_within_range and not is_close def _is_slg_enabled(self) -> bool: if not self._enabled: return False - skip_start_step = int(self._start * self._num_inference_steps) - skip_stop_step = int(self._stop * self._num_inference_steps) - is_within_range = skip_start_step < self._step < skip_stop_step + + is_within_range = True + if self._num_inference_steps is not None: + skip_start_step = int(self._start * self._num_inference_steps) + skip_stop_step = int(self._stop * self._num_inference_steps) + is_within_range = skip_start_step < self._step < skip_stop_step + is_zero = math.isclose(self.skip_layer_guidance_scale, 0.0) + return is_within_range and not is_zero diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py index 5df57c6c1605..24d7e333c4fa 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py @@ -194,11 +194,7 @@ def inputs(self) -> List[InputParam]: PipelineImageInput, required=True, description="The image(s) to be used as ip adapter" - ), - InputParam( - "guidance_scale", - default=5.0, - ), + ) ] @@ -236,11 +232,10 @@ def encode_image(self, components, image, device, num_images_per_prompt, output_ # modified from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds def prepare_ip_adapter_image_embeds( - self, components, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance + self, components, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt ): image_embeds = [] - if do_classifier_free_guidance: - negative_image_embeds = [] + negative_image_embeds = [] if ip_adapter_image_embeds is None: if not isinstance(ip_adapter_image, list): ip_adapter_image = [ip_adapter_image] @@ -259,21 +254,18 @@ def prepare_ip_adapter_image_embeds( ) image_embeds.append(single_image_embeds[None, :]) - if do_classifier_free_guidance: - negative_image_embeds.append(single_negative_image_embeds[None, :]) + negative_image_embeds.append(single_negative_image_embeds[None, :]) else: for single_image_embeds in ip_adapter_image_embeds: - if do_classifier_free_guidance: - single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2) - negative_image_embeds.append(single_negative_image_embeds) + single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2) + negative_image_embeds.append(single_negative_image_embeds) image_embeds.append(single_image_embeds) ip_adapter_image_embeds = [] for i, single_image_embeds in enumerate(image_embeds): single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0) - if do_classifier_free_guidance: - single_negative_image_embeds = torch.cat([negative_image_embeds[i]] * num_images_per_prompt, dim=0) - single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0) + single_negative_image_embeds = torch.cat([negative_image_embeds[i]] * num_images_per_prompt, dim=0) + single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0) single_image_embeds = single_image_embeds.to(device=device) ip_adapter_image_embeds.append(single_image_embeds) @@ -323,6 +315,7 @@ def expected_components(self) -> List[ComponentSpec]: ComponentSpec("text_encoder_2", CLIPTextModelWithProjection), ComponentSpec("tokenizer", CLIPTokenizer), ComponentSpec("tokenizer_2", CLIPTokenizer), + ComponentSpec("guider", GuiderType, obj=ClassifierFreeGuidance()), ] @property @@ -337,7 +330,6 @@ def inputs(self) -> List[InputParam]: InputParam("negative_prompt"), InputParam("negative_prompt_2"), InputParam("cross_attention_kwargs"), - InputParam("guidance_scale",default=5.0), InputParam("clip_skip"), ] @@ -601,10 +593,9 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: data = self.get_block_state(state) self.check_inputs(pipeline, data) - data.do_classifier_free_guidance = data.guidance_scale > 1.0 + data.do_classifier_free_guidance = pipeline.guider.num_conditions > 1 data.device = pipeline._execution_device - # Encode input prompt data.text_encoder_lora_scale = ( data.cross_attention_kwargs.get("scale", None) if data.cross_attention_kwargs is not None else None @@ -1750,7 +1741,6 @@ def inputs(self) -> List[Tuple[str, Any]]: InputParam("crops_coords_top_left", default=(0, 0)), InputParam("negative_crops_coords_top_left", default=(0, 0)), InputParam("num_images_per_prompt", default=1), - InputParam("guidance_scale", required=True), InputParam("aesthetic_score", default=6.0), InputParam("negative_aesthetic_score", default=2.0), ] @@ -1897,7 +1887,8 @@ def __call__(self, pipeline: DiffusionPipeline, state: PipelineState) -> Pipelin and pipeline.unet is not None and pipeline.unet.config.time_cond_proj_dim is not None ): - data.guidance_scale_tensor = torch.tensor(data.guidance_scale - 1).repeat(data.batch_size * data.num_images_per_prompt) + # TODO(yiyi, aryan): Ideally, this should be `embedded_guidance_scale` instead of pulling from guider. Guider scales should be different from this! + data.guidance_scale_tensor = torch.tensor(pipeline.guider.guidance_scale - 1).repeat(data.batch_size * data.num_images_per_prompt) data.timestep_cond = self.get_guidance_scale_embedding( data.guidance_scale_tensor, embedding_dim=pipeline.unet.config.time_cond_proj_dim ).to(device=data.device, dtype=data.latents.dtype) @@ -1925,7 +1916,6 @@ def inputs(self) -> List[Tuple[str, Any]]: InputParam("crops_coords_top_left", default=(0, 0)), InputParam("negative_crops_coords_top_left", default=(0, 0)), InputParam("num_images_per_prompt", default=1), - InputParam("guidance_scale", default=5.0), ] @property @@ -2051,7 +2041,8 @@ def __call__(self, pipeline: DiffusionPipeline, state: PipelineState) -> Pipelin and pipeline.unet is not None and pipeline.unet.config.time_cond_proj_dim is not None ): - data.guidance_scale_tensor = torch.tensor(data.guidance_scale - 1).repeat(data.batch_size * data.num_images_per_prompt) + # TODO(yiyi, aryan): Ideally, this should be `embedded_guidance_scale` instead of pulling from guider. Guider scales should be different from this! + data.guidance_scale_tensor = torch.tensor(pipeline.guider.guidance_scale - 1).repeat(data.batch_size * data.num_images_per_prompt) data.timestep_cond = self.get_guidance_scale_embedding( data.guidance_scale_tensor, embedding_dim=pipeline.unet.config.time_cond_proj_dim ).to(device=data.device, dtype=data.latents.dtype) @@ -2234,6 +2225,10 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: data.num_channels_unet = pipeline.unet.config.in_channels data.disable_guidance = True if pipeline.unet.config.time_cond_proj_dim is not None else False + if data.disable_guidance: + pipeline.guider.force_disable() + else: + pipeline.guider.force_enable() # Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline data.extra_step_kwargs = self.prepare_extra_step_kwargs(pipeline, data.generator, data.eta) @@ -2354,7 +2349,6 @@ def inputs(self) -> List[Tuple[str, Any]]: InputParam("cross_attention_kwargs"), InputParam("generator"), InputParam("eta", default=0.0), - InputParam("guider_kwargs"), ] @property @@ -2627,9 +2621,8 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: data.disable_guidance = True if pipeline.unet.config.time_cond_proj_dim is not None else False if data.disable_guidance: pipeline.guider.force_disable() - - # (3) Prepare conditional inputs for controlnet using the guider - data.controlnet_disable_guidance = True if data.disable_guidance or data.guess_mode else False + else: + pipeline.guider.force_enable() # (4) Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline data.extra_step_kwargs = self.prepare_extra_step_kwargs(pipeline, data.generator, data.eta) @@ -2764,7 +2757,6 @@ def expected_components(self) -> List[ComponentSpec]: ComponentSpec("controlnet", ControlNetUnionModel), ComponentSpec("scheduler", KarrasDiffusionSchedulers), ComponentSpec("guider", GuiderType, obj=ClassifierFreeGuidance()), - ComponentSpec("controlnet_guider", GuiderType, obj=ClassifierFreeGuidance()), ComponentSpec("control_image_processor", VaeImageProcessor, obj=VaeImageProcessor(do_convert_rgb=True, do_normalize=False)), ] @@ -2781,12 +2773,9 @@ def inputs(self) -> List[Tuple[str, Any]]: InputParam("controlnet_conditioning_scale", default=1.0), InputParam("guess_mode", default=False), InputParam("num_images_per_prompt", default=1), - InputParam("guidance_scale", default=5.0), - InputParam("guidance_rescale", default=0.0), InputParam("cross_attention_kwargs"), InputParam("generator"), InputParam("eta", default=0.0), - InputParam("guider_kwargs") ] @property @@ -3029,7 +3018,6 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: crops_coords=data.crops_coords, ) data.height, data.width = data.control_image[idx].shape[-2:] - # (1.6) # controlnet_keep @@ -3043,80 +3031,48 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: # (2) Prepare conditional inputs for unet using the guider # adding default guider arguments: disable_guidance, guidance_scale, guidance_rescale data.disable_guidance = True if pipeline.unet.config.time_cond_proj_dim is not None else False - data.guider_kwargs = data.guider_kwargs or {} - data.guider_kwargs = { - **data.guider_kwargs, - "disable_guidance": data.disable_guidance, - "guidance_scale": data.guidance_scale, - "guidance_rescale": data.guidance_rescale, - "batch_size": data.batch_size * data.num_images_per_prompt, - } - pipeline.guider.set_guider(pipeline, data.guider_kwargs) - data.prompt_embeds = pipeline.guider.prepare_input( - data.prompt_embeds, - data.negative_prompt_embeds, - ) - data.add_time_ids = pipeline.guider.prepare_input( - data.add_time_ids, - data.negative_add_time_ids, - ) - data.pooled_prompt_embeds = pipeline.guider.prepare_input( - data.pooled_prompt_embeds, - data.negative_pooled_prompt_embeds, - ) - - if data.num_channels_unet == 9: - data.mask = pipeline.guider.prepare_input(data.mask, data.mask) - data.masked_image_latents = pipeline.guider.prepare_input(data.masked_image_latents, data.masked_image_latents) - - data.added_cond_kwargs = { - "text_embeds": data.pooled_prompt_embeds, - "time_ids": data.add_time_ids, - } - - if data.ip_adapter_embeds is not None: - data.ip_adapter_embeds = pipeline.guider.prepare_input(data.ip_adapter_embeds, data.negative_ip_adapter_embeds) - data.added_cond_kwargs["image_embeds"] = data.ip_adapter_embeds + if data.disable_guidance: + pipeline.guider.force_disable() + else: + pipeline.guider.force_enable() # (3) Prepare conditional inputs for controlnet using the guider - data.controlnet_disable_guidance = True if data.disable_guidance or data.guess_mode else False - data.controlnet_guider_kwargs = data.guider_kwargs or {} - data.controlnet_guider_kwargs = { - **data.controlnet_guider_kwargs, - "disable_guidance": data.controlnet_disable_guidance, - "guidance_scale": data.guidance_scale, - "guidance_rescale": data.guidance_rescale, - "batch_size": data.batch_size * data.num_images_per_prompt, - } - pipeline.controlnet_guider.set_guider(pipeline, data.controlnet_guider_kwargs) - data.controlnet_prompt_embeds = pipeline.controlnet_guider.prepare_input(data.prompt_embeds) - data.controlnet_added_cond_kwargs = { - "text_embeds": pipeline.controlnet_guider.prepare_input(data.pooled_prompt_embeds), - "time_ids": pipeline.controlnet_guider.prepare_input(data.add_time_ids), - } - for idx, _ in enumerate(data.control_image): - data.control_image[idx] = pipeline.controlnet_guider.prepare_input(data.control_image[idx], data.control_image[idx]) + # data.controlnet_prompt_embeds = pipeline.controlnet_guider.prepare_input(data.prompt_embeds) + # data.controlnet_added_cond_kwargs = { + # "text_embeds": pipeline.controlnet_guider.prepare_input(data.pooled_prompt_embeds), + # "time_ids": pipeline.controlnet_guider.prepare_input(data.add_time_ids), + # } - data.control_type = ( - data.control_type.reshape(1, -1) - .to(data.device, dtype=data.prompt_embeds.dtype) - ) + data.control_type = data.control_type.reshape(1, -1).to(data.device, dtype=data.prompt_embeds.dtype) repeat_by = data.batch_size * data.num_images_per_prompt // data.control_type.shape[0] data.control_type = data.control_type.repeat_interleave(repeat_by, dim=0) - data.control_type = pipeline.controlnet_guider.prepare_input(data.control_type, data.control_type) # (4) Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline data.extra_step_kwargs = self.prepare_extra_step_kwargs(pipeline, data.generator, data.eta) data.num_warmup_steps = max(len(data.timesteps) - data.num_inference_steps * pipeline.scheduler.order, 0) - with pipeline.progress_bar(total=data.num_inference_steps) as progress_bar: for i, t in enumerate(data.timesteps): - # prepare latents for unet using the guider - data.latent_model_input = pipeline.guider.prepare_input(data.latents, data.latents) + pipeline.guider.set_state(step=i, num_inference_steps=data.num_inference_steps, timestep=t) - # prepare latents for controlnet using the guider - data.control_model_input = pipeline.controlnet_guider.prepare_input(data.latents, data.latents) + ( + latents, + prompt_embeds, + add_time_ids, + pooled_prompt_embeds, + mask, + masked_image_latents, + ip_adapter_embeds, + ) = pipeline.guider.prepare_inputs( + pipeline.unet, + data.latents, + (data.prompt_embeds, data.negative_prompt_embeds), + (data.add_time_ids, data.negative_add_time_ids), + (data.pooled_prompt_embeds, data.negative_pooled_prompt_embeds), + data.mask, + data.masked_image_latents, + (data.ip_adapter_embeds, data.negative_ip_adapter_embeds), + ) if isinstance(data.controlnet_keep[i], list): data.cond_scale = [c * s for c, s in zip(data.controlnet_conditioning_scale, data.controlnet_keep[i])] @@ -3126,48 +3082,72 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: data.controlnet_cond_scale = data.controlnet_cond_scale[0] data.cond_scale = data.controlnet_cond_scale * data.controlnet_keep[i] - data.down_block_res_samples, data.mid_block_res_sample = pipeline.controlnet( - pipeline.scheduler.scale_model_input(data.control_model_input, t), - t, - encoder_hidden_states=data.controlnet_prompt_embeds, - controlnet_cond=data.control_image, - control_type=data.control_type, - control_type_idx=data.control_mode, - conditioning_scale=data.cond_scale, - guess_mode=data.guess_mode, - added_cond_kwargs=data.controlnet_added_cond_kwargs, - return_dict=False, - ) + for batch_index, ( + latents_i, prompt_embeds_i, add_time_ids_i, pooled_prompt_embeds_i, mask_i, masked_image_latents_i, ip_adapter_embeds_i + ) in enumerate(zip( + latents, prompt_embeds, add_time_ids, pooled_prompt_embeds, mask, masked_image_latents, ip_adapter_embeds + )): + latents_i = pipeline.scheduler.scale_model_input(latents_i, t) + + # Prepare for inpainting + if data.num_channels_unet == 9: + latents_i = torch.cat([latents_i, mask_i, masked_image_latents_i], dim=1) + + # Prepare additional conditionings + data.added_cond_kwargs = { + "text_embeds": pooled_prompt_embeds_i, + "time_ids": add_time_ids_i, + } + if ip_adapter_embeds_i is not None: + data.added_cond_kwargs["image_embeds"] = ip_adapter_embeds_i + + # Prepare controlnet additional conditionings + data.controlnet_added_cond_kwargs = { + "text_embeds": pooled_prompt_embeds_i, + "time_ids": add_time_ids_i, + } + + if pipeline.guider.is_conditional or not data.guess_mode: + data.down_block_res_samples, data.mid_block_res_sample = pipeline.controlnet( + latents_i, + t, + encoder_hidden_states=prompt_embeds_i, + controlnet_cond=data.control_image, + control_type=data.control_type, + control_type_idx=data.control_mode, + conditioning_scale=data.cond_scale, + guess_mode=data.guess_mode, + added_cond_kwargs=data.controlnet_added_cond_kwargs, + return_dict=False, + ) + elif pipeline.guider.is_unconditional and data.guess_mode: + data.down_block_res_samples = [torch.zeros_like(d) for d in data.down_block_res_samples] + data.mid_block_res_sample = torch.zeros_like(data.mid_block_res_sample) - # when we apply guidance for unet, but not for controlnet: - # add 0 to the unconditional batch - data.down_block_res_samples = pipeline.guider.prepare_input( - data.down_block_res_samples, [torch.zeros_like(d) for d in data.down_block_res_samples] - ) - data.mid_block_res_sample = pipeline.guider.prepare_input( - data.mid_block_res_sample, torch.zeros_like(data.mid_block_res_sample) - ) + if data.num_channels_unet == 9: + latents_i = torch.cat([latents_i, mask_i, masked_image_latents_i], dim=1) + + data.noise_pred = pipeline.unet( + latents_i, + t, + encoder_hidden_states=prompt_embeds_i, + timestep_cond=data.timestep_cond, + cross_attention_kwargs=data.cross_attention_kwargs, + added_cond_kwargs=data.added_cond_kwargs, + down_block_additional_residuals=data.down_block_res_samples, + mid_block_additional_residual=data.mid_block_res_sample, + return_dict=False, + )[0] + data.noise_pred = pipeline.guider.prepare_outputs(pipeline.unet, data.noise_pred) + + # Perform guidance + outputs = pipeline.guider.outputs + data.noise_pred = pipeline.guider(**outputs) - data.latent_model_input = pipeline.scheduler.scale_model_input(data.latent_model_input, t) - if data.num_channels_unet == 9: - data.latent_model_input = torch.cat([data.latent_model_input, data.mask, data.masked_image_latents], dim=1) - - data.noise_pred = pipeline.unet( - data.latent_model_input, - t, - encoder_hidden_states=data.prompt_embeds, - timestep_cond=data.timestep_cond, - cross_attention_kwargs=data.cross_attention_kwargs, - added_cond_kwargs=data.added_cond_kwargs, - down_block_additional_residuals=data.down_block_res_samples, - mid_block_additional_residual=data.mid_block_res_sample, - return_dict=False, - )[0] - # perform guidance - data.noise_pred = pipeline.guider.apply_guidance(data.noise_pred, timestep=t, latents=data.latents) - # compute the previous noisy sample x_t -> x_t-1 + # Perform scheduler step using the predicted output data.latents_dtype = data.latents.dtype data.latents = pipeline.scheduler.step(data.noise_pred, t, data.latents, **data.extra_step_kwargs, return_dict=False)[0] + if data.latents.dtype != data.latents_dtype: if torch.backends.mps.is_available(): # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 @@ -3180,14 +3160,10 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: data.init_latents_proper = pipeline.scheduler.add_noise( data.init_latents_proper, data.noise, torch.tensor([data.noise_timestep]) ) - data.latents = (1 - data.mask) * data.init_latents_proper + data.mask * data.latents if i == len(data.timesteps) - 1 or ((i + 1) > data.num_warmup_steps and (i + 1) % pipeline.scheduler.order == 0): progress_bar.update() - - pipeline.guider.reset_guider(pipeline) - pipeline.controlnet_guider.reset_guider(pipeline) self.add_block_state(state, data) From 31593e2c3336b5eed36ae4349214dc612585ebf2 Mon Sep 17 00:00:00 2001 From: Aryan Date: Mon, 14 Apr 2025 15:56:36 +0200 Subject: [PATCH 04/24] update --- .../pipeline_stable_diffusion_xl_modular.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py index 24d7e333c4fa..0cb4294e12b9 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py @@ -184,6 +184,7 @@ def expected_components(self) -> List[ComponentSpec]: ComponentSpec("image_encoder", CLIPVisionModelWithProjection), ComponentSpec("feature_extractor", CLIPImageProcessor), ComponentSpec("unet", UNet2DConditionModel), + ComponentSpec("guider", GuiderType), ] @property @@ -276,7 +277,7 @@ def prepare_ip_adapter_image_embeds( def __call__(self, pipeline, state: PipelineState) -> PipelineState: data = self.get_block_state(state) - data.do_classifier_free_guidance = data.guidance_scale > 1.0 + data.do_classifier_free_guidance = pipeline.guider.num_conditions > 1 data.device = pipeline._execution_device data.ip_adapter_embeds = self.prepare_ip_adapter_image_embeds( @@ -315,7 +316,7 @@ def expected_components(self) -> List[ComponentSpec]: ComponentSpec("text_encoder_2", CLIPTextModelWithProjection), ComponentSpec("tokenizer", CLIPTokenizer), ComponentSpec("tokenizer_2", CLIPTokenizer), - ComponentSpec("guider", GuiderType, obj=ClassifierFreeGuidance()), + ComponentSpec("guider", GuiderType), ] @property @@ -3490,6 +3491,11 @@ def description(self): "- to run the ip_adapter workflow, you need to provide `ip_adapter_image`\n" + \ "- for text-to-image generation, all you need to provide is `prompt`" +# TODO(yiyi, aryan): We need another step before text encoder to set the `num_inference_steps` attribute for guider so that +# things like when to do guidance and how many conditions to be prepared can be determined. Currently, this is done by +# always assuming you want to do guidance in the Guiders. So, negative embeddings are prepared regardless of what the +# configuration of guider is. + # block mapping TEXT2IMAGE_BLOCKS = OrderedDict([ ("text_encoder", StableDiffusionXLTextEncoderStep), @@ -3611,7 +3617,6 @@ def num_channels_latents(self): "negative_prompt": InputParam("negative_prompt", type_hint=Union[str, List[str]], description="The prompt or prompts not to guide the image generation"), "negative_prompt_2": InputParam("negative_prompt_2", type_hint=Union[str, List[str]], description="The negative prompt or prompts for text_encoder_2"), "cross_attention_kwargs": InputParam("cross_attention_kwargs", type_hint=Optional[dict], description="Kwargs dictionary passed to the AttentionProcessor"), - "guidance_scale": InputParam("guidance_scale", type_hint=float, default=5.0, description="Classifier-Free Diffusion Guidance scale"), "clip_skip": InputParam("clip_skip", type_hint=Optional[int], description="Number of layers to skip in CLIP text encoder"), "image": InputParam("image", type_hint=PipelineImageInput, required=True, description="The image(s) to modify for img2img or inpainting"), "mask_image": InputParam("mask_image", type_hint=PipelineImageInput, required=True, description="Mask image for inpainting, white pixels will be repainted"), @@ -3636,7 +3641,6 @@ def num_channels_latents(self): "negative_crops_coords_top_left": InputParam("negative_crops_coords_top_left", type_hint=Tuple[int, int], default=(0, 0), description="Negative conditioning crop coordinates"), "aesthetic_score": InputParam("aesthetic_score", type_hint=float, default=6.0, description="Simulates aesthetic score of generated image"), "negative_aesthetic_score": InputParam("negative_aesthetic_score", type_hint=float, default=2.0, description="Simulates negative aesthetic score"), - "guidance_rescale": InputParam("guidance_rescale", type_hint=float, default=0.0, description="Guidance rescale factor to fix overexposure"), "eta": InputParam("eta", type_hint=float, default=0.0, description="Parameter η in the DDIM paper"), "guider_kwargs": InputParam("guider_kwargs", type_hint=Optional[Dict[str, Any]], description="Kwargs dictionary passed to the Guider"), "output_type": InputParam("output_type", type_hint=str, default="pil", description="Output format (pil/tensor/np.array)"), @@ -3704,4 +3708,4 @@ def num_channels_latents(self): SDXL_OUTPUTS_SCHEMA = { "images": OutputParam("images", type_hint=Union[Tuple[Union[List[PIL.Image.Image], List[torch.Tensor], List[np.array]]], StableDiffusionXLPipelineOutput], description="The final generated images") -} \ No newline at end of file +} From 625530295de7bb2e14c169a05c4358772fea907f Mon Sep 17 00:00:00 2001 From: Aryan Date: Mon, 14 Apr 2025 15:57:17 +0200 Subject: [PATCH 05/24] update --- .../pipeline_stable_diffusion_xl_modular.py | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py index 0cb4294e12b9..a02e869e8b32 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py @@ -233,10 +233,11 @@ def encode_image(self, components, image, device, num_images_per_prompt, output_ # modified from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds def prepare_ip_adapter_image_embeds( - self, components, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt + self, components, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance ): image_embeds = [] - negative_image_embeds = [] + if do_classifier_free_guidance: + negative_image_embeds = [] if ip_adapter_image_embeds is None: if not isinstance(ip_adapter_image, list): ip_adapter_image = [ip_adapter_image] @@ -255,18 +256,21 @@ def prepare_ip_adapter_image_embeds( ) image_embeds.append(single_image_embeds[None, :]) - negative_image_embeds.append(single_negative_image_embeds[None, :]) + if do_classifier_free_guidance: + negative_image_embeds.append(single_negative_image_embeds[None, :]) else: for single_image_embeds in ip_adapter_image_embeds: - single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2) - negative_image_embeds.append(single_negative_image_embeds) + if do_classifier_free_guidance: + single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2) + negative_image_embeds.append(single_negative_image_embeds) image_embeds.append(single_image_embeds) ip_adapter_image_embeds = [] for i, single_image_embeds in enumerate(image_embeds): single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0) - single_negative_image_embeds = torch.cat([negative_image_embeds[i]] * num_images_per_prompt, dim=0) - single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0) + if do_classifier_free_guidance: + single_negative_image_embeds = torch.cat([negative_image_embeds[i]] * num_images_per_prompt, dim=0) + single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0) single_image_embeds = single_image_embeds.to(device=device) ip_adapter_image_embeds.append(single_image_embeds) From 2238f55f40b3a7ddffab0fb176aa8d1bf1dec22f Mon Sep 17 00:00:00 2001 From: Aryan Date: Mon, 14 Apr 2025 16:43:16 +0200 Subject: [PATCH 06/24] cfg zero* --- src/diffusers/__init__.py | 2 + src/diffusers/guiders/__init__.py | 1 + .../guiders/classifier_free_guidance.py | 5 + .../classifier_free_zero_star_guidance.py | 143 ++++++++++++++++++ src/diffusers/guiders/skip_layer_guidance.py | 6 + 5 files changed, 157 insertions(+) create mode 100644 src/diffusers/guiders/classifier_free_zero_star_guidance.py diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index d3a61df3baff..b2e24614b97a 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -134,6 +134,7 @@ [ "AdaptiveProjectedGuidance", "ClassifierFreeGuidance", + "ClassifierFreeZeroStarGuidance", "SkipLayerGuidance", ] ) @@ -724,6 +725,7 @@ from .guiders import ( AdaptiveProjectedGuidance, ClassifierFreeGuidance, + ClassifierFreeZeroStarGuidance, SkipLayerGuidance, ) from .hooks import ( diff --git a/src/diffusers/guiders/__init__.py b/src/diffusers/guiders/__init__.py index e3c6494de090..ac3837a6b4d7 100644 --- a/src/diffusers/guiders/__init__.py +++ b/src/diffusers/guiders/__init__.py @@ -20,6 +20,7 @@ if is_torch_available(): from .adaptive_projected_guidance import AdaptiveProjectedGuidance from .classifier_free_guidance import ClassifierFreeGuidance + from .classifier_free_zero_star_guidance import ClassifierFreeZeroStarGuidance from .skip_layer_guidance import SkipLayerGuidance GuiderType = Union[ClassifierFreeGuidance, SkipLayerGuidance] diff --git a/src/diffusers/guiders/classifier_free_guidance.py b/src/diffusers/guiders/classifier_free_guidance.py index 3deacdfb2863..6978080b7152 100644 --- a/src/diffusers/guiders/classifier_free_guidance.py +++ b/src/diffusers/guiders/classifier_free_guidance.py @@ -23,20 +23,25 @@ class ClassifierFreeGuidance(BaseGuidance): """ Classifier-free guidance (CFG): https://huggingface.co/papers/2207.12598 + CFG is a technique used to improve generation quality and condition-following in diffusion models. It works by jointly training a model on both conditional and unconditional data, and using a weighted sum of the two during inference. This allows the model to tradeoff between generation quality and sample diversity. The original paper proposes scaling and shifting the conditional distribution based on the difference between conditional and unconditional predictions. [x_pred = x_cond + scale * (x_cond - x_uncond)] + Diffusers implemented the scaling and shifting on the unconditional prediction instead based on the [Imagen paper](https://huggingface.co/papers/2205.11487), which is equivalent to what the original paper proposed in theory. [x_pred = x_uncond + scale * (x_cond - x_uncond)] + The intution behind the original formulation can be thought of as moving the conditional distribution estimates further away from the unconditional distribution estimates, while the diffusers-native implementation can be thought of as moving the unconditional distribution towards the conditional distribution estimates to get rid of the unconditional predictions (usually negative features like "bad quality, bad anotomy, watermarks", etc.) + The `use_original_formulation` argument can be set to `True` to use the original CFG formulation mentioned in the paper. By default, we use the diffusers-native implementation that has been in the codebase for a long time. + Args: guidance_scale (`float`, defaults to `7.5`): The scale parameter for classifier-free guidance. Higher values result in stronger conditioning on the text diff --git a/src/diffusers/guiders/classifier_free_zero_star_guidance.py b/src/diffusers/guiders/classifier_free_zero_star_guidance.py new file mode 100644 index 000000000000..04c504f8f2d6 --- /dev/null +++ b/src/diffusers/guiders/classifier_free_zero_star_guidance.py @@ -0,0 +1,143 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from typing import Optional, Union, Tuple, List + +import torch + +from .guider_utils import BaseGuidance, rescale_noise_cfg, _default_prepare_inputs + + +class ClassifierFreeZeroStarGuidance(BaseGuidance): + """ + Classifier-free Zero* (CFG-Zero*): https://huggingface.co/papers/2503.18886 + + This is an implementation of the Classifier-Free Zero* guidance technique, which is a variant of classifier-free + guidance. It proposes zero initialization of the noise predictions for the first few steps of the diffusion + process, and also introduces an optimal rescaling factor for the noise predictions, which can help in improving the + quality of generated images. + + The authors of the paper suggest setting zero initialization in the first 4% of the inference steps. + + Args: + guidance_scale (`float`, defaults to `7.5`): + The scale parameter for classifier-free guidance. Higher values result in stronger conditioning on the text + prompt, while lower values allow for more freedom in generation. Higher values may lead to saturation and + deterioration of image quality. + zero_init_steps (`int`, defaults to `1`): + The number of inference steps for which the noise predictions are zeroed out (see Section 4.2). + guidance_rescale (`float`, defaults to `0.0`): + The rescale factor applied to the noise predictions. This is used to improve image quality and fix + overexposure. Based on Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are + Flawed](https://huggingface.co/papers/2305.08891). + use_original_formulation (`bool`, defaults to `False`): + Whether to use the original formulation of classifier-free guidance as proposed in the paper. By default, + we use the diffusers-native implementation that has been in the codebase for a long time. See + [~guiders.classifier_free_guidance.ClassifierFreeGuidance] for more details. + start (`float`, defaults to `0.01`): + The fraction of the total number of denoising steps after which guidance starts. + stop (`float`, defaults to `0.2`): + The fraction of the total number of denoising steps after which guidance stops. + """ + + _input_predictions = ["pred_cond", "pred_uncond"] + + def __init__( + self, + guidance_scale: float = 7.5, + zero_init_steps: int = 1, + guidance_rescale: float = 0.0, + use_original_formulation: bool = False, + start: float = 0.0, + stop: float = 1.0, + ): + super().__init__(start, stop) + + self.guidance_scale = guidance_scale + self.zero_init_steps = zero_init_steps + self.guidance_rescale = guidance_rescale + self.use_original_formulation = use_original_formulation + + def prepare_inputs(self, denoiser: torch.nn.Module, *args: Union[Tuple[torch.Tensor], List[torch.Tensor]]) -> Tuple[List[torch.Tensor], ...]: + return _default_prepare_inputs(denoiser, self.num_conditions, *args) + + def prepare_outputs(self, denoiser: torch.nn.Module, pred: torch.Tensor) -> None: + self._num_outputs_prepared += 1 + if self._num_outputs_prepared > self.num_conditions: + raise ValueError(f"Expected {self.num_conditions} outputs, but prepare_outputs called more times.") + key = self._input_predictions[self._num_outputs_prepared - 1] + self._preds[key] = pred + + def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> torch.Tensor: + pred = None + + if self._step < self.zero_init_steps: + pred = torch.zeros_like(pred_cond) + elif not self._is_cfg_enabled(): + pred = pred_cond + else: + pred_cond_flat = pred_cond.flatten(1) + pred_uncond_flat = pred_uncond.flatten(1) + alpha = cfg_zero_star_scale(pred_cond_flat, pred_uncond_flat) + alpha = alpha.view(-1, *(1,) * (len(pred_cond.shape) - 1)) + pred_uncond = pred_uncond * alpha + shift = pred_cond - pred_uncond + pred = pred_cond if self.use_original_formulation else pred_uncond + pred = pred + self.guidance_scale * shift + + if self.guidance_rescale > 0.0: + pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale) + + return pred + + @property + def is_conditional(self) -> bool: + return self._num_outputs_prepared == 0 + + @property + def num_conditions(self) -> int: + num_conditions = 1 + if self._is_cfg_enabled(): + num_conditions += 1 + return num_conditions + + def _is_cfg_enabled(self) -> bool: + if not self._enabled: + return False + + is_within_range = True + if self._num_inference_steps is not None: + skip_start_step = int(self._start * self._num_inference_steps) + skip_stop_step = int(self._stop * self._num_inference_steps) + is_within_range = skip_start_step <= self._step < skip_stop_step + + is_close = False + if self.use_original_formulation: + is_close = math.isclose(self.guidance_scale, 0.0) + else: + is_close = math.isclose(self.guidance_scale, 1.0) + + return is_within_range and not is_close + + +def cfg_zero_star_scale(cond: torch.Tensor, uncond: torch.Tensor, eps: float = 1e-8) -> torch.Tensor: + cond_dtype = cond.dtype + cond = cond.float() + uncond = uncond.float() + dot_product = torch.sum(cond * uncond, dim=1, keepdim=True) + squared_norm = torch.sum(uncond**2, dim=1, keepdim=True) + eps + # st_star = v_cond^T * v_uncond / ||v_uncond||^2 + scale = dot_product / squared_norm + return scale.to(dtype=cond_dtype) diff --git a/src/diffusers/guiders/skip_layer_guidance.py b/src/diffusers/guiders/skip_layer_guidance.py index 4abb93272e9a..bac851c0dc28 100644 --- a/src/diffusers/guiders/skip_layer_guidance.py +++ b/src/diffusers/guiders/skip_layer_guidance.py @@ -26,20 +26,26 @@ class SkipLayerGuidance(BaseGuidance): """ Skip Layer Guidance (SLG): https://github.com/Stability-AI/sd3.5 Spatio-Temporal Guidance (STG): https://huggingface.co/papers/2411.18664 + SLG was introduced by StabilityAI for improving structure and anotomy coherence in generated images. It works by skipping the forward pass of specified transformer blocks during the denoising process on an additional conditional batch of data, apart from the conditional and unconditional batches already used in CFG ([~guiders.classifier_free_guidance.ClassifierFreeGuidance]), and then scaling and shifting the CFG predictions based on the difference between conditional without skipping and conditional with skipping predictions. + The intution behind SLG can be thought of as moving the CFG predicted distribution estimates further away from worse versions of the conditional distribution estimates (because skipping layers is equivalent to using a worse version of the model for the conditional prediction). + STG is an improvement and follow-up work combining ideas from SLG, PAG and similar techniques for improving generation quality in video diffusion models. + Additional reading: - [Guiding a Diffusion Model with a Bad Version of Itself](https://huggingface.co/papers/2406.02507) + The values for `skip_layer_guidance_scale`, `skip_layer_guidance_start`, and `skip_layer_guidance_stop` are defaulted to the recommendations by StabilityAI for Stable Diffusion 3.5 Medium. + Args: guidance_scale (`float`, defaults to `7.5`): The scale parameter for classifier-free guidance. Higher values result in stronger conditioning on the text From 52b9b61f62087f0120cd6045bbac3f0e4cdc55e2 Mon Sep 17 00:00:00 2001 From: Aryan Date: Mon, 14 Apr 2025 17:34:10 +0200 Subject: [PATCH 07/24] use unwrap_module for torch compiled modules --- .../pipeline_stable_diffusion_xl_modular.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py index a02e869e8b32..0b36633efc4e 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py @@ -30,7 +30,7 @@ scale_lora_layers, unscale_lora_layers, ) -from ...utils.torch_utils import is_compiled_module, randn_tensor +from ...utils.torch_utils import randn_tensor, unwrap_module from ..controlnet.multicontrolnet import MultiControlNetModel from ..modular_pipeline import ( AutoPipelineBlocks, @@ -2545,7 +2545,7 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: data.height = data.height * pipeline.vae_scale_factor data.width = data.width * pipeline.vae_scale_factor - controlnet = pipeline.controlnet._orig_mod if is_compiled_module(pipeline.controlnet) else pipeline.controlnet + controlnet = unwrap_module(pipeline.controlnet) # (1.1) # control_guidance_start/control_guidance_end (align format) @@ -2973,7 +2973,7 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: data.height = data.height * pipeline.vae_scale_factor data.width = data.width * pipeline.vae_scale_factor - controlnet = pipeline.controlnet._orig_mod if is_compiled_module(pipeline.controlnet) else pipeline.controlnet + controlnet = unwrap_module(pipeline.controlnet) # (1.1) # control guidance From d8d2ea37293727d2201ddc09b03a307bbb07c421 Mon Sep 17 00:00:00 2001 From: Aryan Date: Mon, 14 Apr 2025 17:35:30 +0200 Subject: [PATCH 08/24] remove guider kwargs --- .../stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py index 0b36633efc4e..c67ec2f4decb 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py @@ -3646,7 +3646,6 @@ def num_channels_latents(self): "aesthetic_score": InputParam("aesthetic_score", type_hint=float, default=6.0, description="Simulates aesthetic score of generated image"), "negative_aesthetic_score": InputParam("negative_aesthetic_score", type_hint=float, default=2.0, description="Simulates negative aesthetic score"), "eta": InputParam("eta", type_hint=float, default=0.0, description="Parameter η in the DDIM paper"), - "guider_kwargs": InputParam("guider_kwargs", type_hint=Optional[Dict[str, Any]], description="Kwargs dictionary passed to the Guider"), "output_type": InputParam("output_type", type_hint=str, default="pil", description="Output format (pil/tensor/np.array)"), "return_dict": InputParam("return_dict", type_hint=bool, default=True, description="Whether to return a StableDiffusionXLPipelineOutput"), "ip_adapter_image": InputParam("ip_adapter_image", type_hint=PipelineImageInput, required=True, description="Image(s) to be used as IP adapter"), From b31904b853d0d271c16c640a712dc3ac43f15a05 Mon Sep 17 00:00:00 2001 From: Aryan Date: Mon, 14 Apr 2025 17:39:06 +0200 Subject: [PATCH 09/24] remove commented code --- .../pipeline_stable_diffusion_xl_modular.py | 7 ------- 1 file changed, 7 deletions(-) diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py index c67ec2f4decb..68d9d913bd3a 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py @@ -3041,13 +3041,6 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: else: pipeline.guider.force_enable() - # (3) Prepare conditional inputs for controlnet using the guider - # data.controlnet_prompt_embeds = pipeline.controlnet_guider.prepare_input(data.prompt_embeds) - # data.controlnet_added_cond_kwargs = { - # "text_embeds": pipeline.controlnet_guider.prepare_input(data.pooled_prompt_embeds), - # "time_ids": pipeline.controlnet_guider.prepare_input(data.add_time_ids), - # } - data.control_type = data.control_type.reshape(1, -1).to(data.device, dtype=data.prompt_embeds.dtype) repeat_by = data.batch_size * data.num_images_per_prompt // data.control_type.shape[0] data.control_type = data.control_type.repeat_interleave(repeat_by, dim=0) From 57c7e15a919d5d01fc236058ba1dc8fd2d16cccd Mon Sep 17 00:00:00 2001 From: Aryan Date: Mon, 14 Apr 2025 17:45:56 +0200 Subject: [PATCH 10/24] remove old guider --- src/diffusers/guider.py | 748 ------------------------------ src/diffusers/guiders/__init__.py | 2 +- 2 files changed, 1 insertion(+), 749 deletions(-) delete mode 100644 src/diffusers/guider.py diff --git a/src/diffusers/guider.py b/src/diffusers/guider.py deleted file mode 100644 index b42dca64d651..000000000000 --- a/src/diffusers/guider.py +++ /dev/null @@ -1,748 +0,0 @@ -# Copyright 2024 The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import re -from typing import Any, Dict, List, Optional, Tuple, Union - -import torch -import torch.nn as nn - -from .models.attention_processor import ( - Attention, - AttentionProcessor, - PAGCFGIdentitySelfAttnProcessor2_0, - PAGIdentitySelfAttnProcessor2_0, -) -from .utils import logging - - -logger = logging.get_logger(__name__) # pylint: disable=invalid-name - - -# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg -def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): - r""" - Rescales `noise_cfg` tensor based on `guidance_rescale` to improve image quality and fix overexposure. Based on - Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are - Flawed](https://arxiv.org/pdf/2305.08891.pdf). - - Args: - noise_cfg (`torch.Tensor`): - The predicted noise tensor for the guided diffusion process. - noise_pred_text (`torch.Tensor`): - The predicted noise tensor for the text-guided diffusion process. - guidance_rescale (`float`, *optional*, defaults to 0.0): - A rescale factor applied to the noise predictions. - - Returns: - noise_cfg (`torch.Tensor`): The rescaled noise prediction tensor. - """ - std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True) - std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True) - # rescale the results from guidance (fixes overexposure) - noise_pred_rescaled = noise_cfg * (std_text / std_cfg) - # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images - noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg - return noise_cfg - - -class CFGGuider: - """ - This class is used to guide the pipeline with CFG (Classifier-Free Guidance). - """ - - @property - def do_classifier_free_guidance(self): - return self._guidance_scale > 1.0 and not self._disable_guidance - - @property - def guidance_rescale(self): - return self._guidance_rescale - - @property - def guidance_scale(self): - return self._guidance_scale - - @property - def batch_size(self): - return self._batch_size - - def set_guider(self, pipeline, guider_kwargs: Dict[str, Any]): - # a flag to disable CFG, e.g. we disable it for LCM and use a guidance scale embedding instead - disable_guidance = guider_kwargs.get("disable_guidance", False) - guidance_scale = guider_kwargs.get("guidance_scale", None) - if guidance_scale is None: - raise ValueError("guidance_scale is not provided in guider_kwargs") - guidance_rescale = guider_kwargs.get("guidance_rescale", 0.0) - batch_size = guider_kwargs.get("batch_size", None) - if batch_size is None: - raise ValueError("batch_size is not provided in guider_kwargs") - self._guidance_scale = guidance_scale - self._guidance_rescale = guidance_rescale - self._batch_size = batch_size - self._disable_guidance = disable_guidance - - def reset_guider(self, pipeline): - pass - - def maybe_update_guider(self, pipeline, timestep): - pass - - def maybe_update_input(self, pipeline, cond_input): - pass - - def _maybe_split_prepared_input(self, cond): - """ - Process and potentially split the conditional input for Classifier-Free Guidance (CFG). - - This method handles inputs that may already have CFG applied (i.e. when `cond` is output of `prepare_input`). - It determines whether to split the input based on its batch size relative to the expected batch size. - - Args: - cond (torch.Tensor): The conditional input tensor to process. - - Returns: - Tuple[torch.Tensor, torch.Tensor]: A tuple containing: - - The negative conditional input (uncond_input) - - The positive conditional input (cond_input) - """ - if cond.shape[0] == self.batch_size * 2: - neg_cond = cond[0 : self.batch_size] - cond = cond[self.batch_size :] - return neg_cond, cond - elif cond.shape[0] == self.batch_size: - return cond, cond - else: - raise ValueError(f"Unsupported input shape: {cond.shape}") - - def _is_prepared_input(self, cond): - """ - Check if the input is already prepared for Classifier-Free Guidance (CFG). - - Args: - cond (torch.Tensor): The conditional input tensor to check. - - Returns: - bool: True if the input is already prepared, False otherwise. - """ - cond_tensor = cond[0] if isinstance(cond, (list, tuple)) else cond - - return cond_tensor.shape[0] == self.batch_size * 2 - - def prepare_input( - self, - cond_input: Union[torch.Tensor, List[torch.Tensor]], - negative_cond_input: Optional[Union[torch.Tensor, List[torch.Tensor]]] = None, - ) -> Union[torch.Tensor, List[torch.Tensor]]: - """ - Prepare the input for CFG. - - Args: - cond_input (Union[torch.Tensor, List[torch.Tensor]]): - The conditional input. It can be a single tensor or a - list of tensors. It must have the same length as `negative_cond_input`. - negative_cond_input (Union[torch.Tensor, List[torch.Tensor]]): The negative conditional input. It can be a - single tensor or a list of tensors. It must have the same length as `cond_input`. - - Returns: - Union[torch.Tensor, List[torch.Tensor]]: The prepared input. - """ - - # we check if cond_input already has CFG applied, and split if it is the case. - if self._is_prepared_input(cond_input) and self.do_classifier_free_guidance: - return cond_input - - if self._is_prepared_input(cond_input) and not self.do_classifier_free_guidance: - if isinstance(cond_input, list): - negative_cond_input, cond_input = zip(*[self._maybe_split_prepared_input(cond) for cond in cond_input]) - else: - negative_cond_input, cond_input = self._maybe_split_prepared_input(cond_input) - - if not self._is_prepared_input(cond_input) and self.do_classifier_free_guidance and negative_cond_input is None: - raise ValueError( - "`negative_cond_input` is required when cond_input does not already contains negative conditional input" - ) - - if isinstance(cond_input, (list, tuple)): - if not self.do_classifier_free_guidance: - return cond_input - - if len(negative_cond_input) != len(cond_input): - raise ValueError("The length of negative_cond_input and cond_input must be the same.") - prepared_input = [] - for neg_cond, cond in zip(negative_cond_input, cond_input): - if neg_cond.shape[0] != cond.shape[0]: - raise ValueError("The batch size of negative_cond_input and cond_input must be the same.") - prepared_input.append(torch.cat([neg_cond, cond], dim=0)) - return prepared_input - - elif isinstance(cond_input, torch.Tensor): - if not self.do_classifier_free_guidance: - return cond_input - else: - return torch.cat([negative_cond_input, cond_input], dim=0) - - else: - raise ValueError(f"Unsupported input type: {type(cond_input)}") - - def apply_guidance( - self, - model_output: torch.Tensor, - timestep: int = None, - latents: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - if not self.do_classifier_free_guidance: - return model_output - - noise_pred_uncond, noise_pred_text = model_output.chunk(2) - noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) - - if self.guidance_rescale > 0.0: - # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf - noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale) - return noise_pred - - -class PAGGuider: - """ - This class is used to guide the pipeline with CFG (Classifier-Free Guidance). - """ - - def __init__( - self, - pag_applied_layers: Union[str, List[str]], - pag_attn_processors: Tuple[AttentionProcessor, AttentionProcessor] = ( - PAGCFGIdentitySelfAttnProcessor2_0(), - PAGIdentitySelfAttnProcessor2_0(), - ), - ): - r""" - Set the the self-attention layers to apply PAG. Raise ValueError if the input is invalid. - - Args: - pag_applied_layers (`str` or `List[str]`): - One or more strings identifying the layer names, or a simple regex for matching multiple layers, where - PAG is to be applied. A few ways of expected usage are as follows: - - Single layers specified as - "blocks.{layer_index}" - - Multiple layers as a list - ["blocks.{layers_index_1}", "blocks.{layer_index_2}", ...] - - Multiple layers as a block name - "mid" - - Multiple layers as regex - "blocks.({layer_index_1}|{layer_index_2})" - pag_attn_processors: - (`Tuple[AttentionProcessor, AttentionProcessor]`, defaults to `(PAGCFGIdentitySelfAttnProcessor2_0(), - PAGIdentitySelfAttnProcessor2_0())`): A tuple of two attention processors. The first attention - processor is for PAG with Classifier-free guidance enabled (conditional and unconditional). The second - attention processor is for PAG with CFG disabled (unconditional only). - """ - - if not isinstance(pag_applied_layers, list): - pag_applied_layers = [pag_applied_layers] - if pag_attn_processors is not None: - if not isinstance(pag_attn_processors, tuple) or len(pag_attn_processors) != 2: - raise ValueError("Expected a tuple of two attention processors") - - for i in range(len(pag_applied_layers)): - if not isinstance(pag_applied_layers[i], str): - raise ValueError( - f"Expected either a string or a list of string but got type {type(pag_applied_layers[i])}" - ) - - self.pag_applied_layers = pag_applied_layers - self._pag_attn_processors = pag_attn_processors - - def _set_pag_attn_processor(self, model, pag_applied_layers, do_classifier_free_guidance): - r""" - Set the attention processor for the PAG layers. - """ - pag_attn_processors = self._pag_attn_processors - pag_attn_proc = pag_attn_processors[0] if do_classifier_free_guidance else pag_attn_processors[1] - - def is_self_attn(module: nn.Module) -> bool: - r""" - Check if the module is self-attention module based on its name. - """ - return isinstance(module, Attention) and not module.is_cross_attention - - def is_fake_integral_match(layer_id, name): - layer_id = layer_id.split(".")[-1] - name = name.split(".")[-1] - return layer_id.isnumeric() and name.isnumeric() and layer_id == name - - for layer_id in pag_applied_layers: - # for each PAG layer input, we find corresponding self-attention layers in the unet model - target_modules = [] - - for name, module in model.named_modules(): - # Identify the following simple cases: - # (1) Self Attention layer existing - # (2) Whether the module name matches pag layer id even partially - # (3) Make sure it's not a fake integral match if the layer_id ends with a number - # For example, blocks.1, blocks.10 should be differentiable if layer_id="blocks.1" - if ( - is_self_attn(module) - and re.search(layer_id, name) is not None - and not is_fake_integral_match(layer_id, name) - ): - logger.debug(f"Applying PAG to layer: {name}") - target_modules.append(module) - - if len(target_modules) == 0: - raise ValueError(f"Cannot find PAG layer to set attention processor for: {layer_id}") - - for module in target_modules: - module.processor = pag_attn_proc - - @property - def do_classifier_free_guidance(self): - return self._guidance_scale > 1 and not self._disable_guidance - - @property - def do_perturbed_attention_guidance(self): - return self._pag_scale > 0 and not self._disable_guidance - - @property - def do_pag_adaptive_scaling(self): - return self._pag_adaptive_scale > 0 and self._pag_scale > 0 and not self._disable_guidance - - @property - def guidance_scale(self): - return self._guidance_scale - - @property - def guidance_rescale(self): - return self._guidance_rescale - - @property - def batch_size(self): - return self._batch_size - - @property - def pag_scale(self): - return self._pag_scale - - @property - def pag_adaptive_scale(self): - return self._pag_adaptive_scale - - def set_guider(self, pipeline, guider_kwargs: Dict[str, Any]): - pag_scale = guider_kwargs.get("pag_scale", 3.0) - pag_adaptive_scale = guider_kwargs.get("pag_adaptive_scale", 0.0) - - batch_size = guider_kwargs.get("batch_size", None) - if batch_size is None: - raise ValueError("batch_size is a required argument for PAGGuider") - - guidance_scale = guider_kwargs.get("guidance_scale", None) - guidance_rescale = guider_kwargs.get("guidance_rescale", 0.0) - disable_guidance = guider_kwargs.get("disable_guidance", False) - - if guidance_scale is None: - raise ValueError("guidance_scale is a required argument for PAGGuider") - - self._pag_scale = pag_scale - self._pag_adaptive_scale = pag_adaptive_scale - self._guidance_scale = guidance_scale - self._disable_guidance = disable_guidance - self._guidance_rescale = guidance_rescale - self._batch_size = batch_size - if not hasattr(pipeline, "original_attn_proc") or pipeline.original_attn_proc is None: - pipeline.original_attn_proc = pipeline.unet.attn_processors - self._set_pag_attn_processor( - model=pipeline.unet if hasattr(pipeline, "unet") else pipeline.transformer, - pag_applied_layers=self.pag_applied_layers, - do_classifier_free_guidance=self.do_classifier_free_guidance, - ) - - def reset_guider(self, pipeline): - if ( - self.do_perturbed_attention_guidance - and hasattr(pipeline, "original_attn_proc") - and pipeline.original_attn_proc is not None - ): - pipeline.unet.set_attn_processor(pipeline.original_attn_proc) - pipeline.original_attn_proc = None - - def maybe_update_guider(self, pipeline, timestep): - pass - - def maybe_update_input(self, pipeline, cond_input): - pass - - def _is_prepared_input(self, cond): - """ - Check if the input is already prepared for Perturbed Attention Guidance (PAG). - - Args: - cond (torch.Tensor): The conditional input tensor to check. - - Returns: - bool: True if the input is already prepared, False otherwise. - """ - cond_tensor = cond[0] if isinstance(cond, (list, tuple)) else cond - - return cond_tensor.shape[0] == self.batch_size * 3 - - def _maybe_split_prepared_input(self, cond): - """ - Process and potentially split the conditional input for Classifier-Free Guidance (CFG). - - This method handles inputs that may already have CFG applied (i.e. when `cond` is output of `prepare_input`). - It determines whether to split the input based on its batch size relative to the expected batch size. - - Args: - cond (torch.Tensor): The conditional input tensor to process. - - Returns: - Tuple[torch.Tensor, torch.Tensor]: A tuple containing: - - The negative conditional input (uncond_input) - - The positive conditional input (cond_input) - """ - if cond.shape[0] == self.batch_size * 3: - neg_cond = cond[0 : self.batch_size] - cond = cond[self.batch_size : self.batch_size * 2] - return neg_cond, cond - elif cond.shape[0] == self.batch_size: - return cond, cond - else: - raise ValueError(f"Unsupported input shape: {cond.shape}") - - def prepare_input( - self, - cond_input: Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor, ...]], - negative_cond_input: Optional[Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor, ...]]] = None, - ) -> Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor, ...]]: - """ - Prepare the input for CFG. - - Args: - cond_input (Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor, ...]]): - The conditional input. It can be a single tensor or a - list of tensors. It must have the same length as `negative_cond_input`. - negative_cond_input (Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor, ...]]): - The negative conditional input. It can be a single tensor or a list of tensors. It must have the same - length as `cond_input`. - - Returns: - Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor, ...]]: The prepared input. - """ - - # we check if cond_input already has CFG applied, and split if it is the case. - - if self._is_prepared_input(cond_input) and self.do_perturbed_attention_guidance: - return cond_input - - if self._is_prepared_input(cond_input) and not self.do_perturbed_attention_guidance: - if isinstance(cond_input, list): - negative_cond_input, cond_input = zip(*[self._maybe_split_prepared_input(cond) for cond in cond_input]) - else: - negative_cond_input, cond_input = self._maybe_split_prepared_input(cond_input) - - if not self._is_prepared_input(cond_input) and self.do_perturbed_attention_guidance and negative_cond_input is None: - raise ValueError( - "`negative_cond_input` is required when cond_input does not already contains negative conditional input" - ) - - if isinstance(cond_input, (list, tuple)): - if not self.do_perturbed_attention_guidance: - return cond_input - - if len(negative_cond_input) != len(cond_input): - raise ValueError("The length of negative_cond_input and cond_input must be the same.") - - prepared_input = [] - for neg_cond, cond in zip(negative_cond_input, cond_input): - if neg_cond.shape[0] != cond.shape[0]: - raise ValueError("The batch size of negative_cond_input and cond_input must be the same.") - - cond = torch.cat([cond] * 2, dim=0) - if self.do_classifier_free_guidance: - prepared_input.append(torch.cat([neg_cond, cond], dim=0)) - else: - prepared_input.append(cond) - - return prepared_input - - elif isinstance(cond_input, torch.Tensor): - if not self.do_perturbed_attention_guidance: - return cond_input - - cond_input = torch.cat([cond_input] * 2, dim=0) - if self.do_classifier_free_guidance: - return torch.cat([negative_cond_input, cond_input], dim=0) - else: - return cond_input - - else: - raise ValueError(f"Unsupported input type: {type(negative_cond_input)} and {type(cond_input)}") - - def apply_guidance( - self, - model_output: torch.Tensor, - timestep: int, - latents: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - if not self.do_perturbed_attention_guidance: - return model_output - - if self.do_pag_adaptive_scaling: - pag_scale = max(self._pag_scale - self._pag_adaptive_scale * (1000 - timestep), 0) - else: - pag_scale = self._pag_scale - - if self.do_classifier_free_guidance: - noise_pred_uncond, noise_pred_text, noise_pred_perturb = model_output.chunk(3) - noise_pred = ( - noise_pred_uncond - + self.guidance_scale * (noise_pred_text - noise_pred_uncond) - + pag_scale * (noise_pred_text - noise_pred_perturb) - ) - else: - noise_pred_text, noise_pred_perturb = model_output.chunk(2) - noise_pred = noise_pred_text + pag_scale * (noise_pred_text - noise_pred_perturb) - - if self.guidance_rescale > 0.0: - # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf - noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale) - - return noise_pred - - -class MomentumBuffer: - def __init__(self, momentum: float): - self.momentum = momentum - self.running_average = 0 - - def update(self, update_value: torch.Tensor): - new_average = self.momentum * self.running_average - self.running_average = update_value + new_average - - -class APGGuider: - """ - This class is used to guide the pipeline with APG (Adaptive Projected Guidance). - """ - - def normalized_guidance( - self, - pred_cond: torch.Tensor, - pred_uncond: torch.Tensor, - guidance_scale: float, - momentum_buffer: MomentumBuffer = None, - norm_threshold: float = 0.0, - eta: float = 1.0, - ): - """ - Based on the findings of [Eliminating Oversaturation and Artifacts of High Guidance Scales in Diffusion - Models](https://arxiv.org/pdf/2410.02416) - """ - diff = pred_cond - pred_uncond - if momentum_buffer is not None: - momentum_buffer.update(diff) - diff = momentum_buffer.running_average - if norm_threshold > 0: - ones = torch.ones_like(diff) - diff_norm = diff.norm(p=2, dim=[-1, -2, -3], keepdim=True) - scale_factor = torch.minimum(ones, norm_threshold / diff_norm) - diff = diff * scale_factor - v0, v1 = diff.double(), pred_cond.double() - v1 = torch.nn.functional.normalize(v1, dim=[-1, -2, -3]) - v0_parallel = (v0 * v1).sum(dim=[-1, -2, -3], keepdim=True) * v1 - v0_orthogonal = v0 - v0_parallel - diff_parallel, diff_orthogonal = v0_parallel.to(diff.dtype), v0_orthogonal.to(diff.dtype) - normalized_update = diff_orthogonal + eta * diff_parallel - pred_guided = pred_cond + (guidance_scale - 1) * normalized_update - return pred_guided - - @property - def adaptive_projected_guidance_momentum(self): - return self._adaptive_projected_guidance_momentum - - @property - def adaptive_projected_guidance_rescale_factor(self): - return self._adaptive_projected_guidance_rescale_factor - - @property - def do_classifier_free_guidance(self): - return self._guidance_scale > 1.0 and not self._disable_guidance - - @property - def guidance_rescale(self): - return self._guidance_rescale - - @property - def guidance_scale(self): - return self._guidance_scale - - @property - def batch_size(self): - return self._batch_size - - def set_guider(self, pipeline, guider_kwargs: Dict[str, Any]): - disable_guidance = guider_kwargs.get("disable_guidance", False) - guidance_scale = guider_kwargs.get("guidance_scale", None) - if guidance_scale is None: - raise ValueError("guidance_scale is not provided in guider_kwargs") - adaptive_projected_guidance_momentum = guider_kwargs.get("adaptive_projected_guidance_momentum", None) - adaptive_projected_guidance_rescale_factor = guider_kwargs.get( - "adaptive_projected_guidance_rescale_factor", 15.0 - ) - guidance_rescale = guider_kwargs.get("guidance_rescale", 0.0) - batch_size = guider_kwargs.get("batch_size", None) - if batch_size is None: - raise ValueError("batch_size is not provided in guider_kwargs") - self._adaptive_projected_guidance_momentum = adaptive_projected_guidance_momentum - self._adaptive_projected_guidance_rescale_factor = adaptive_projected_guidance_rescale_factor - self._guidance_scale = guidance_scale - self._guidance_rescale = guidance_rescale - self._batch_size = batch_size - self._disable_guidance = disable_guidance - if adaptive_projected_guidance_momentum is not None: - self.momentum_buffer = MomentumBuffer(adaptive_projected_guidance_momentum) - else: - self.momentum_buffer = None - self.scheduler = pipeline.scheduler - - def reset_guider(self, pipeline): - pass - - def maybe_update_guider(self, pipeline, timestep): - pass - - def maybe_update_input(self, pipeline, cond_input): - pass - - def _maybe_split_prepared_input(self, cond): - """ - Process and potentially split the conditional input for Classifier-Free Guidance (CFG). - - This method handles inputs that may already have CFG applied (i.e. when `cond` is output of `prepare_input`). - It determines whether to split the input based on its batch size relative to the expected batch size. - - Args: - cond (torch.Tensor): The conditional input tensor to process. - - Returns: - Tuple[torch.Tensor, torch.Tensor]: A tuple containing: - - The negative conditional input (uncond_input) - - The positive conditional input (cond_input) - """ - if cond.shape[0] == self.batch_size * 2: - neg_cond = cond[0 : self.batch_size] - cond = cond[self.batch_size :] - return neg_cond, cond - elif cond.shape[0] == self.batch_size: - return cond, cond - else: - raise ValueError(f"Unsupported input shape: {cond.shape}") - - def _is_prepared_input(self, cond): - """ - Check if the input is already prepared for Classifier-Free Guidance (CFG). - - Args: - cond (torch.Tensor): The conditional input tensor to check. - - Returns: - bool: True if the input is already prepared, False otherwise. - """ - cond_tensor = cond[0] if isinstance(cond, (list, tuple)) else cond - - return cond_tensor.shape[0] == self.batch_size * 2 - - def prepare_input( - self, - cond_input: Union[torch.Tensor, List[torch.Tensor]], - negative_cond_input: Optional[Union[torch.Tensor, List[torch.Tensor]]] = None, - ) -> Union[torch.Tensor, List[torch.Tensor]]: - """ - Prepare the input for CFG. - - Args: - cond_input (Union[torch.Tensor, List[torch.Tensor]]): - The conditional input. It can be a single tensor or a - list of tensors. It must have the same length as `negative_cond_input`. - negative_cond_input (Union[torch.Tensor, List[torch.Tensor]]): The negative conditional input. It can be a - single tensor or a list of tensors. It must have the same length as `cond_input`. - - Returns: - Union[torch.Tensor, List[torch.Tensor]]: The prepared input. - """ - - # we check if cond_input already has CFG applied, and split if it is the case. - if self._is_prepared_input(cond_input) and self.do_classifier_free_guidance: - return cond_input - - if self._is_prepared_input(cond_input) and not self.do_classifier_free_guidance: - if isinstance(cond_input, list): - negative_cond_input, cond_input = zip(*[self._maybe_split_prepared_input(cond) for cond in cond_input]) - else: - negative_cond_input, cond_input = self._maybe_split_prepared_input(cond_input) - - if not self._is_prepared_input(cond_input) and self.do_classifier_free_guidance and negative_cond_input is None: - raise ValueError( - "`negative_cond_input` is required when cond_input does not already contains negative conditional input" - ) - - if isinstance(cond_input, (list, tuple)): - if not self.do_classifier_free_guidance: - return cond_input - - if len(negative_cond_input) != len(cond_input): - raise ValueError("The length of negative_cond_input and cond_input must be the same.") - prepared_input = [] - for neg_cond, cond in zip(negative_cond_input, cond_input): - if neg_cond.shape[0] != cond.shape[0]: - raise ValueError("The batch size of negative_cond_input and cond_input must be the same.") - prepared_input.append(torch.cat([neg_cond, cond], dim=0)) - return prepared_input - - elif isinstance(cond_input, torch.Tensor): - if not self.do_classifier_free_guidance: - return cond_input - else: - return torch.cat([negative_cond_input, cond_input], dim=0) - - else: - raise ValueError(f"Unsupported input type: {type(cond_input)}") - - def apply_guidance( - self, - model_output: torch.Tensor, - timestep: int = None, - latents: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - if not self.do_classifier_free_guidance: - return model_output - - if latents is None: - raise ValueError("APG requires `latents` to convert model output to denoised prediction (x0).") - - sigma = self.scheduler.sigmas[self.scheduler.step_index] - noise_pred = latents - sigma * model_output - noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) - noise_pred = self.normalized_guidance( - noise_pred_text, - noise_pred_uncond, - self.guidance_scale, - self.momentum_buffer, - self.adaptive_projected_guidance_rescale_factor, - ) - noise_pred = (latents - noise_pred) / sigma - - if self.guidance_rescale > 0.0: - # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf - noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale) - return noise_pred - - -Guiders = Union[CFGGuider, PAGGuider, APGGuider] \ No newline at end of file diff --git a/src/diffusers/guiders/__init__.py b/src/diffusers/guiders/__init__.py index ac3837a6b4d7..52b9176f5fb7 100644 --- a/src/diffusers/guiders/__init__.py +++ b/src/diffusers/guiders/__init__.py @@ -23,4 +23,4 @@ from .classifier_free_zero_star_guidance import ClassifierFreeZeroStarGuidance from .skip_layer_guidance import SkipLayerGuidance - GuiderType = Union[ClassifierFreeGuidance, SkipLayerGuidance] + GuiderType = Union[AdaptiveProjectedGuidance, ClassifierFreeGuidance, ClassifierFreeZeroStarGuidance, SkipLayerGuidance] From 8d31c699a5bed011afc2ac6569eff3892fe885ae Mon Sep 17 00:00:00 2001 From: Aryan Date: Tue, 15 Apr 2025 08:26:03 +0200 Subject: [PATCH 11/24] fix slg bug --- .../guiders/adaptive_projected_guidance.py | 6 ++-- src/diffusers/guiders/guider_utils.py | 7 ++++ src/diffusers/guiders/skip_layer_guidance.py | 35 ++++++++++++++----- src/diffusers/hooks/layer_skip.py | 13 ++++--- .../pipeline_stable_diffusion_xl_modular.py | 3 ++ 5 files changed, 48 insertions(+), 16 deletions(-) diff --git a/src/diffusers/guiders/adaptive_projected_guidance.py b/src/diffusers/guiders/adaptive_projected_guidance.py index ab5175745b07..45bd196860f4 100644 --- a/src/diffusers/guiders/adaptive_projected_guidance.py +++ b/src/diffusers/guiders/adaptive_projected_guidance.py @@ -86,7 +86,7 @@ def prepare_outputs(self, denoiser: torch.nn.Module, pred: torch.Tensor) -> None def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> torch.Tensor: pred = None - if not self._is_cfg_enabled(): + if not self._is_apg_enabled(): pred = pred_cond else: pred = normalized_guidance( @@ -111,11 +111,11 @@ def is_conditional(self) -> bool: @property def num_conditions(self) -> int: num_conditions = 1 - if self._is_cfg_enabled(): + if self._is_apg_enabled(): num_conditions += 1 return num_conditions - def _is_cfg_enabled(self) -> bool: + def _is_apg_enabled(self) -> bool: if not self._enabled: return False diff --git a/src/diffusers/guiders/guider_utils.py b/src/diffusers/guiders/guider_utils.py index 690afae89178..60859bf390f6 100644 --- a/src/diffusers/guiders/guider_utils.py +++ b/src/diffusers/guiders/guider_utils.py @@ -68,6 +68,13 @@ def set_state(self, step: int, num_inference_steps: int, timestep: torch.LongTen self._preds = {} self._num_outputs_prepared = 0 + def prepare_models(self, denoiser: torch.nn.Module) -> None: + """ + Prepares the models for the guidance technique on a given batch of data. This method should be overridden in + subclasses to implement specific model preparation logic. + """ + pass + def prepare_inputs(self, denoiser: torch.nn.Module, *args: Union[Tuple[torch.Tensor], List[torch.Tensor]]) -> Tuple[List[torch.Tensor], ...]: raise NotImplementedError("BaseGuidance::prepare_inputs must be implemented in subclasses.") diff --git a/src/diffusers/guiders/skip_layer_guidance.py b/src/diffusers/guiders/skip_layer_guidance.py index bac851c0dc28..3fbfd771eff9 100644 --- a/src/diffusers/guiders/skip_layer_guidance.py +++ b/src/diffusers/guiders/skip_layer_guidance.py @@ -54,6 +54,10 @@ class SkipLayerGuidance(BaseGuidance): skip_layer_guidance_scale (`float`, defaults to `2.8`): The scale parameter for skip layer guidance. Anatomy and structure coherence may improve with higher values, but it may also lead to overexposure and saturation. + skip_layer_guidance_start (`float`, defaults to `0.01`): + The fraction of the total number of denoising steps after which skip layer guidance starts. + skip_layer_guidance_stop (`float`, defaults to `0.2`): + The fraction of the total number of denoising steps after which skip layer guidance stops. skip_layer_guidance_layers (`int` or `List[int]`, *optional*): The layer indices to apply skip layer guidance to. Can be a single integer or a list of integers. If not provided, `skip_layer_config` must be provided. The recommended values are `[7, 8, 9]` for Stable Diffusion @@ -81,20 +85,33 @@ def __init__( self, guidance_scale: float = 7.5, skip_layer_guidance_scale: float = 2.8, + skip_layer_guidance_start: float = 0.01, + skip_layer_guidance_stop: float = 0.2, skip_layer_guidance_layers: Optional[Union[int, List[int]]] = None, skip_layer_config: Union[LayerSkipConfig, List[LayerSkipConfig]] = None, guidance_rescale: float = 0.0, use_original_formulation: bool = False, - start: float = 0.01, - stop: float = 0.2, + start: float = 0.0, + stop: float = 1.0, ): super().__init__(start, stop) self.guidance_scale = guidance_scale self.skip_layer_guidance_scale = skip_layer_guidance_scale + self.skip_layer_guidance_start = skip_layer_guidance_start + self.skip_layer_guidance_stop = skip_layer_guidance_stop self.guidance_rescale = guidance_rescale self.use_original_formulation = use_original_formulation + if not (0.0 <= skip_layer_guidance_start < 1.0): + raise ValueError( + f"Expected `skip_layer_guidance_start` to be between 0.0 and 1.0, but got {skip_layer_guidance_start}." + ) + if not (skip_layer_guidance_start <= skip_layer_guidance_stop <= 1.0): + raise ValueError( + f"Expected `skip_layer_guidance_stop` to be between 0.0 and 1.0, but got {skip_layer_guidance_stop}." + ) + if skip_layer_guidance_layers is None and skip_layer_config is None: raise ValueError( "Either `skip_layer_guidance_layers` or `skip_layer_config` must be provided to enable Skip Layer Guidance." @@ -122,11 +139,12 @@ def __init__( self.skip_layer_config = skip_layer_config self._skip_layer_hook_names = [f"SkipLayerGuidance_{i}" for i in range(len(self.skip_layer_config))] - def prepare_inputs(self, denoiser: torch.nn.Module, *args: Union[Tuple[torch.Tensor], List[torch.Tensor]]) -> Tuple[List[torch.Tensor], ...]: - if self._num_outputs_prepared == 0 and self._is_slg_enabled(): + def prepare_models(self, denoiser: torch.nn.Module) -> None: + if self._is_slg_enabled() and self.is_conditional and self._num_outputs_prepared > 0: for name, config in zip(self._skip_layer_hook_names, self.skip_layer_config): _apply_layer_skip_hook(denoiser, config, name=name) - + + def prepare_inputs(self, denoiser: torch.nn.Module, *args: Union[Tuple[torch.Tensor], List[torch.Tensor]]) -> Tuple[List[torch.Tensor], ...]: num_conditions = self.num_conditions list_of_inputs = [] for arg in args: @@ -161,7 +179,8 @@ def prepare_outputs(self, denoiser: torch.nn.Module, pred: torch.Tensor) -> None key = "pred_cond_skip" self._preds[key] = pred - if self._num_outputs_prepared == self.num_conditions: + if key == "pred_cond_skip": + # If we are in SLG mode, we need to remove the hooks after inference registry = HookRegistry.check_if_exists_or_initialize(denoiser) # Remove the hooks after inference for hook_name in self._skip_layer_hook_names: @@ -233,8 +252,8 @@ def _is_slg_enabled(self) -> bool: is_within_range = True if self._num_inference_steps is not None: - skip_start_step = int(self._start * self._num_inference_steps) - skip_stop_step = int(self._stop * self._num_inference_steps) + skip_start_step = int(self.skip_layer_guidance_start * self._num_inference_steps) + skip_stop_step = int(self.skip_layer_guidance_stop * self._num_inference_steps) is_within_range = skip_start_step < self._step < skip_stop_step is_zero = math.isclose(self.skip_layer_guidance_scale, 0.0) diff --git a/src/diffusers/hooks/layer_skip.py b/src/diffusers/hooks/layer_skip.py index b8471523902a..e28322ac4865 100644 --- a/src/diffusers/hooks/layer_skip.py +++ b/src/diffusers/hooks/layer_skip.py @@ -80,14 +80,17 @@ def __init__(self, skip_processor_output_fn: Callable, skip_attention_scores: bo def new_forward(self, module: torch.nn.Module, *args, **kwargs): if self.skip_attention_scores: + print("Skipping attention scores") with AttentionScoreSkipFunctionMode(): return self.fn_ref.original_forward(*args, **kwargs) else: + print("Skipping attention processor output") return self.skip_processor_output_fn(module, *args, **kwargs) class FeedForwardSkipHook(ModelHook): def new_forward(self, module: torch.nn.Module, *args, **kwargs): + print("Skipping feed-forward block") output = kwargs.get("hidden_states", None) if output is None: output = kwargs.get("x", None) @@ -102,18 +105,22 @@ def initialize_hook(self, module): return module def new_forward(self, module: torch.nn.Module, *args, **kwargs): + print("Skipping transformer block") return self._metadata.skip_block_output_fn(module, *args, **kwargs) def apply_layer_skip(module: torch.nn.Module, config: LayerSkipConfig) -> None: r""" Apply layer skipping to internal layers of a transformer. + Args: module (`torch.nn.Module`): The transformer model to which the layer skip hook should be applied. config (`LayerSkipConfig`): The configuration for the layer skip hook. + Example: + ```python >>> from diffusers import apply_layer_skip_hook, CogVideoXTransformer3DModel, LayerSkipConfig >>> transformer = CogVideoXTransformer3DModel.from_pretrained("THUDM/CogVideoX-5b", torch_dtype=torch.bfloat16) @@ -168,17 +175,13 @@ def _apply_layer_skip_hook(module: torch.nn.Module, config: LayerSkipConfig, nam registry = HookRegistry.check_if_exists_or_initialize(submodule) hook = AttentionProcessorSkipHook(output_fn, config.skip_attention_scores) registry.register_hook(hook, name) - elif config.skip_ff: + if config.skip_ff: for submodule_name, submodule in block.named_modules(): if isinstance(submodule, _FEEDFORWARD_CLASSES): logger.debug(f"Applying FeedForwardSkipHook to '{config.fqn}.{i}.{submodule_name}'") registry = HookRegistry.check_if_exists_or_initialize(submodule) hook = FeedForwardSkipHook() registry.register_hook(hook, name) - else: - raise ValueError( - "At least one of `skip_attention`, `skip_attention_scores`, or `skip_ff` must be set to True." - ) if not blocks_found: raise ValueError( diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py index 68d9d913bd3a..aed212c3f880 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py @@ -2267,6 +2267,7 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: ) in enumerate(zip( latents, prompt_embeds, add_time_ids, pooled_prompt_embeds, mask, masked_image_latents, ip_adapter_embeds )): + pipeline.guider.prepare_models(pipeline.unet) latents_i = pipeline.scheduler.scale_model_input(latents_i, t) # Prepare for inpainting @@ -2670,6 +2671,7 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: ) in enumerate(zip( latents, prompt_embeds, add_time_ids, pooled_prompt_embeds, mask, masked_image_latents, ip_adapter_embeds )): + pipeline.guider.prepare_models(pipeline.unet) latents_i = pipeline.scheduler.scale_model_input(latents_i, t) # Prepare for inpainting @@ -3085,6 +3087,7 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: ) in enumerate(zip( latents, prompt_embeds, add_time_ids, pooled_prompt_embeds, mask, masked_image_latents, ip_adapter_embeds )): + pipeline.guider.prepare_models(pipeline.unet) latents_i = pipeline.scheduler.scale_model_input(latents_i, t) # Prepare for inpainting From 1c1d1d52e0712028253d5196545a79cc41ef005d Mon Sep 17 00:00:00 2001 From: Aryan Date: Tue, 15 Apr 2025 08:26:43 +0200 Subject: [PATCH 12/24] remove debug print --- src/diffusers/hooks/layer_skip.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/src/diffusers/hooks/layer_skip.py b/src/diffusers/hooks/layer_skip.py index e28322ac4865..2a906c315c16 100644 --- a/src/diffusers/hooks/layer_skip.py +++ b/src/diffusers/hooks/layer_skip.py @@ -80,17 +80,14 @@ def __init__(self, skip_processor_output_fn: Callable, skip_attention_scores: bo def new_forward(self, module: torch.nn.Module, *args, **kwargs): if self.skip_attention_scores: - print("Skipping attention scores") with AttentionScoreSkipFunctionMode(): return self.fn_ref.original_forward(*args, **kwargs) else: - print("Skipping attention processor output") return self.skip_processor_output_fn(module, *args, **kwargs) class FeedForwardSkipHook(ModelHook): def new_forward(self, module: torch.nn.Module, *args, **kwargs): - print("Skipping feed-forward block") output = kwargs.get("hidden_states", None) if output is None: output = kwargs.get("x", None) @@ -105,7 +102,6 @@ def initialize_hook(self, module): return module def new_forward(self, module: torch.nn.Module, *args, **kwargs): - print("Skipping transformer block") return self._metadata.skip_block_output_fn(module, *args, **kwargs) From ba579f4da90b5d6836d80f8b9e61a6df35cc3bad Mon Sep 17 00:00:00 2001 From: Aryan Date: Tue, 15 Apr 2025 09:21:25 +0200 Subject: [PATCH 13/24] autoguidance --- src/diffusers/__init__.py | 2 + src/diffusers/guiders/__init__.py | 1 + src/diffusers/guiders/auto_guidance.py | 172 +++++++++++++++++++++++++ src/diffusers/hooks/layer_skip.py | 72 ++++++++--- 4 files changed, 232 insertions(+), 15 deletions(-) create mode 100644 src/diffusers/guiders/auto_guidance.py diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index b2e24614b97a..0672e1035628 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -133,6 +133,7 @@ _import_structure["guiders"].extend( [ "AdaptiveProjectedGuidance", + "AutoGuidance", "ClassifierFreeGuidance", "ClassifierFreeZeroStarGuidance", "SkipLayerGuidance", @@ -724,6 +725,7 @@ else: from .guiders import ( AdaptiveProjectedGuidance, + AutoGuidance, ClassifierFreeGuidance, ClassifierFreeZeroStarGuidance, SkipLayerGuidance, diff --git a/src/diffusers/guiders/__init__.py b/src/diffusers/guiders/__init__.py index 52b9176f5fb7..c23e48578ebc 100644 --- a/src/diffusers/guiders/__init__.py +++ b/src/diffusers/guiders/__init__.py @@ -19,6 +19,7 @@ if is_torch_available(): from .adaptive_projected_guidance import AdaptiveProjectedGuidance + from .auto_guidance import AutoGuidance from .classifier_free_guidance import ClassifierFreeGuidance from .classifier_free_zero_star_guidance import ClassifierFreeZeroStarGuidance from .skip_layer_guidance import SkipLayerGuidance diff --git a/src/diffusers/guiders/auto_guidance.py b/src/diffusers/guiders/auto_guidance.py new file mode 100644 index 000000000000..8c759f497307 --- /dev/null +++ b/src/diffusers/guiders/auto_guidance.py @@ -0,0 +1,172 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from typing import List, Optional, Tuple, Union + +import torch + +from ..hooks import HookRegistry, LayerSkipConfig +from ..hooks.layer_skip import _apply_layer_skip_hook +from .guider_utils import BaseGuidance, rescale_noise_cfg, _default_prepare_inputs + + +class AutoGuidance(BaseGuidance): + """ + AutoGuidance: https://huggingface.co/papers/2406.02507 + + Args: + guidance_scale (`float`, defaults to `7.5`): + The scale parameter for classifier-free guidance. Higher values result in stronger conditioning on the text + prompt, while lower values allow for more freedom in generation. Higher values may lead to saturation and + deterioration of image quality. + auto_guidance_layers (`int` or `List[int]`, *optional*): + The layer indices to apply skip layer guidance to. Can be a single integer or a list of integers. If not + provided, `skip_layer_config` must be provided. + auto_guidance_config (`LayerSkipConfig` or `List[LayerSkipConfig]`, *optional*): + The configuration for the skip layer guidance. Can be a single `LayerSkipConfig` or a list of + `LayerSkipConfig`. If not provided, `skip_layer_guidance_layers` must be provided. + dropout (`float`, *optional*): + The dropout probability for autoguidance on the enabled skip layers (either with `auto_guidance_layers` or + `auto_guidance_config`). If not provided, the dropout probability will be set to 1.0. + guidance_rescale (`float`, defaults to `0.0`): + The rescale factor applied to the noise predictions. This is used to improve image quality and fix + overexposure. Based on Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are + Flawed](https://huggingface.co/papers/2305.08891). + use_original_formulation (`bool`, defaults to `False`): + Whether to use the original formulation of classifier-free guidance as proposed in the paper. By default, + we use the diffusers-native implementation that has been in the codebase for a long time. See + [~guiders.classifier_free_guidance.ClassifierFreeGuidance] for more details. + start (`float`, defaults to `0.0`): + The fraction of the total number of denoising steps after which guidance starts. + stop (`float`, defaults to `1.0`): + The fraction of the total number of denoising steps after which guidance stops. + """ + + _input_predictions = ["pred_cond", "pred_uncond"] + + def __init__( + self, + guidance_scale: float = 7.5, + auto_guidance_layers: Optional[Union[int, List[int]]] = None, + auto_guidance_config: Union[LayerSkipConfig, List[LayerSkipConfig]] = None, + dropout: Optional[float] = None, + guidance_rescale: float = 0.0, + use_original_formulation: bool = False, + start: float = 0.0, + stop: float = 1.0, + ): + super().__init__(start, stop) + + self.guidance_scale = guidance_scale + self.auto_guidance_layers = auto_guidance_layers + self.auto_guidance_config = auto_guidance_config + self.dropout = dropout + self.guidance_rescale = guidance_rescale + self.use_original_formulation = use_original_formulation + + if auto_guidance_layers is None and auto_guidance_config is None: + raise ValueError( + "Either `auto_guidance_layers` or `auto_guidance_config` must be provided to enable Skip Layer Guidance." + ) + if auto_guidance_layers is not None and auto_guidance_config is not None: + raise ValueError("Only one of `auto_guidance_layers` or `auto_guidance_config` can be provided.") + if (dropout is None and auto_guidance_layers is not None) or (dropout is not None and auto_guidance_layers is None): + raise ValueError("`dropout` must be provided if `auto_guidance_layers` is provided.") + + if auto_guidance_layers is not None: + if isinstance(auto_guidance_layers, int): + auto_guidance_layers = [auto_guidance_layers] + if not isinstance(auto_guidance_layers, list): + raise ValueError( + f"Expected `auto_guidance_layers` to be an int or a list of ints, but got {type(auto_guidance_layers)}." + ) + auto_guidance_config = [LayerSkipConfig(layer, fqn="auto", dropout=dropout) for layer in auto_guidance_layers] + + if isinstance(auto_guidance_config, LayerSkipConfig): + auto_guidance_config = [auto_guidance_config] + + if not isinstance(auto_guidance_config, list): + raise ValueError( + f"Expected `auto_guidance_config` to be a LayerSkipConfig or a list of LayerSkipConfig, but got {type(auto_guidance_config)}." + ) + + self.auto_guidance_config = auto_guidance_config + self._auto_guidance_hook_names = [f"AutoGuidance_{i}" for i in range(len(self.auto_guidance_config))] + + def prepare_models(self, denoiser: torch.nn.Module) -> None: + if self._is_ag_enabled() and self.is_unconditional: + for name, config in zip(self._auto_guidance_hook_names, self.auto_guidance_config): + _apply_layer_skip_hook(denoiser, config, name=name) + + def prepare_inputs(self, denoiser: torch.nn.Module, *args: Union[Tuple[torch.Tensor], List[torch.Tensor]]) -> Tuple[List[torch.Tensor], ...]: + return _default_prepare_inputs(denoiser, self.num_conditions, *args) + + def prepare_outputs(self, denoiser: torch.nn.Module, pred: torch.Tensor) -> None: + self._num_outputs_prepared += 1 + if self._num_outputs_prepared > self.num_conditions: + raise ValueError(f"Expected {self.num_conditions} outputs, but prepare_outputs called more times.") + key = self._input_predictions[self._num_outputs_prepared - 1] + self._preds[key] = pred + + if key == "pred_uncond": + # If we are in AutoGuidance unconditional inference mode, we need to remove the hooks after inference + registry = HookRegistry.check_if_exists_or_initialize(denoiser) + # Remove the hooks after inference + for hook_name in self._auto_guidance_hook_names: + registry.remove_hook(hook_name, recurse=True) + + def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> torch.Tensor: + pred = None + + if not self._is_ag_enabled(): + pred = pred_cond + else: + shift = pred_cond - pred_uncond + pred = pred_cond if self.use_original_formulation else pred_uncond + pred = pred + self.guidance_scale * shift + + if self.guidance_rescale > 0.0: + pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale) + + return pred + + @property + def is_conditional(self) -> bool: + return self._num_outputs_prepared == 0 + + @property + def num_conditions(self) -> int: + num_conditions = 1 + if self._is_ag_enabled(): + num_conditions += 1 + return num_conditions + + def _is_ag_enabled(self) -> bool: + if not self._enabled: + return False + + is_within_range = True + if self._num_inference_steps is not None: + skip_start_step = int(self._start * self._num_inference_steps) + skip_stop_step = int(self._stop * self._num_inference_steps) + is_within_range = skip_start_step <= self._step < skip_stop_step + + is_close = False + if self.use_original_formulation: + is_close = math.isclose(self.guidance_scale, 0.0) + else: + is_close = math.isclose(self.guidance_scale, 1.0) + + return is_within_range and not is_close diff --git a/src/diffusers/hooks/layer_skip.py b/src/diffusers/hooks/layer_skip.py index 2a906c315c16..14b1cf492d0e 100644 --- a/src/diffusers/hooks/layer_skip.py +++ b/src/diffusers/hooks/layer_skip.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import math from dataclasses import dataclass from typing import Callable, List, Optional @@ -47,8 +48,12 @@ class LayerSkipConfig: skip_ff (`bool`, defaults to `True`): Whether to skip feed-forward blocks. skip_attention_scores (`bool`, defaults to `False`): - Whether to skip attention score computation in the attention blocks. This is equivalent to using - `value` projections as the output of scaled dot product attention. + Whether to skip attention score computation in the attention blocks. This is equivalent to using `value` + projections as the output of scaled dot product attention. + dropout (`float`, defaults to `1.0`): + The dropout probability for dropping the outputs of the skipped layers. By default, this is set to `1.0`, + meaning that the outputs of the skipped layers are completely ignored. If set to `0.0`, the outputs of the + skipped layers are fully retained, which is equivalent to not skipping any layers. """ indices: List[int] @@ -56,6 +61,11 @@ class LayerSkipConfig: skip_attention: bool = True skip_attention_scores: bool = False skip_ff: bool = True + dropout: float = 1.0 + + def __post_init__(self): + if not (0 <= self.dropout <= 1): + raise ValueError(f"Expected `dropout` to be between 0.0 and 1.0, but got {self.dropout}.") class AttentionScoreSkipFunctionMode(torch.overrides.TorchFunctionMode): @@ -74,36 +84,62 @@ def __torch_function__(self, func, types, args=(), kwargs=None): class AttentionProcessorSkipHook(ModelHook): - def __init__(self, skip_processor_output_fn: Callable, skip_attention_scores: bool = False): + def __init__(self, skip_processor_output_fn: Callable, skip_attention_scores: bool = False, dropout: float = 1.0): self.skip_processor_output_fn = skip_processor_output_fn self.skip_attention_scores = skip_attention_scores + self.dropout = dropout def new_forward(self, module: torch.nn.Module, *args, **kwargs): if self.skip_attention_scores: + if math.isclose(self.dropout, 1.0): + raise ValueError( + "Cannot set `skip_attention_scores` to True when `dropout` is not 1.0. Please set `dropout` to 1.0." + ) with AttentionScoreSkipFunctionMode(): - return self.fn_ref.original_forward(*args, **kwargs) + output = self.fn_ref.original_forward(*args, **kwargs) else: - return self.skip_processor_output_fn(module, *args, **kwargs) + if math.isclose(self.dropout, 1.0): + output = self.skip_processor_output_fn(module, *args, **kwargs) + else: + output = self.fn_ref.original_forward(*args, **kwargs) + output = torch.nn.functional.dropout(output, p=self.dropout) + return output class FeedForwardSkipHook(ModelHook): + def __init__(self, dropout: float): + super().__init__() + self.dropout = dropout + def new_forward(self, module: torch.nn.Module, *args, **kwargs): - output = kwargs.get("hidden_states", None) - if output is None: - output = kwargs.get("x", None) - if output is None and len(args) > 0: - output = args[0] + if math.isclose(self.dropout, 1.0): + output = kwargs.get("hidden_states", None) + if output is None: + output = kwargs.get("x", None) + if output is None and len(args) > 0: + output = args[0] + else: + output = self.fn_ref.original_forward(*args, **kwargs) + output = torch.nn.functional.dropout(output, p=self.dropout) return output class TransformerBlockSkipHook(ModelHook): + def __init__(self, dropout: float): + super().__init__() + self.dropout = dropout + def initialize_hook(self, module): self._metadata = TransformerBlockRegistry.get(unwrap_module(module).__class__) return module def new_forward(self, module: torch.nn.Module, *args, **kwargs): - return self._metadata.skip_block_output_fn(module, *args, **kwargs) - + if math.isclose(self.dropout, 1.0): + output = self._metadata.skip_block_output_fn(module, *args, **kwargs) + else: + output = self.fn_ref.original_forward(*args, **kwargs) + output = torch.nn.functional.dropout(output, p=self.dropout) + return output def apply_layer_skip(module: torch.nn.Module, config: LayerSkipConfig) -> None: r""" @@ -132,6 +168,8 @@ def _apply_layer_skip_hook(module: torch.nn.Module, config: LayerSkipConfig, nam if config.skip_attention and config.skip_attention_scores: raise ValueError("Cannot set both `skip_attention` and `skip_attention_scores` to True. Please choose one.") + if not math.isclose(config.dropout, 1.0) and config.skip_attention_scores: + raise ValueError("Cannot set `skip_attention_scores` to True when `dropout` is not 1.0. Please set `dropout` to 1.0.") if config.fqn == "auto": for identifier in _ALL_TRANSFORMER_BLOCK_IDENTIFIERS: @@ -157,26 +195,30 @@ def _apply_layer_skip_hook(module: torch.nn.Module, config: LayerSkipConfig, nam for i, block in enumerate(transformer_blocks): if i not in config.indices: continue + blocks_found = True + if config.skip_attention and config.skip_ff: logger.debug(f"Applying TransformerBlockSkipHook to '{config.fqn}.{i}'") registry = HookRegistry.check_if_exists_or_initialize(block) - hook = TransformerBlockSkipHook() + hook = TransformerBlockSkipHook(config.dropout) registry.register_hook(hook, name) + elif config.skip_attention or config.skip_attention_scores: for submodule_name, submodule in block.named_modules(): if isinstance(submodule, _ATTENTION_CLASSES) and not submodule.is_cross_attention: logger.debug(f"Applying AttentionProcessorSkipHook to '{config.fqn}.{i}.{submodule_name}'") output_fn = AttentionProcessorRegistry.get(submodule.processor.__class__).skip_processor_output_fn registry = HookRegistry.check_if_exists_or_initialize(submodule) - hook = AttentionProcessorSkipHook(output_fn, config.skip_attention_scores) + hook = AttentionProcessorSkipHook(output_fn, config.skip_attention_scores, config.dropout) registry.register_hook(hook, name) + if config.skip_ff: for submodule_name, submodule in block.named_modules(): if isinstance(submodule, _FEEDFORWARD_CLASSES): logger.debug(f"Applying FeedForwardSkipHook to '{config.fqn}.{i}.{submodule_name}'") registry = HookRegistry.check_if_exists_or_initialize(submodule) - hook = FeedForwardSkipHook() + hook = FeedForwardSkipHook(config.dropout) registry.register_hook(hook, name) if not blocks_found: From 720783e508c9f76ed4ecd05763db96a989a0ec86 Mon Sep 17 00:00:00 2001 From: Aryan Date: Tue, 15 Apr 2025 21:13:13 +0200 Subject: [PATCH 14/24] smoothed energy guidance --- src/diffusers/__init__.py | 4 + src/diffusers/guiders/__init__.py | 1 + src/diffusers/guiders/guider_utils.py | 4 +- src/diffusers/guiders/skip_layer_guidance.py | 5 +- .../guiders/smoothed_energy_guidance.py | 250 ++++++++++++++++++ src/diffusers/hooks/__init__.py | 1 + src/diffusers/hooks/_common.py | 11 + src/diffusers/hooks/layer_skip.py | 16 +- .../hooks/smoothed_energy_guidance_utils.py | 148 +++++++++++ .../pipeline_stable_diffusion_xl_modular.py | 12 +- 10 files changed, 431 insertions(+), 21 deletions(-) create mode 100644 src/diffusers/guiders/smoothed_energy_guidance.py create mode 100644 src/diffusers/hooks/smoothed_energy_guidance_utils.py diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 0672e1035628..67d8ae9f79ac 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -137,6 +137,7 @@ "ClassifierFreeGuidance", "ClassifierFreeZeroStarGuidance", "SkipLayerGuidance", + "SmoothedEnergyGuidance", ] ) _import_structure["hooks"].extend( @@ -145,6 +146,7 @@ "HookRegistry", "PyramidAttentionBroadcastConfig", "LayerSkipConfig", + "SmoothedEnergyGuidanceConfig", "apply_faster_cache", "apply_layer_skip", "apply_pyramid_attention_broadcast", @@ -729,12 +731,14 @@ ClassifierFreeGuidance, ClassifierFreeZeroStarGuidance, SkipLayerGuidance, + SmoothedEnergyGuidance, ) from .hooks import ( FasterCacheConfig, HookRegistry, LayerSkipConfig, PyramidAttentionBroadcastConfig, + SmoothedEnergyGuidanceConfig, apply_layer_skip, apply_faster_cache, apply_pyramid_attention_broadcast, diff --git a/src/diffusers/guiders/__init__.py b/src/diffusers/guiders/__init__.py index c23e48578ebc..7b88d61c67b3 100644 --- a/src/diffusers/guiders/__init__.py +++ b/src/diffusers/guiders/__init__.py @@ -23,5 +23,6 @@ from .classifier_free_guidance import ClassifierFreeGuidance from .classifier_free_zero_star_guidance import ClassifierFreeZeroStarGuidance from .skip_layer_guidance import SkipLayerGuidance + from .smoothed_energy_guidance import SmoothedEnergyGuidance GuiderType = Union[AdaptiveProjectedGuidance, ClassifierFreeGuidance, ClassifierFreeZeroStarGuidance, SkipLayerGuidance] diff --git a/src/diffusers/guiders/guider_utils.py b/src/diffusers/guiders/guider_utils.py index 60859bf390f6..e7d22a50d0da 100644 --- a/src/diffusers/guiders/guider_utils.py +++ b/src/diffusers/guiders/guider_utils.py @@ -55,10 +55,10 @@ def __init__(self, start: float = 0.0, stop: float = 1.0): "`_input_predictions` must be a list of required prediction names for the guidance technique." ) - def force_disable(self): + def _force_disable(self): self._enabled = False - def force_enable(self): + def _force_enable(self): self._enabled = True def set_state(self, step: int, num_inference_steps: int, timestep: torch.LongTensor) -> None: diff --git a/src/diffusers/guiders/skip_layer_guidance.py b/src/diffusers/guiders/skip_layer_guidance.py index 3fbfd771eff9..64b2b8a73c1a 100644 --- a/src/diffusers/guiders/skip_layer_guidance.py +++ b/src/diffusers/guiders/skip_layer_guidance.py @@ -24,8 +24,9 @@ class SkipLayerGuidance(BaseGuidance): """ - Skip Layer Guidance (SLG): https://github.com/Stability-AI/sd3.5 Spatio-Temporal Guidance (STG): - https://huggingface.co/papers/2411.18664 + Skip Layer Guidance (SLG): https://github.com/Stability-AI/sd3.5 + + Spatio-Temporal Guidance (STG): https://huggingface.co/papers/2411.18664 SLG was introduced by StabilityAI for improving structure and anotomy coherence in generated images. It works by skipping the forward pass of specified transformer blocks during the denoising process on an additional conditional diff --git a/src/diffusers/guiders/smoothed_energy_guidance.py b/src/diffusers/guiders/smoothed_energy_guidance.py new file mode 100644 index 000000000000..bd2a61b894e7 --- /dev/null +++ b/src/diffusers/guiders/smoothed_energy_guidance.py @@ -0,0 +1,250 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from typing import List, Optional, Tuple, Union + +import torch + +from ..hooks import HookRegistry +from ..hooks.smoothed_energy_guidance_utils import SmoothedEnergyGuidanceConfig, _apply_smoothed_energy_guidance_hook +from .guider_utils import BaseGuidance, rescale_noise_cfg + + +class SmoothedEnergyGuidance(BaseGuidance): + """ + Smoothed Energy Guidance (SEG): https://huggingface.co/papers/2408.00760 + + Args: + guidance_scale (`float`, defaults to `7.5`): + The scale parameter for classifier-free guidance. Higher values result in stronger conditioning on the text + prompt, while lower values allow for more freedom in generation. Higher values may lead to saturation and + deterioration of image quality. + seg_guidance_scale (`float`, defaults to `3.0`): + The scale parameter for smoothed energy guidance. Anatomy and structure coherence may improve with higher + values, but it may also lead to overexposure and saturation. + seg_blur_sigma (`float`, defaults to `9999999.0`): + The amount by which we blur the attention weights. Setting this value greater than 9999.0 results in + infinite blur, which means uniform queries. Controlling it exponentially is empirically effective. + seg_blur_threshold_inf (`float`, defaults to `9999.0`): + The threshold above which the blur is considered infinite. + seg_guidance_start (`float`, defaults to `0.0`): + The fraction of the total number of denoising steps after which smoothed energy guidance starts. + seg_guidance_stop (`float`, defaults to `1.0`): + The fraction of the total number of denoising steps after which smoothed energy guidance stops. + seg_guidance_layers (`int` or `List[int]`, *optional*): + The layer indices to apply smoothed energy guidance to. Can be a single integer or a list of integers. If not + provided, `seg_guidance_config` must be provided. The recommended values are `[7, 8, 9]` for Stable Diffusion + 3.5 Medium. + seg_guidance_config (`SmoothedEnergyGuidanceConfig` or `List[SmoothedEnergyGuidanceConfig]`, *optional*): + The configuration for the smoothed energy layer guidance. Can be a single `SmoothedEnergyGuidanceConfig` or a list of + `SmoothedEnergyGuidanceConfig`. If not provided, `seg_guidance_layers` must be provided. + guidance_rescale (`float`, defaults to `0.0`): + The rescale factor applied to the noise predictions. This is used to improve image quality and fix + overexposure. Based on Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are + Flawed](https://huggingface.co/papers/2305.08891). + use_original_formulation (`bool`, defaults to `False`): + Whether to use the original formulation of classifier-free guidance as proposed in the paper. By default, + we use the diffusers-native implementation that has been in the codebase for a long time. See + [~guiders.classifier_free_guidance.ClassifierFreeGuidance] for more details. + start (`float`, defaults to `0.01`): + The fraction of the total number of denoising steps after which guidance starts. + stop (`float`, defaults to `0.2`): + The fraction of the total number of denoising steps after which guidance stops. + """ + + _input_predictions = ["pred_cond", "pred_uncond", "pred_cond_seg"] + + def __init__( + self, + guidance_scale: float = 7.5, + seg_guidance_scale: float = 2.8, + seg_blur_sigma: float = 9999999.0, + seg_blur_threshold_inf: float = 9999.0, + seg_guidance_start: float = 0.0, + seg_guidance_stop: float = 1.0, + seg_guidance_layers: Optional[Union[int, List[int]]] = None, + seg_guidance_config: Union[SmoothedEnergyGuidanceConfig, List[SmoothedEnergyGuidanceConfig]] = None, + guidance_rescale: float = 0.0, + use_original_formulation: bool = False, + start: float = 0.0, + stop: float = 1.0, + ): + super().__init__(start, stop) + + self.guidance_scale = guidance_scale + self.seg_guidance_scale = seg_guidance_scale + self.seg_blur_sigma = seg_blur_sigma + self.seg_blur_threshold_inf = seg_blur_threshold_inf + self.seg_guidance_start = seg_guidance_start + self.seg_guidance_stop = seg_guidance_stop + self.guidance_rescale = guidance_rescale + self.use_original_formulation = use_original_formulation + + if not (0.0 <= seg_guidance_start < 1.0): + raise ValueError( + f"Expected `seg_guidance_start` to be between 0.0 and 1.0, but got {seg_guidance_start}." + ) + if not (seg_guidance_start <= seg_guidance_stop <= 1.0): + raise ValueError( + f"Expected `seg_guidance_stop` to be between 0.0 and 1.0, but got {seg_guidance_stop}." + ) + + if seg_guidance_layers is None and seg_guidance_config is None: + raise ValueError( + "Either `seg_guidance_layers` or `seg_guidance_config` must be provided to enable Smoothed Energy Guidance." + ) + if seg_guidance_layers is not None and seg_guidance_config is not None: + raise ValueError("Only one of `seg_guidance_layers` or `seg_guidance_config` can be provided.") + + if seg_guidance_layers is not None: + if isinstance(seg_guidance_layers, int): + seg_guidance_layers = [seg_guidance_layers] + if not isinstance(seg_guidance_layers, list): + raise ValueError( + f"Expected `seg_guidance_layers` to be an int or a list of ints, but got {type(seg_guidance_layers)}." + ) + seg_guidance_config = [SmoothedEnergyGuidanceConfig(layer, fqn="auto") for layer in seg_guidance_layers] + + if isinstance(seg_guidance_config, SmoothedEnergyGuidanceConfig): + seg_guidance_config = [seg_guidance_config] + + if not isinstance(seg_guidance_config, list): + raise ValueError( + f"Expected `seg_guidance_config` to be a SmoothedEnergyGuidanceConfig or a list of SmoothedEnergyGuidanceConfig, but got {type(seg_guidance_config)}." + ) + + self.seg_guidance_config = seg_guidance_config + self._seg_layer_hook_names = [f"SmoothedEnergyGuidance_{i}" for i in range(len(self.seg_guidance_config))] + + def prepare_models(self, denoiser: torch.nn.Module) -> None: + if self._is_seg_enabled() and self.is_conditional and self._num_outputs_prepared > 0: + for name, config in zip(self._seg_layer_hook_names, self.seg_guidance_config): + _apply_smoothed_energy_guidance_hook(denoiser, config, self.seg_blur_sigma, name=name) + + def prepare_inputs(self, denoiser: torch.nn.Module, *args: Union[Tuple[torch.Tensor], List[torch.Tensor]]) -> Tuple[List[torch.Tensor], ...]: + num_conditions = self.num_conditions + list_of_inputs = [] + for arg in args: + if arg is None or isinstance(arg, torch.Tensor): + list_of_inputs.append([arg] * num_conditions) + elif isinstance(arg, (tuple, list)): + if len(arg) != 2: + raise ValueError( + f"Expected a tuple or list of length 2, but got {len(arg)} for argument {arg}. Please provide a tuple/list of length 2 " + f"with the first element being the conditional input and the second element being the unconditional input or None." + ) + if arg[1] is None: + # Only conditioning inputs for all batches + list_of_inputs.append([arg[0]] * num_conditions) + else: + list_of_inputs.append([arg[0], arg[1], arg[0]]) + else: + raise ValueError( + f"Expected a tensor, tuple, or list, but got {type(arg)} for argument {arg}. Please provide a tensor, tuple, or list." + ) + return tuple(list_of_inputs) + + def prepare_outputs(self, denoiser: torch.nn.Module, pred: torch.Tensor) -> None: + self._num_outputs_prepared += 1 + if self._num_outputs_prepared > self.num_conditions: + raise ValueError(f"Expected {self.num_conditions} outputs, but prepare_outputs called more times.") + key = self._input_predictions[self._num_outputs_prepared - 1] + if not self._is_cfg_enabled() and self._is_seg_enabled(): + # If we're predicting pred_cond and pred_cond_seg only, we need to set the key to pred_cond_seg + # to avoid writing into pred_uncond which is not used + if self._num_outputs_prepared == 2: + key = "pred_cond_seg" + self._preds[key] = pred + + if key == "pred_cond_seg": + # If we are in SLG mode, we need to remove the hooks after inference + registry = HookRegistry.check_if_exists_or_initialize(denoiser) + # Remove the hooks after inference + for hook_name in self._seg_layer_hook_names: + registry.remove_hook(hook_name, recurse=True) + + def forward( + self, + pred_cond: torch.Tensor, + pred_uncond: Optional[torch.Tensor] = None, + pred_cond_seg: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + pred = None + + if not self._is_cfg_enabled() and not self._is_seg_enabled(): + pred = pred_cond + elif not self._is_cfg_enabled(): + shift = pred_cond - pred_cond_seg + pred = pred_cond if self.use_original_formulation else pred_cond_seg + pred = pred + self.seg_guidance_scale * shift + elif not self._is_seg_enabled(): + shift = pred_cond - pred_uncond + pred = pred_cond if self.use_original_formulation else pred_uncond + pred = pred + self.guidance_scale * shift + else: + shift = pred_cond - pred_uncond + shift_seg = pred_cond - pred_cond_seg + pred = pred_cond if self.use_original_formulation else pred_uncond + pred = pred + self.guidance_scale * shift + self.seg_guidance_scale * shift_seg + + if self.guidance_rescale > 0.0: + pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale) + + return pred + + @property + def is_conditional(self) -> bool: + return self._num_outputs_prepared == 0 or self._num_outputs_prepared == 2 + + @property + def num_conditions(self) -> int: + num_conditions = 1 + if self._is_cfg_enabled(): + num_conditions += 1 + if self._is_seg_enabled(): + num_conditions += 1 + return num_conditions + + def _is_cfg_enabled(self) -> bool: + if not self._enabled: + return False + + is_within_range = True + if self._num_inference_steps is not None: + skip_start_step = int(self._start * self._num_inference_steps) + skip_stop_step = int(self._stop * self._num_inference_steps) + is_within_range = skip_start_step <= self._step < skip_stop_step + + is_close = False + if self.use_original_formulation: + is_close = math.isclose(self.guidance_scale, 0.0) + else: + is_close = math.isclose(self.guidance_scale, 1.0) + + return is_within_range and not is_close + + def _is_seg_enabled(self) -> bool: + if not self._enabled: + return False + + is_within_range = True + if self._num_inference_steps is not None: + skip_start_step = int(self.seg_guidance_start * self._num_inference_steps) + skip_stop_step = int(self.seg_guidance_stop * self._num_inference_steps) + is_within_range = skip_start_step < self._step < skip_stop_step + + is_zero = math.isclose(self.seg_guidance_scale, 0.0) + + return is_within_range and not is_zero diff --git a/src/diffusers/hooks/__init__.py b/src/diffusers/hooks/__init__.py index 142ff860371c..9d0e96e9e79e 100644 --- a/src/diffusers/hooks/__init__.py +++ b/src/diffusers/hooks/__init__.py @@ -8,3 +8,4 @@ from .layer_skip import LayerSkipConfig, apply_layer_skip from .layerwise_casting import apply_layerwise_casting, apply_layerwise_casting_hook from .pyramid_attention_broadcast import PyramidAttentionBroadcastConfig, apply_pyramid_attention_broadcast + from .smoothed_energy_guidance_utils import SmoothedEnergyGuidanceConfig diff --git a/src/diffusers/hooks/_common.py b/src/diffusers/hooks/_common.py index 6ea83dcbf6a7..3d9c99e8189f 100644 --- a/src/diffusers/hooks/_common.py +++ b/src/diffusers/hooks/_common.py @@ -12,6 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import Optional + +import torch + from ..models.attention import FeedForward, LuminaFeedForward from ..models.attention_processor import Attention, MochiAttention @@ -30,3 +34,10 @@ *_CROSS_TRANSFORMER_BLOCK_IDENTIFIERS, } ) + + +def _get_submodule_from_fqn(module: torch.nn.Module, fqn: str) -> Optional[torch.nn.Module]: + for submodule_name, submodule in module.named_modules(): + if submodule_name == fqn: + return submodule + return None diff --git a/src/diffusers/hooks/layer_skip.py b/src/diffusers/hooks/layer_skip.py index 14b1cf492d0e..45f42a1f0f86 100644 --- a/src/diffusers/hooks/layer_skip.py +++ b/src/diffusers/hooks/layer_skip.py @@ -20,7 +20,7 @@ from ..utils import get_logger from ..utils.torch_utils import unwrap_module -from ._common import _ALL_TRANSFORMER_BLOCK_IDENTIFIERS, _ATTENTION_CLASSES, _FEEDFORWARD_CLASSES +from ._common import _ALL_TRANSFORMER_BLOCK_IDENTIFIERS, _ATTENTION_CLASSES, _FEEDFORWARD_CLASSES, _get_submodule_from_fqn from ._helpers import AttentionProcessorRegistry, TransformerBlockRegistry from .hooks import HookRegistry, ModelHook @@ -66,12 +66,13 @@ class LayerSkipConfig: def __post_init__(self): if not (0 <= self.dropout <= 1): raise ValueError(f"Expected `dropout` to be between 0.0 and 1.0, but got {self.dropout}.") + if not math.isclose(self.dropout, 1.0) and self.skip_attention_scores: + raise ValueError( + "Cannot set `skip_attention_scores` to True when `dropout` is not 1.0. Please set `dropout` to 1.0." + ) class AttentionScoreSkipFunctionMode(torch.overrides.TorchFunctionMode): - def __init__(self) -> None: - super().__init__() - def __torch_function__(self, func, types, args=(), kwargs=None): if kwargs is None: kwargs = {} @@ -226,10 +227,3 @@ def _apply_layer_skip_hook(module: torch.nn.Module, config: LayerSkipConfig, nam f"Could not find any transformer blocks matching the provided indices {config.indices} and " f"fully qualified name '{config.fqn}'. Please check the indices and fqn for correctness." ) - - -def _get_submodule_from_fqn(module: torch.nn.Module, fqn: str) -> Optional[torch.nn.Module]: - for submodule_name, submodule in module.named_modules(): - if submodule_name == fqn: - return submodule - return None diff --git a/src/diffusers/hooks/smoothed_energy_guidance_utils.py b/src/diffusers/hooks/smoothed_energy_guidance_utils.py new file mode 100644 index 000000000000..20df0de048c7 --- /dev/null +++ b/src/diffusers/hooks/smoothed_energy_guidance_utils.py @@ -0,0 +1,148 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Tuple + +import torch +import torch.nn.functional as F + +from ..utils import get_logger +from ._common import _ATTENTION_CLASSES, _get_submodule_from_fqn +from .hooks import HookRegistry, ModelHook + + +logger = get_logger(__name__) # pylint: disable=invalid-name + +_SMOOTHED_ENERGY_GUIDANCE_HOOK = "smoothed_energy_guidance_hook" + + +@dataclass +class SmoothedEnergyGuidanceConfig: + r""" + Configuration for skipping internal transformer blocks when executing a transformer model. + + Args: + indices (`List[int]`): + The indices of the layer to skip. This is typically the first layer in the transformer block. + fqn (`str`, defaults to `"auto"`): + The fully qualified name identifying the stack of transformer blocks. Typically, this is + `transformer_blocks`, `single_transformer_blocks`, `blocks`, `layers`, or `temporal_transformer_blocks`. + For automatic detection, set this to `"auto"`. + "auto" only works on DiT models. For UNet models, you must provide the correct fqn. + _query_proj_identifiers (`List[str]`, defaults to `None`): + The identifiers for the query projection layers. Typically, these are `to_q`, `query`, or `q_proj`. + If `None`, `to_q` is used by default. + """ + + indices: List[int] + fqn: str = "auto" + _query_proj_identifiers: List[str] = None + + +class SmoothedEnergyGuidanceHook(ModelHook): + def __init__(self, blur_sigma: float = 1.0, blur_threshold_inf: float = 9999.9) -> None: + super().__init__() + self.blur_sigma = blur_sigma + self.blur_threshold_inf = blur_threshold_inf + + def post_forward(self, module: torch.nn.Module, output: torch.Tensor) -> torch.Tensor: + # Copied from https://github.com/SusungHong/SEG-SDXL/blob/cf8256d640d5373541cfea3b3b6caf93272cf986/pipeline_seg.py#L172C31-L172C102 + kernel_size = math.ceil(6 * self.blur_sigma) + 1 - math.ceil(6 * self.blur_sigma) % 2 + smoothed_output = _gaussian_blur_2d(output, kernel_size, self.blur_sigma, self.blur_threshold_inf) + return smoothed_output + + +def _apply_smoothed_energy_guidance_hook(module: torch.nn.Module, config: SmoothedEnergyGuidanceConfig, blur_sigma: float, name: Optional[str] = None) -> None: + name = name or _SMOOTHED_ENERGY_GUIDANCE_HOOK + + if config.fqn == "auto": + for identifier in _ALL_TRANSFORMER_BLOCK_IDENTIFIERS: + if hasattr(module, identifier): + config.fqn = identifier + break + else: + raise ValueError( + "Could not find a suitable identifier for the transformer blocks automatically. Please provide a valid " + "`fqn` (fully qualified name) that identifies a stack of transformer blocks." + ) + + if config._query_proj_identifiers is None: + config._query_proj_identifiers = ["to_q"] + + transformer_blocks = _get_submodule_from_fqn(module, config.fqn) + blocks_found = False + for i, block in enumerate(transformer_blocks): + if i not in config.indices: + continue + + blocks_found = True + + for submodule_name, submodule in block.named_modules(): + if not isinstance(submodule, _ATTENTION_CLASSES) or submodule.is_cross_attention: + continue + for identifier in config._query_proj_identifiers: + query_proj = getattr(submodule, identifier, None) + if query_proj is None or not isinstance(query_proj, torch.nn.Linear): + continue + logger.debug( + f"Registering smoothed energy guidance hook on {config.fqn}.{i}.{submodule_name}.{identifier}" + ) + registry = HookRegistry.check_if_exists_or_initialize(query_proj) + hook = SmoothedEnergyGuidanceHook(blur_sigma) + registry.register_hook(hook, name) + + if not blocks_found: + raise ValueError( + f"Could not find any transformer blocks matching the provided indices {config.indices} and " + f"fully qualified name '{config.fqn}'. Please check the indices and fqn for correctness." + ) + + +# Modified from https://github.com/SusungHong/SEG-SDXL/blob/cf8256d640d5373541cfea3b3b6caf93272cf986/pipeline_seg.py#L71 +def _gaussian_blur_2d(query: torch.Tensor, kernel_size: int, sigma: float, sigma_threshold_inf: float) -> torch.Tensor: + assert query.ndim == 3 + + is_inf = sigma > sigma_threshold_inf + batch_size, seq_len, embed_dim = query.shape + + seq_len_sqrt = int(math.sqrt(seq_len)) + num_square_tokens = seq_len_sqrt * seq_len_sqrt + query_slice = query[:, :num_square_tokens, :] + query_slice = query_slice.permute(0, 2, 1) + query_slice = query_slice.reshape(batch_size, embed_dim, seq_len_sqrt, seq_len_sqrt) + + if is_inf: + kernel_size = min(kernel_size, seq_len_sqrt - (seq_len_sqrt % 2 - 1)) + kernel_size_half = (kernel_size - 1) / 2 + + x = torch.linspace(-kernel_size_half, kernel_size_half, steps=kernel_size) + pdf = torch.exp(-0.5 * (x / sigma).pow(2)) + kernel1d = pdf / pdf.sum() + kernel1d = kernel1d.to(query) + kernel2d = torch.matmul(kernel1d[:, None], kernel1d[None, :]) + kernel2d = kernel2d.expand(embed_dim, 1, kernel2d.shape[0], kernel2d.shape[1]) + + padding = [kernel_size // 2, kernel_size // 2, kernel_size // 2, kernel_size // 2] + query_slice = F.pad(query_slice, padding, mode="reflect") + query_slice = F.conv2d(query_slice, kernel2d, groups=embed_dim) + else: + query[:] = query.mean(dim=(-2, -1), keepdim=True) + + query_slice = query_slice.reshape(batch_size, embed_dim, num_square_tokens) + query_slice = query_slice.permute(0, 2, 1) + query[:, :num_square_tokens, :] = query_slice + + return query diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py index aed212c3f880..76030a71535a 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py @@ -2231,9 +2231,9 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: data.num_channels_unet = pipeline.unet.config.in_channels data.disable_guidance = True if pipeline.unet.config.time_cond_proj_dim is not None else False if data.disable_guidance: - pipeline.guider.force_disable() + pipeline.guider._force_disable() else: - pipeline.guider.force_enable() + pipeline.guider._force_enable() # Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline data.extra_step_kwargs = self.prepare_extra_step_kwargs(pipeline, data.generator, data.eta) @@ -2626,9 +2626,9 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: # (2) Prepare conditional inputs for unet using the guider data.disable_guidance = True if pipeline.unet.config.time_cond_proj_dim is not None else False if data.disable_guidance: - pipeline.guider.force_disable() + pipeline.guider._force_disable() else: - pipeline.guider.force_enable() + pipeline.guider._force_enable() # (4) Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline data.extra_step_kwargs = self.prepare_extra_step_kwargs(pipeline, data.generator, data.eta) @@ -3039,9 +3039,9 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: # adding default guider arguments: disable_guidance, guidance_scale, guidance_rescale data.disable_guidance = True if pipeline.unet.config.time_cond_proj_dim is not None else False if data.disable_guidance: - pipeline.guider.force_disable() + pipeline.guider._force_disable() else: - pipeline.guider.force_enable() + pipeline.guider._force_enable() data.control_type = data.control_type.reshape(1, -1).to(data.device, dtype=data.prompt_embeds.dtype) repeat_by = data.batch_size * data.num_images_per_prompt // data.control_type.shape[0] From b9bcd469f13287afb9deaf1486714170d26b9999 Mon Sep 17 00:00:00 2001 From: Aryan Date: Tue, 15 Apr 2025 21:25:15 +0200 Subject: [PATCH 15/24] add note about seg --- src/diffusers/guiders/smoothed_energy_guidance.py | 3 +++ .../hooks/smoothed_energy_guidance_utils.py | 14 ++++++++++++-- 2 files changed, 15 insertions(+), 2 deletions(-) diff --git a/src/diffusers/guiders/smoothed_energy_guidance.py b/src/diffusers/guiders/smoothed_energy_guidance.py index bd2a61b894e7..2328aa82ec98 100644 --- a/src/diffusers/guiders/smoothed_energy_guidance.py +++ b/src/diffusers/guiders/smoothed_energy_guidance.py @@ -25,6 +25,9 @@ class SmoothedEnergyGuidance(BaseGuidance): """ Smoothed Energy Guidance (SEG): https://huggingface.co/papers/2408.00760 + + SEG is only supported as an experimental prototype feature for now, so the implementation may be modified + in the future without warning or guarantee of reproducibility. Args: guidance_scale (`float`, defaults to `7.5`): diff --git a/src/diffusers/hooks/smoothed_energy_guidance_utils.py b/src/diffusers/hooks/smoothed_energy_guidance_utils.py index 20df0de048c7..f0366e29887f 100644 --- a/src/diffusers/hooks/smoothed_energy_guidance_utils.py +++ b/src/diffusers/hooks/smoothed_energy_guidance_utils.py @@ -113,6 +113,16 @@ def _apply_smoothed_energy_guidance_hook(module: torch.nn.Module, config: Smooth # Modified from https://github.com/SusungHong/SEG-SDXL/blob/cf8256d640d5373541cfea3b3b6caf93272cf986/pipeline_seg.py#L71 def _gaussian_blur_2d(query: torch.Tensor, kernel_size: int, sigma: float, sigma_threshold_inf: float) -> torch.Tensor: + """ + This implementation assumes that the input query is for visual (image/videos) tokens to apply the 2D gaussian + blur. However, some models use joint text-visual token attention for which this may not be suitable. Additionally, + this implementation also assumes that the visual tokens come from a square image/video. In practice, despite + these assumptions, applying the 2D square gaussian blur on the query projections generates reasonable results + for Smoothed Energy Guidance. + + SEG is only supported as an experimental prototype feature for now, so the implementation may be modified + in the future without warning or guarantee of reproducibility. + """ assert query.ndim == 3 is_inf = sigma > sigma_threshold_inf @@ -139,10 +149,10 @@ def _gaussian_blur_2d(query: torch.Tensor, kernel_size: int, sigma: float, sigma query_slice = F.pad(query_slice, padding, mode="reflect") query_slice = F.conv2d(query_slice, kernel2d, groups=embed_dim) else: - query[:] = query.mean(dim=(-2, -1), keepdim=True) + query_slice[:] = query_slice.mean(dim=(-2, -1), keepdim=True) query_slice = query_slice.reshape(batch_size, embed_dim, num_square_tokens) query_slice = query_slice.permute(0, 2, 1) - query[:, :num_square_tokens, :] = query_slice + query[:, :num_square_tokens, :] = query_slice.clone() return query From 2dc673a213dba107aa7463a73295421ff1d30218 Mon Sep 17 00:00:00 2001 From: Aryan Date: Wed, 16 Apr 2025 06:38:03 +0200 Subject: [PATCH 16/24] tangential cfg --- src/diffusers/__init__.py | 2 + src/diffusers/guiders/__init__.py | 3 +- .../guiders/adaptive_projected_guidance.py | 7 +- .../guiders/smoothed_energy_guidance.py | 5 +- .../tangential_classifier_free_guidance.py | 133 ++++++++++++++++++ src/diffusers/hooks/layer_skip.py | 2 +- 6 files changed, 148 insertions(+), 4 deletions(-) create mode 100644 src/diffusers/guiders/tangential_classifier_free_guidance.py diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 67d8ae9f79ac..a4f55acf8b70 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -138,6 +138,7 @@ "ClassifierFreeZeroStarGuidance", "SkipLayerGuidance", "SmoothedEnergyGuidance", + "TangentialClassifierFreeGuidance", ] ) _import_structure["hooks"].extend( @@ -732,6 +733,7 @@ ClassifierFreeZeroStarGuidance, SkipLayerGuidance, SmoothedEnergyGuidance, + TangentialClassifierFreeGuidance, ) from .hooks import ( FasterCacheConfig, diff --git a/src/diffusers/guiders/__init__.py b/src/diffusers/guiders/__init__.py index 7b88d61c67b3..3c1ee293382d 100644 --- a/src/diffusers/guiders/__init__.py +++ b/src/diffusers/guiders/__init__.py @@ -24,5 +24,6 @@ from .classifier_free_zero_star_guidance import ClassifierFreeZeroStarGuidance from .skip_layer_guidance import SkipLayerGuidance from .smoothed_energy_guidance import SmoothedEnergyGuidance + from .tangential_classifier_free_guidance import TangentialClassifierFreeGuidance - GuiderType = Union[AdaptiveProjectedGuidance, ClassifierFreeGuidance, ClassifierFreeZeroStarGuidance, SkipLayerGuidance] + GuiderType = Union[AdaptiveProjectedGuidance, AutoGuidance, ClassifierFreeGuidance, ClassifierFreeZeroStarGuidance, SkipLayerGuidance, SmoothedEnergyGuidance, TangentialClassifierFreeGuidance] diff --git a/src/diffusers/guiders/adaptive_projected_guidance.py b/src/diffusers/guiders/adaptive_projected_guidance.py index 45bd196860f4..05c186e58d9f 100644 --- a/src/diffusers/guiders/adaptive_projected_guidance.py +++ b/src/diffusers/guiders/adaptive_projected_guidance.py @@ -155,20 +155,25 @@ def normalized_guidance( ): diff = pred_cond - pred_uncond dim = [-i for i in range(1, len(diff.shape))] + if momentum_buffer is not None: momentum_buffer.update(diff) diff = momentum_buffer.running_average + if norm_threshold > 0: ones = torch.ones_like(diff) diff_norm = diff.norm(p=2, dim=dim, keepdim=True) scale_factor = torch.minimum(ones, norm_threshold / diff_norm) diff = diff * scale_factor + v0, v1 = diff.double(), pred_cond.double() v1 = torch.nn.functional.normalize(v1, dim=dim) v0_parallel = (v0 * v1).sum(dim=dim, keepdim=True) * v1 v0_orthogonal = v0 - v0_parallel diff_parallel, diff_orthogonal = v0_parallel.type_as(diff), v0_orthogonal.type_as(diff) normalized_update = diff_orthogonal + eta * diff_parallel + pred = pred_cond if use_original_formulation else pred_uncond - pred = pred + (guidance_scale - 1) * normalized_update + pred = pred + guidance_scale * normalized_update + return pred diff --git a/src/diffusers/guiders/smoothed_energy_guidance.py b/src/diffusers/guiders/smoothed_energy_guidance.py index 2328aa82ec98..906900856f4c 100644 --- a/src/diffusers/guiders/smoothed_energy_guidance.py +++ b/src/diffusers/guiders/smoothed_energy_guidance.py @@ -27,7 +27,10 @@ class SmoothedEnergyGuidance(BaseGuidance): Smoothed Energy Guidance (SEG): https://huggingface.co/papers/2408.00760 SEG is only supported as an experimental prototype feature for now, so the implementation may be modified - in the future without warning or guarantee of reproducibility. + in the future without warning or guarantee of reproducibility. This implementation assumes: + - Generated images are square (height == width) + - The model does not combine different modalities together (e.g., text and image latent streams are + not combined together such as Flux) Args: guidance_scale (`float`, defaults to `7.5`): diff --git a/src/diffusers/guiders/tangential_classifier_free_guidance.py b/src/diffusers/guiders/tangential_classifier_free_guidance.py new file mode 100644 index 000000000000..078d795baa68 --- /dev/null +++ b/src/diffusers/guiders/tangential_classifier_free_guidance.py @@ -0,0 +1,133 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from typing import Optional, Union, Tuple, List + +import torch + +from .guider_utils import BaseGuidance, rescale_noise_cfg, _default_prepare_inputs + + +class TangentialClassifierFreeGuidance(BaseGuidance): + """ + Tangential Classifier Free Guidance (TCFG): https://huggingface.co/papers/2503.18137 + + Args: + guidance_scale (`float`, defaults to `7.5`): + The scale parameter for classifier-free guidance. Higher values result in stronger conditioning on the text + prompt, while lower values allow for more freedom in generation. Higher values may lead to saturation and + deterioration of image quality. + guidance_rescale (`float`, defaults to `0.0`): + The rescale factor applied to the noise predictions. This is used to improve image quality and fix + overexposure. Based on Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are + Flawed](https://huggingface.co/papers/2305.08891). + use_original_formulation (`bool`, defaults to `False`): + Whether to use the original formulation of classifier-free guidance as proposed in the paper. By default, + we use the diffusers-native implementation that has been in the codebase for a long time. See + [~guiders.classifier_free_guidance.ClassifierFreeGuidance] for more details. + start (`float`, defaults to `0.0`): + The fraction of the total number of denoising steps after which guidance starts. + stop (`float`, defaults to `1.0`): + The fraction of the total number of denoising steps after which guidance stops. + """ + + _input_predictions = ["pred_cond", "pred_uncond"] + + def __init__( + self, + guidance_scale: float = 7.5, + guidance_rescale: float = 0.0, + use_original_formulation: bool = False, + start: float = 0.0, + stop: float = 1.0, + ): + super().__init__(start, stop) + + self.guidance_scale = guidance_scale + self.guidance_rescale = guidance_rescale + self.use_original_formulation = use_original_formulation + self.momentum_buffer = None + + def prepare_inputs(self, denoiser: torch.nn.Module, *args: Union[Tuple[torch.Tensor], List[torch.Tensor]]) -> Tuple[List[torch.Tensor], ...]: + return _default_prepare_inputs(denoiser, self.num_conditions, *args) + + def prepare_outputs(self, denoiser: torch.nn.Module, pred: torch.Tensor) -> None: + self._num_outputs_prepared += 1 + if self._num_outputs_prepared > self.num_conditions: + raise ValueError(f"Expected {self.num_conditions} outputs, but prepare_outputs called more times.") + key = self._input_predictions[self._num_outputs_prepared - 1] + self._preds[key] = pred + + def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> torch.Tensor: + pred = None + + if not self._is_tcfg_enabled(): + pred = pred_cond + else: + pred = normalized_guidance(pred_cond, pred_uncond, self.guidance_scale, self.use_original_formulation) + + if self.guidance_rescale > 0.0: + pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale) + + return pred + + @property + def is_conditional(self) -> bool: + return self._num_outputs_prepared == 0 + + @property + def num_conditions(self) -> int: + num_conditions = 1 + if self._is_tcfg_enabled(): + num_conditions += 1 + return num_conditions + + def _is_tcfg_enabled(self) -> bool: + if not self._enabled: + return False + + is_within_range = True + if self._num_inference_steps is not None: + skip_start_step = int(self._start * self._num_inference_steps) + skip_stop_step = int(self._stop * self._num_inference_steps) + is_within_range = skip_start_step <= self._step < skip_stop_step + + is_close = False + if self.use_original_formulation: + is_close = math.isclose(self.guidance_scale, 0.0) + else: + is_close = math.isclose(self.guidance_scale, 1.0) + + return is_within_range and not is_close + + +def normalized_guidance(pred_cond: torch.Tensor, pred_uncond: torch.Tensor, guidance_scale: float, use_original_formulation: bool = False) -> torch.Tensor: + cond_dtype = pred_cond.dtype + preds = torch.stack([pred_cond, pred_uncond], dim=1).float() + preds = preds.flatten(2) + U, S, Vh = torch.linalg.svd(preds, full_matrices=False) + Vh_modified = Vh.clone() + Vh_modified[:, 1] = 0 + + uncond_flat = pred_uncond.reshape(pred_uncond.size(0), 1, -1).float() + x_Vh = torch.matmul(uncond_flat, Vh.transpose(-2, -1)) + x_Vh_V = torch.matmul(x_Vh, Vh_modified) + pred_uncond = x_Vh_V.reshape(pred_uncond.shape).to(cond_dtype) + + pred = pred_cond if use_original_formulation else pred_uncond + shift = pred_cond - pred_uncond + pred = pred + guidance_scale * shift + + return pred diff --git a/src/diffusers/hooks/layer_skip.py b/src/diffusers/hooks/layer_skip.py index 45f42a1f0f86..c50d2b7471e4 100644 --- a/src/diffusers/hooks/layer_skip.py +++ b/src/diffusers/hooks/layer_skip.py @@ -92,7 +92,7 @@ def __init__(self, skip_processor_output_fn: Callable, skip_attention_scores: bo def new_forward(self, module: torch.nn.Module, *args, **kwargs): if self.skip_attention_scores: - if math.isclose(self.dropout, 1.0): + if not math.isclose(self.dropout, 1.0): raise ValueError( "Cannot set `skip_attention_scores` to True when `dropout` is not 1.0. Please set `dropout` to 1.0." ) From 77d8a285bf36960a2e0315725e322e2f0f1f6197 Mon Sep 17 00:00:00 2001 From: Aryan Date: Wed, 16 Apr 2025 08:08:00 +0200 Subject: [PATCH 17/24] cfg plus plus --- src/diffusers/__init__.py | 2 + src/diffusers/guiders/__init__.py | 1 + .../classifier_free_guidance_plus_plus.py | 117 ++++++++++++++++++ src/diffusers/guiders/guider_utils.py | 9 +- .../tangential_classifier_free_guidance.py | 1 - .../pipeline_stable_diffusion_xl_modular.py | 9 +- .../schedulers/scheduling_euler_discrete.py | 29 +++++ 7 files changed, 163 insertions(+), 5 deletions(-) create mode 100644 src/diffusers/guiders/classifier_free_guidance_plus_plus.py diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index a4f55acf8b70..424011961ab0 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -134,6 +134,7 @@ [ "AdaptiveProjectedGuidance", "AutoGuidance", + "CFGPlusPlusGuidance", "ClassifierFreeGuidance", "ClassifierFreeZeroStarGuidance", "SkipLayerGuidance", @@ -729,6 +730,7 @@ from .guiders import ( AdaptiveProjectedGuidance, AutoGuidance, + CFGPlusPlusGuidance, ClassifierFreeGuidance, ClassifierFreeZeroStarGuidance, SkipLayerGuidance, diff --git a/src/diffusers/guiders/__init__.py b/src/diffusers/guiders/__init__.py index 3c1ee293382d..56e95c92b697 100644 --- a/src/diffusers/guiders/__init__.py +++ b/src/diffusers/guiders/__init__.py @@ -20,6 +20,7 @@ if is_torch_available(): from .adaptive_projected_guidance import AdaptiveProjectedGuidance from .auto_guidance import AutoGuidance + from .classifier_free_guidance_plus_plus import CFGPlusPlusGuidance from .classifier_free_guidance import ClassifierFreeGuidance from .classifier_free_zero_star_guidance import ClassifierFreeZeroStarGuidance from .skip_layer_guidance import SkipLayerGuidance diff --git a/src/diffusers/guiders/classifier_free_guidance_plus_plus.py b/src/diffusers/guiders/classifier_free_guidance_plus_plus.py new file mode 100644 index 000000000000..516dbfa0e05f --- /dev/null +++ b/src/diffusers/guiders/classifier_free_guidance_plus_plus.py @@ -0,0 +1,117 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from typing import Optional, Union, Tuple, List + +import torch + +from .guider_utils import BaseGuidance, rescale_noise_cfg, _default_prepare_inputs + + +class CFGPlusPlusGuidance(BaseGuidance): + """ + CFG++: https://huggingface.co/papers/2406.08070 + + Args: + guidance_scale (`float`, defaults to `0.7`): + The scale parameter for classifier-free guidance. Higher values result in stronger conditioning on the text + prompt, while lower values allow for more freedom in generation. Higher values may lead to saturation and + deterioration of image quality. + guidance_rescale (`float`, defaults to `0.0`): + The rescale factor applied to the noise predictions. This is used to improve image quality and fix + overexposure. Based on Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are + Flawed](https://huggingface.co/papers/2305.08891). + use_original_formulation (`bool`, defaults to `False`): + Whether to use the original formulation of classifier-free guidance as proposed in the paper. By default, + we use the diffusers-native implementation that has been in the codebase for a long time. See + [~guiders.classifier_free_guidance.ClassifierFreeGuidance] for more details. + start (`float`, defaults to `0.0`): + The fraction of the total number of denoising steps after which guidance starts. + stop (`float`, defaults to `1.0`): + The fraction of the total number of denoising steps after which guidance stops. + """ + + _input_predictions = ["pred_cond", "pred_uncond"] + + def __init__( + self, + guidance_scale: float = 0.7, + guidance_rescale: float = 0.0, + use_original_formulation: bool = False, + start: float = 0.0, + stop: float = 1.0, + ): + super().__init__(start, stop) + + self.guidance_scale = guidance_scale + self.guidance_rescale = guidance_rescale + self.use_original_formulation = use_original_formulation + + def prepare_inputs(self, denoiser: torch.nn.Module, *args: Union[Tuple[torch.Tensor], List[torch.Tensor]]) -> Tuple[List[torch.Tensor], ...]: + return _default_prepare_inputs(denoiser, self.num_conditions, *args) + + def prepare_outputs(self, denoiser: torch.nn.Module, pred: torch.Tensor) -> None: + self._num_outputs_prepared += 1 + if self._num_outputs_prepared > self.num_conditions: + raise ValueError(f"Expected {self.num_conditions} outputs, but prepare_outputs called more times.") + key = self._input_predictions[self._num_outputs_prepared - 1] + self._preds[key] = pred + + def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> torch.Tensor: + pred = None + + if not self._is_cfgpp_enabled(): + pred = pred_cond + else: + shift = pred_cond - pred_uncond + pred = pred_cond if self.use_original_formulation else pred_uncond + pred = pred + self.guidance_scale * shift + + if self.guidance_rescale > 0.0: + pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale) + + return pred + + def post_scheduler_step(self, pred: torch.Tensor) -> torch.Tensor: + if self._is_cfgpp_enabled(): + # TODO(aryan): this probably only makes sense for EulerDiscreteScheduler. Look into the others later! + pred_cond = self._preds["pred_cond"] + pred_uncond = self._preds["pred_uncond"] + diff = pred_uncond - pred_cond + pred = pred + diff * self.guidance_scale * self._sigma_next + return pred + + @property + def is_conditional(self) -> bool: + return self._num_outputs_prepared == 0 + + @property + def num_conditions(self) -> int: + num_conditions = 1 + if self._is_cfgpp_enabled(): + num_conditions += 1 + return num_conditions + + def _is_cfgpp_enabled(self) -> bool: + if not self._enabled: + return False + + is_within_range = True + if self._num_inference_steps is not None: + skip_start_step = int(self._start * self._num_inference_steps) + skip_stop_step = int(self._stop * self._num_inference_steps) + is_within_range = skip_start_step <= self._step < skip_stop_step + + return is_within_range diff --git a/src/diffusers/guiders/guider_utils.py b/src/diffusers/guiders/guider_utils.py index e7d22a50d0da..f51452ed0cee 100644 --- a/src/diffusers/guiders/guider_utils.py +++ b/src/diffusers/guiders/guider_utils.py @@ -37,6 +37,8 @@ def __init__(self, start: float = 0.0, stop: float = 1.0): self._step: int = None self._num_inference_steps: int = None self._timestep: torch.LongTensor = None + self._sigma: torch.Tensor = None + self._sigma_next: torch.Tensor = None self._preds: Dict[str, torch.Tensor] = {} self._num_outputs_prepared: int = 0 self._enabled = True @@ -61,10 +63,12 @@ def _force_disable(self): def _force_enable(self): self._enabled = True - def set_state(self, step: int, num_inference_steps: int, timestep: torch.LongTensor) -> None: + def set_state(self, step: int, num_inference_steps: int, timestep: torch.LongTensor, sigma: torch.Tensor, sigma_next: torch.Tensor) -> None: self._step = step self._num_inference_steps = num_inference_steps self._timestep = timestep + self._sigma = sigma + self._sigma_next = sigma_next self._preds = {} self._num_outputs_prepared = 0 @@ -91,6 +95,9 @@ def __call__(self, **kwargs) -> Any: def forward(self, *args, **kwargs) -> Any: raise NotImplementedError("BaseGuidance::forward must be implemented in subclasses.") + def post_scheduler_step(self, pred: torch.Tensor) -> torch.Tensor: + return pred + @property def is_conditional(self) -> bool: raise NotImplementedError("BaseGuidance::is_conditional must be implemented in subclasses.") diff --git a/src/diffusers/guiders/tangential_classifier_free_guidance.py b/src/diffusers/guiders/tangential_classifier_free_guidance.py index 078d795baa68..7529114bfd6f 100644 --- a/src/diffusers/guiders/tangential_classifier_free_guidance.py +++ b/src/diffusers/guiders/tangential_classifier_free_guidance.py @@ -58,7 +58,6 @@ def __init__( self.guidance_scale = guidance_scale self.guidance_rescale = guidance_rescale self.use_original_formulation = use_original_formulation - self.momentum_buffer = None def prepare_inputs(self, denoiser: torch.nn.Module, *args: Union[Tuple[torch.Tensor], List[torch.Tensor]]) -> Tuple[List[torch.Tensor], ...]: return _default_prepare_inputs(denoiser, self.num_conditions, *args) diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py index 76030a71535a..8e0ea4545f29 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py @@ -2241,7 +2241,7 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: with pipeline.progress_bar(total=data.num_inference_steps) as progress_bar: for i, t in enumerate(data.timesteps): - pipeline.guider.set_state(step=i, num_inference_steps=data.num_inference_steps, timestep=t) + pipeline.guider.set_state(step=i, num_inference_steps=data.num_inference_steps, timestep=t, sigma=pipeline.scheduler.sigmas[i], sigma_next=pipeline.scheduler.sigmas[i + 1]) ( latents, @@ -2301,6 +2301,7 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: # Perform scheduler step using the predicted output data.latents_dtype = data.latents.dtype data.latents = pipeline.scheduler.step(data.noise_pred, t, data.latents, **data.extra_step_kwargs, return_dict=False)[0] + data.latents = pipeline.guider.post_scheduler_step(data.latents) if data.latents.dtype != data.latents_dtype: if torch.backends.mps.is_available(): @@ -2637,7 +2638,7 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: # (5) Denoise loop with pipeline.progress_bar(total=data.num_inference_steps) as progress_bar: for i, t in enumerate(data.timesteps): - pipeline.guider.set_state(step=i, num_inference_steps=data.num_inference_steps, timestep=t) + pipeline.guider.set_state(step=i, num_inference_steps=data.num_inference_steps, timestep=t, sigma=pipeline.scheduler.sigmas[i], sigma_next=pipeline.scheduler.sigmas[i + 1]) ( latents, @@ -2730,6 +2731,7 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: # Perform scheduler step using the predicted output data.latents_dtype = data.latents.dtype data.latents = pipeline.scheduler.step(data.noise_pred, t, data.latents, **data.extra_step_kwargs, return_dict=False)[0] + data.latents = pipeline.guider.post_scheduler_step(data.latents) if data.latents.dtype != data.latents_dtype: if torch.backends.mps.is_available(): @@ -3053,7 +3055,7 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: with pipeline.progress_bar(total=data.num_inference_steps) as progress_bar: for i, t in enumerate(data.timesteps): - pipeline.guider.set_state(step=i, num_inference_steps=data.num_inference_steps, timestep=t) + pipeline.guider.set_state(step=i, num_inference_steps=data.num_inference_steps, timestep=t, sigma=pipeline.scheduler.sigmas[i], sigma_next=pipeline.scheduler.sigmas[i + 1]) ( latents, @@ -3148,6 +3150,7 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: # Perform scheduler step using the predicted output data.latents_dtype = data.latents.dtype data.latents = pipeline.scheduler.step(data.noise_pred, t, data.latents, **data.extra_step_kwargs, return_dict=False)[0] + data.latents = pipeline.guider.post_scheduler_step(data.latents) if data.latents.dtype != data.latents_dtype: if torch.backends.mps.is_available(): diff --git a/src/diffusers/schedulers/scheduling_euler_discrete.py b/src/diffusers/schedulers/scheduling_euler_discrete.py index 56757f3ca197..4adec768b776 100644 --- a/src/diffusers/schedulers/scheduling_euler_discrete.py +++ b/src/diffusers/schedulers/scheduling_euler_discrete.py @@ -669,6 +669,35 @@ def step( prev_sample = sample + derivative * dt + # denoised = sample - model_output * sigmas[i] + # d = (sample - denoised) / sigmas[i] + # new_sample = denoised + d * sigmas[i + 1] + + # new_sample = denoised + (sample - denoised) * sigmas[i + 1] / sigmas[i] + # new_sample = sample - model_output * sigmas[i] + model_output * sigmas[i + 1] + # new_sample = sample + model_output * (sigmas[i + 1] - sigmas[i]) + # new_sample = sample - model_output * sigmas[i] + model_output * sigmas[i + 1] --- (1) + + # CFG++ ===== + # denoised = sample - model_output * sigmas[i] + # uncond_denoised = sample - model_output_uncond * sigmas[i] + # d = (sample - uncond_denoised) / sigmas[i] + # new_sample = denoised + d * sigmas[i + 1] + + # new_sample = denoised + (sample - uncond_denoised) * sigmas[i + 1] / sigmas[i] + # new_sample = sample - model_output * sigmas[i] + model_output_uncond * sigmas[i + 1] --- (2) + + # To go from (1) to (2): + # new_sample_2 = new_sample_1 - model_output * sigmas[i + 1] + model_output_uncond * sigmas[i + 1] + # new_sample_2 = new_sample_1 + (model_output_uncond - model_output) * sigmas[i + 1] + # new_sample_2 = new_sample_1 + diff * sigmas[i + 1] + + # diff = model_output_uncond - model_output + # diff = model_output_uncond - (model_output_uncond + g * (model_output_cond - model_output_uncond)) + # diff = model_output_uncond - (g * model_output_cond + (1 - g) * model_output_uncond) + # diff = model_output_uncond - g * model_output_cond + (g - 1) * model_output_uncond + # diff = g * (model_output_uncond - model_output_cond) + # Cast sample back to model compatible dtype prev_sample = prev_sample.to(model_output.dtype) From 78fca12803d69541fc63161b036db96792520fd8 Mon Sep 17 00:00:00 2001 From: Aryan Date: Wed, 16 Apr 2025 13:23:24 +0200 Subject: [PATCH 18/24] support cfgpp in ddim --- .../classifier_free_guidance_plus_plus.py | 20 ++++++++---------- src/diffusers/guiders/guider_utils.py | 11 ++-------- .../pipeline_stable_diffusion_xl_modular.py | 21 ++++++++----------- src/diffusers/schedulers/scheduling_ddim.py | 9 +++++++- .../schedulers/scheduling_euler_discrete.py | 11 +++++++++- 5 files changed, 38 insertions(+), 34 deletions(-) diff --git a/src/diffusers/guiders/classifier_free_guidance_plus_plus.py b/src/diffusers/guiders/classifier_free_guidance_plus_plus.py index 516dbfa0e05f..d1c6f8744143 100644 --- a/src/diffusers/guiders/classifier_free_guidance_plus_plus.py +++ b/src/diffusers/guiders/classifier_free_guidance_plus_plus.py @@ -12,8 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -import math -from typing import Optional, Union, Tuple, List +from typing import Dict, Optional, Union, Tuple, List import torch @@ -84,15 +83,6 @@ def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = return pred - def post_scheduler_step(self, pred: torch.Tensor) -> torch.Tensor: - if self._is_cfgpp_enabled(): - # TODO(aryan): this probably only makes sense for EulerDiscreteScheduler. Look into the others later! - pred_cond = self._preds["pred_cond"] - pred_uncond = self._preds["pred_uncond"] - diff = pred_uncond - pred_cond - pred = pred + diff * self.guidance_scale * self._sigma_next - return pred - @property def is_conditional(self) -> bool: return self._num_outputs_prepared == 0 @@ -104,6 +94,14 @@ def num_conditions(self) -> int: num_conditions += 1 return num_conditions + @property + def outputs(self) -> Dict[str, torch.Tensor]: + scheduler_step_kwargs = {} + if self._is_cfgpp_enabled(): + scheduler_step_kwargs["_use_cfgpp"] = True + scheduler_step_kwargs["_model_output_uncond"] = self._preds.get("pred_uncond") + return self._preds, scheduler_step_kwargs + def _is_cfgpp_enabled(self) -> bool: if not self._enabled: return False diff --git a/src/diffusers/guiders/guider_utils.py b/src/diffusers/guiders/guider_utils.py index f51452ed0cee..420a56690678 100644 --- a/src/diffusers/guiders/guider_utils.py +++ b/src/diffusers/guiders/guider_utils.py @@ -37,8 +37,6 @@ def __init__(self, start: float = 0.0, stop: float = 1.0): self._step: int = None self._num_inference_steps: int = None self._timestep: torch.LongTensor = None - self._sigma: torch.Tensor = None - self._sigma_next: torch.Tensor = None self._preds: Dict[str, torch.Tensor] = {} self._num_outputs_prepared: int = 0 self._enabled = True @@ -63,12 +61,10 @@ def _force_disable(self): def _force_enable(self): self._enabled = True - def set_state(self, step: int, num_inference_steps: int, timestep: torch.LongTensor, sigma: torch.Tensor, sigma_next: torch.Tensor) -> None: + def set_state(self, step: int, num_inference_steps: int, timestep: torch.LongTensor) -> None: self._step = step self._num_inference_steps = num_inference_steps self._timestep = timestep - self._sigma = sigma - self._sigma_next = sigma_next self._preds = {} self._num_outputs_prepared = 0 @@ -95,9 +91,6 @@ def __call__(self, **kwargs) -> Any: def forward(self, *args, **kwargs) -> Any: raise NotImplementedError("BaseGuidance::forward must be implemented in subclasses.") - def post_scheduler_step(self, pred: torch.Tensor) -> torch.Tensor: - return pred - @property def is_conditional(self) -> bool: raise NotImplementedError("BaseGuidance::is_conditional must be implemented in subclasses.") @@ -112,7 +105,7 @@ def num_conditions(self) -> int: @property def outputs(self) -> Dict[str, torch.Tensor]: - return self._preds + return self._preds, {} def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py index 8e0ea4545f29..37d2bbbe6ca6 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py @@ -2241,7 +2241,7 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: with pipeline.progress_bar(total=data.num_inference_steps) as progress_bar: for i, t in enumerate(data.timesteps): - pipeline.guider.set_state(step=i, num_inference_steps=data.num_inference_steps, timestep=t, sigma=pipeline.scheduler.sigmas[i], sigma_next=pipeline.scheduler.sigmas[i + 1]) + pipeline.guider.set_state(step=i, num_inference_steps=data.num_inference_steps, timestep=t) ( latents, @@ -2295,13 +2295,12 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: data.noise_pred = pipeline.guider.prepare_outputs(pipeline.unet, data.noise_pred) # Perform guidance - outputs = pipeline.guider.outputs + outputs, scheduler_step_kwargs = pipeline.guider.outputs data.noise_pred = pipeline.guider(**outputs) # Perform scheduler step using the predicted output data.latents_dtype = data.latents.dtype - data.latents = pipeline.scheduler.step(data.noise_pred, t, data.latents, **data.extra_step_kwargs, return_dict=False)[0] - data.latents = pipeline.guider.post_scheduler_step(data.latents) + data.latents = pipeline.scheduler.step(data.noise_pred, t, data.latents, **data.extra_step_kwargs, **scheduler_step_kwargs, return_dict=False)[0] if data.latents.dtype != data.latents_dtype: if torch.backends.mps.is_available(): @@ -2638,7 +2637,7 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: # (5) Denoise loop with pipeline.progress_bar(total=data.num_inference_steps) as progress_bar: for i, t in enumerate(data.timesteps): - pipeline.guider.set_state(step=i, num_inference_steps=data.num_inference_steps, timestep=t, sigma=pipeline.scheduler.sigmas[i], sigma_next=pipeline.scheduler.sigmas[i + 1]) + pipeline.guider.set_state(step=i, num_inference_steps=data.num_inference_steps, timestep=t) ( latents, @@ -2725,13 +2724,12 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: data.noise_pred = pipeline.guider.prepare_outputs(pipeline.unet, data.noise_pred) # Perform guidance - outputs = pipeline.guider.outputs + outputs, scheduler_step_kwargs = pipeline.guider.outputs data.noise_pred = pipeline.guider(**outputs) # Perform scheduler step using the predicted output data.latents_dtype = data.latents.dtype - data.latents = pipeline.scheduler.step(data.noise_pred, t, data.latents, **data.extra_step_kwargs, return_dict=False)[0] - data.latents = pipeline.guider.post_scheduler_step(data.latents) + data.latents = pipeline.scheduler.step(data.noise_pred, t, data.latents, **data.extra_step_kwargs, **scheduler_step_kwargs, return_dict=False)[0] if data.latents.dtype != data.latents_dtype: if torch.backends.mps.is_available(): @@ -3055,7 +3053,7 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: with pipeline.progress_bar(total=data.num_inference_steps) as progress_bar: for i, t in enumerate(data.timesteps): - pipeline.guider.set_state(step=i, num_inference_steps=data.num_inference_steps, timestep=t, sigma=pipeline.scheduler.sigmas[i], sigma_next=pipeline.scheduler.sigmas[i + 1]) + pipeline.guider.set_state(step=i, num_inference_steps=data.num_inference_steps, timestep=t) ( latents, @@ -3144,13 +3142,12 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: data.noise_pred = pipeline.guider.prepare_outputs(pipeline.unet, data.noise_pred) # Perform guidance - outputs = pipeline.guider.outputs + outputs, scheduler_step_kwargs = pipeline.guider.outputs data.noise_pred = pipeline.guider(**outputs) # Perform scheduler step using the predicted output data.latents_dtype = data.latents.dtype - data.latents = pipeline.scheduler.step(data.noise_pred, t, data.latents, **data.extra_step_kwargs, return_dict=False)[0] - data.latents = pipeline.guider.post_scheduler_step(data.latents) + data.latents = pipeline.scheduler.step(data.noise_pred, t, data.latents, **data.extra_step_kwargs, **scheduler_step_kwargs, return_dict=False)[0] if data.latents.dtype != data.latents_dtype: if torch.backends.mps.is_available(): diff --git a/src/diffusers/schedulers/scheduling_ddim.py b/src/diffusers/schedulers/scheduling_ddim.py index 13c9b3b4a5e9..2e74c9bbfccd 100644 --- a/src/diffusers/schedulers/scheduling_ddim.py +++ b/src/diffusers/schedulers/scheduling_ddim.py @@ -349,6 +349,8 @@ def step( generator=None, variance_noise: Optional[torch.Tensor] = None, return_dict: bool = True, + _model_output_uncond: Optional[torch.Tensor] = None, + _use_cfgpp: bool = True, ) -> Union[DDIMSchedulerOutput, Tuple]: """ Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion @@ -386,6 +388,11 @@ def step( raise ValueError( "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" ) + + if _use_cfgpp and self.config.prediction_type != "epsilon": + raise ValueError( + f"CFG++ is only supported for prediction type `epsilon`, but got {self.config.prediction_type}." + ) # See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf # Ideally, read DDIM paper in-detail understanding @@ -411,7 +418,7 @@ def step( # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf if self.config.prediction_type == "epsilon": pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5) - pred_epsilon = model_output + pred_epsilon = model_output if not _use_cfgpp else _model_output_uncond elif self.config.prediction_type == "sample": pred_original_sample = model_output pred_epsilon = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5) diff --git a/src/diffusers/schedulers/scheduling_euler_discrete.py b/src/diffusers/schedulers/scheduling_euler_discrete.py index 4adec768b776..4c82ca7e389b 100644 --- a/src/diffusers/schedulers/scheduling_euler_discrete.py +++ b/src/diffusers/schedulers/scheduling_euler_discrete.py @@ -584,6 +584,8 @@ def step( s_noise: float = 1.0, generator: Optional[torch.Generator] = None, return_dict: bool = True, + _model_output_uncond: Optional[torch.Tensor] = None, + _use_cfgpp: bool = False, ) -> Union[EulerDiscreteSchedulerOutput, Tuple]: """ Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion @@ -627,6 +629,11 @@ def step( "The `scale_model_input` function should be called before `step` to ensure correct denoising. " "See `StableDiffusionPipeline` for a usage example." ) + + if _use_cfgpp and self.config.prediction_type != "epsilon": + raise ValueError( + f"CFG++ is only supported for prediction type `epsilon`, but got {self.config.prediction_type}." + ) if self.step_index is None: self._init_step_index(timestep) @@ -668,7 +675,9 @@ def step( dt = self.sigmas[self.step_index + 1] - sigma_hat prev_sample = sample + derivative * dt - + if _use_cfgpp: + prev_sample = prev_sample + (_model_output_uncond - model_output) * self.sigmas[self.step_index + 1] + # denoised = sample - model_output * sigmas[i] # d = (sample - denoised) / sigmas[i] # new_sample = denoised + d * sigmas[i + 1] From e8768e58bd79d2936c3ef3098b0aec7b8c0b0491 Mon Sep 17 00:00:00 2001 From: Aryan Date: Wed, 23 Apr 2025 21:39:56 +0200 Subject: [PATCH 19/24] apply review suggestions --- .../guiders/classifier_free_guidance.py | 23 ++-- .../guiders/entropy_rectifying_guidance.py | 0 src/diffusers/guiders/guider_utils.py | 104 +++++++++++++++--- .../pipeline_stable_diffusion_xl_modular.py | 68 +++++------- 4 files changed, 128 insertions(+), 67 deletions(-) create mode 100644 src/diffusers/guiders/entropy_rectifying_guidance.py diff --git a/src/diffusers/guiders/classifier_free_guidance.py b/src/diffusers/guiders/classifier_free_guidance.py index 6978080b7152..1df38b897107 100644 --- a/src/diffusers/guiders/classifier_free_guidance.py +++ b/src/diffusers/guiders/classifier_free_guidance.py @@ -13,12 +13,15 @@ # limitations under the License. import math -from typing import Optional, Union, Tuple, List +from typing import Optional, Union, Tuple, List, TYPE_CHECKING import torch from .guider_utils import BaseGuidance, rescale_noise_cfg, _default_prepare_inputs +if TYPE_CHECKING: + from ..pipelines.modular_pipeline import BlockState + class ClassifierFreeGuidance(BaseGuidance): """ @@ -72,15 +75,13 @@ def __init__( self.guidance_rescale = guidance_rescale self.use_original_formulation = use_original_formulation - def prepare_inputs(self, denoiser: torch.nn.Module, *args: Union[Tuple[torch.Tensor], List[torch.Tensor]]) -> Tuple[List[torch.Tensor], ...]: - return _default_prepare_inputs(denoiser, self.num_conditions, *args) - - def prepare_outputs(self, denoiser: torch.nn.Module, pred: torch.Tensor) -> None: - self._num_outputs_prepared += 1 - if self._num_outputs_prepared > self.num_conditions: - raise ValueError(f"Expected {self.num_conditions} outputs, but prepare_outputs called more times.") - key = self._input_predictions[self._num_outputs_prepared - 1] - self._preds[key] = pred + def prepare_inputs(self, data: "BlockState") -> List["BlockState"]: + tuple_indices = [0] if self.num_conditions == 1 else [0, 1] + data_batches = [] + for i in range(self.num_conditions): + data_batch = self._prepare_batch(self._input_fields, data, tuple_indices[i], self._input_predictions[i]) + data_batches.append(data_batch) + return data_batches def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> torch.Tensor: pred = None @@ -95,7 +96,7 @@ def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = if self.guidance_rescale > 0.0: pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale) - return pred + return pred, {} @property def is_conditional(self) -> bool: diff --git a/src/diffusers/guiders/entropy_rectifying_guidance.py b/src/diffusers/guiders/entropy_rectifying_guidance.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/src/diffusers/guiders/guider_utils.py b/src/diffusers/guiders/guider_utils.py index 420a56690678..4643346c0099 100644 --- a/src/diffusers/guiders/guider_utils.py +++ b/src/diffusers/guiders/guider_utils.py @@ -21,6 +21,7 @@ if TYPE_CHECKING: from ..models.attention_processor import AttentionProcessor + from ..pipelines.modular_pipeline import BlockState logger = get_logger(__name__) # pylint: disable=invalid-name @@ -30,6 +31,7 @@ class BaseGuidance: r"""Base class providing the skeleton for implementing guidance techniques.""" _input_predictions = None + _identifier_key = "__guidance_identifier__" def __init__(self, start: float = 0.0, stop: float = 1.0): self._start = start @@ -37,7 +39,7 @@ def __init__(self, start: float = 0.0, stop: float = 1.0): self._step: int = None self._num_inference_steps: int = None self._timestep: torch.LongTensor = None - self._preds: Dict[str, torch.Tensor] = {} + self._input_fields: Dict[str, Union[str, Tuple[str, str]]] = None self._num_outputs_prepared: int = 0 self._enabled = True @@ -65,9 +67,45 @@ def set_state(self, step: int, num_inference_steps: int, timestep: torch.LongTen self._step = step self._num_inference_steps = num_inference_steps self._timestep = timestep - self._preds = {} self._num_outputs_prepared = 0 + def set_input_fields(self, **kwargs: Dict[str, Union[str, Tuple[str, str]]]) -> None: + """ + Set the input fields for the guidance technique. The input fields are used to specify the names of the + returned attributes containing the prepared data after `prepare_inputs` is called. The prepared data is + obtained from the values of the provided keyword arguments to this method. + + Args: + **kwargs (`Dict[str, Union[str, Tuple[str, str]]]`): + A dictionary where the keys are the names of the fields that will be used to store the data once + it is prepared with `prepare_inputs`. The values can be either a string or a tuple of length 2, + which is used to look up the required data provided for preparation. + + If a string is provided, it will be used as the conditional data (or unconditional if used with + a guidance method that requires it). If a tuple of length 2 is provided, the first element must + be the conditional data identifier and the second element must be the unconditional data identifier + or None. + + Example: + + ``` + data = {"prompt_embeds": , "negative_prompt_embeds": , "latents": } + + BaseGuidance.set_input_fields( + latents="latents", + prompt_embeds=("prompt_embeds", "negative_prompt_embeds"), + ) + ``` + """ + for key, value in kwargs.items(): + is_string = isinstance(value, str) + is_tuple_of_str_with_len_2 = isinstance(value, tuple) and len(value) == 2 and all(isinstance(v, str) for v in value) + if not (is_string or is_tuple_of_str_with_len_2): + raise ValueError( + f"Expected `set_input_fields` to be called with a string or a tuple of string with length 2, but got {type(value)} for key {key}." + ) + self._input_fields = kwargs + def prepare_models(self, denoiser: torch.nn.Module) -> None: """ Prepares the models for the guidance technique on a given batch of data. This method should be overridden in @@ -75,18 +113,18 @@ def prepare_models(self, denoiser: torch.nn.Module) -> None: """ pass - def prepare_inputs(self, denoiser: torch.nn.Module, *args: Union[Tuple[torch.Tensor], List[torch.Tensor]]) -> Tuple[List[torch.Tensor], ...]: + def prepare_inputs(self, data: "BlockState") -> List["BlockState"]: raise NotImplementedError("BaseGuidance::prepare_inputs must be implemented in subclasses.") - def prepare_outputs(self, denoiser: torch.nn.Module, pred: torch.Tensor) -> None: - raise NotImplementedError("BaseGuidance::prepare_outputs must be implemented in subclasses.") - - def __call__(self, **kwargs) -> Any: - if len(kwargs) != self.num_conditions: + def __call__(self, data: List["BlockState"]) -> Any: + if not all(hasattr(d, "noise_pred") for d in data): + raise ValueError("Expected all data to have `noise_pred` attribute.") + if len(data) != self.num_conditions: raise ValueError( - f"Expected {self.num_conditions} arguments, but got {len(kwargs)}. Please provide the correct number of arguments." + f"Expected {self.num_conditions} data items, but got {len(data)}. Please check the input data." ) - return self.forward(**kwargs) + forward_inputs = {getattr(d, self._identifier_key): d.noise_pred for d in data} + return self.forward(**forward_inputs) def forward(self, *args, **kwargs) -> Any: raise NotImplementedError("BaseGuidance::forward must be implemented in subclasses.") @@ -102,10 +140,48 @@ def is_unconditional(self) -> bool: @property def num_conditions(self) -> int: raise NotImplementedError("BaseGuidance::num_conditions must be implemented in subclasses.") - - @property - def outputs(self) -> Dict[str, torch.Tensor]: - return self._preds, {} + + @classmethod + def _prepare_batch(cls, input_fields: Dict[str, Union[str, Tuple[str, str]]], data: "BlockState", tuple_index: int, identifier: str) -> "BlockState": + """ + Prepares a batch of data for the guidance technique. This method is used in the `prepare_inputs` method of + the `BaseGuidance` class. It prepares the batch based on the provided tuple index. + + Args: + input_fields (`Dict[str, Union[str, Tuple[str, str]]]`): + A dictionary where the keys are the names of the fields that will be used to store the data once + it is prepared with `prepare_inputs`. The values can be either a string or a tuple of length 2, + which is used to look up the required data provided for preparation. + If a string is provided, it will be used as the conditional data (or unconditional if used with + a guidance method that requires it). If a tuple of length 2 is provided, the first element must + be the conditional data identifier and the second element must be the unconditional data identifier + or None. + data (`BlockState`): + The input data to be prepared. + tuple_index (`int`): + The index to use when accessing input fields that are tuples. + + Returns: + `BlockState`: The prepared batch of data. + """ + from ..pipelines.modular_pipeline import BlockState + + if input_fields is None: + raise ValueError("Input fields have not been set. Please call `set_input_fields` before preparing inputs.") + data_batch = {} + for key, value in input_fields.items(): + try: + if isinstance(value, str): + data_batch[key] = getattr(data, value) + elif isinstance(value, tuple): + data_batch[key] = getattr(data, value[tuple_index]) + else: + # We've already checked that value is a string or a tuple of strings with length 2 + pass + except AttributeError: + raise ValueError(f"Expected `data` to have attribute(s) {value}, but it does not. Please check the input data.") + data_batch[cls._identifier_key] = identifier + return BlockState(**data_batch) def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py index 37d2bbbe6ca6..b70a4b6c082a 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py @@ -2239,64 +2239,48 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: data.extra_step_kwargs = self.prepare_extra_step_kwargs(pipeline, data.generator, data.eta) data.num_warmup_steps = max(len(data.timesteps) - data.num_inference_steps * pipeline.scheduler.order, 0) + pipeline.guider.set_input_fields( + prompt_embeds=("prompt_embeds", "negative_prompt_embeds"), + add_time_ids=("add_time_ids", "negative_add_time_ids"), + pooled_prompt_embeds=("pooled_prompt_embeds", "negative_pooled_prompt_embeds"), + ip_adapter_embeds=("ip_adapter_embeds", "negative_ip_adapter_embeds"), + ) + with pipeline.progress_bar(total=data.num_inference_steps) as progress_bar: for i, t in enumerate(data.timesteps): pipeline.guider.set_state(step=i, num_inference_steps=data.num_inference_steps, timestep=t) + guider_data = pipeline.guider.prepare_inputs(data) - ( - latents, - prompt_embeds, - add_time_ids, - pooled_prompt_embeds, - mask, - masked_image_latents, - ip_adapter_embeds, - ) = pipeline.guider.prepare_inputs( - pipeline.unet, - data.latents, - (data.prompt_embeds, data.negative_prompt_embeds), - (data.add_time_ids, data.negative_add_time_ids), - (data.pooled_prompt_embeds, data.negative_pooled_prompt_embeds), - data.mask, - data.masked_image_latents, - (data.ip_adapter_embeds, data.negative_ip_adapter_embeds), - ) - - for batch_index, ( - latents_i, prompt_embeds_i, add_time_ids_i, pooled_prompt_embeds_i, mask_i, masked_image_latents_i, ip_adapter_embeds_i, - ) in enumerate(zip( - latents, prompt_embeds, add_time_ids, pooled_prompt_embeds, mask, masked_image_latents, ip_adapter_embeds - )): + data.scaled_latents = pipeline.scheduler.scale_model_input(data.latents, t) + + # Prepare for inpainting + if data.num_channels_unet == 9: + data.scaled_latents = torch.cat([data.scaled_latents, data.mask, data.masked_image_latents], dim=1) + + for batch in guider_data: pipeline.guider.prepare_models(pipeline.unet) - latents_i = pipeline.scheduler.scale_model_input(latents_i, t) - - # Prepare for inpainting - if data.num_channels_unet == 9: - latents_i = torch.cat([latents_i, mask_i, masked_image_latents_i], dim=1) # Prepare additional conditionings - data.added_cond_kwargs = { - "text_embeds": pooled_prompt_embeds_i, - "time_ids": add_time_ids_i, + batch.added_cond_kwargs = { + "text_embeds": batch.pooled_prompt_embeds, + "time_ids": batch.add_time_ids, } - if ip_adapter_embeds_i is not None: - data.added_cond_kwargs["image_embeds"] = ip_adapter_embeds_i - + if batch.ip_adapter_embeds is not None: + batch.added_cond_kwargs["image_embeds"] = batch.ip_adapter_embeds + # Predict the noise residual - data.noise_pred = pipeline.unet( - latents_i, + batch.noise_pred = pipeline.unet( + data.scaled_latents, t, - encoder_hidden_states=prompt_embeds_i, + encoder_hidden_states=batch.prompt_embeds, timestep_cond=data.timestep_cond, cross_attention_kwargs=data.cross_attention_kwargs, - added_cond_kwargs=data.added_cond_kwargs, + added_cond_kwargs=batch.added_cond_kwargs, return_dict=False, )[0] - data.noise_pred = pipeline.guider.prepare_outputs(pipeline.unet, data.noise_pred) # Perform guidance - outputs, scheduler_step_kwargs = pipeline.guider.outputs - data.noise_pred = pipeline.guider(**outputs) + data.noise_pred, scheduler_step_kwargs = pipeline.guider(guider_data) # Perform scheduler step using the predicted output data.latents_dtype = data.latents.dtype From 0d5a788eeee5cb6f2d4fdf1d1811705214b44c64 Mon Sep 17 00:00:00 2001 From: Aryan Date: Fri, 25 Apr 2025 00:23:47 +0200 Subject: [PATCH 20/24] refactor --- .../guiders/adaptive_projected_guidance.py | 27 +-- src/diffusers/guiders/auto_guidance.py | 43 ++-- .../guiders/classifier_free_guidance.py | 6 +- .../classifier_free_guidance_plus_plus.py | 38 ++-- .../classifier_free_zero_star_guidance.py | 27 +-- src/diffusers/guiders/guider_utils.py | 12 +- src/diffusers/guiders/skip_layer_guidance.py | 67 +++--- .../guiders/smoothed_energy_guidance.py | 66 +++--- .../tangential_classifier_free_guidance.py | 27 +-- .../pipeline_stable_diffusion_xl_modular.py | 190 ++++++++---------- 10 files changed, 227 insertions(+), 276 deletions(-) diff --git a/src/diffusers/guiders/adaptive_projected_guidance.py b/src/diffusers/guiders/adaptive_projected_guidance.py index 05c186e58d9f..7da1cc59a365 100644 --- a/src/diffusers/guiders/adaptive_projected_guidance.py +++ b/src/diffusers/guiders/adaptive_projected_guidance.py @@ -13,11 +13,14 @@ # limitations under the License. import math -from typing import Optional, Union, Tuple, List +from typing import Optional, List, TYPE_CHECKING import torch -from .guider_utils import BaseGuidance, rescale_noise_cfg, _default_prepare_inputs +from .guider_utils import BaseGuidance, rescale_noise_cfg + +if TYPE_CHECKING: + from ..pipelines.modular_pipeline import BlockState class AdaptiveProjectedGuidance(BaseGuidance): @@ -70,18 +73,16 @@ def __init__( self.use_original_formulation = use_original_formulation self.momentum_buffer = None - def prepare_inputs(self, denoiser: torch.nn.Module, *args: Union[Tuple[torch.Tensor], List[torch.Tensor]]) -> Tuple[List[torch.Tensor], ...]: + def prepare_inputs(self, data: "BlockState") -> List["BlockState"]: if self._step == 0: if self.adaptive_projected_guidance_momentum is not None: self.momentum_buffer = MomentumBuffer(self.adaptive_projected_guidance_momentum) - return _default_prepare_inputs(denoiser, self.num_conditions, *args) - - def prepare_outputs(self, denoiser: torch.nn.Module, pred: torch.Tensor) -> None: - self._num_outputs_prepared += 1 - if self._num_outputs_prepared > self.num_conditions: - raise ValueError(f"Expected {self.num_conditions} outputs, but prepare_outputs called more times.") - key = self._input_predictions[self._num_outputs_prepared - 1] - self._preds[key] = pred + tuple_indices = [0] if self.num_conditions == 1 else [0, 1] + data_batches = [] + for i in range(self.num_conditions): + data_batch = self._prepare_batch(self._input_fields, data, tuple_indices[i], self._input_predictions[i]) + data_batches.append(data_batch) + return data_batches def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> torch.Tensor: pred = None @@ -102,11 +103,11 @@ def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = if self.guidance_rescale > 0.0: pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale) - return pred + return pred, {} @property def is_conditional(self) -> bool: - return self._num_outputs_prepared == 0 + return self._count_prepared == 1 @property def num_conditions(self) -> int: diff --git a/src/diffusers/guiders/auto_guidance.py b/src/diffusers/guiders/auto_guidance.py index 8c759f497307..bfffb9f39cd2 100644 --- a/src/diffusers/guiders/auto_guidance.py +++ b/src/diffusers/guiders/auto_guidance.py @@ -13,13 +13,16 @@ # limitations under the License. import math -from typing import List, Optional, Tuple, Union +from typing import List, Optional, Union, TYPE_CHECKING import torch from ..hooks import HookRegistry, LayerSkipConfig from ..hooks.layer_skip import _apply_layer_skip_hook -from .guider_utils import BaseGuidance, rescale_noise_cfg, _default_prepare_inputs +from .guider_utils import BaseGuidance, rescale_noise_cfg + +if TYPE_CHECKING: + from ..pipelines.modular_pipeline import BlockState class AutoGuidance(BaseGuidance): @@ -106,26 +109,24 @@ def __init__( self._auto_guidance_hook_names = [f"AutoGuidance_{i}" for i in range(len(self.auto_guidance_config))] def prepare_models(self, denoiser: torch.nn.Module) -> None: + self._count_prepared += 1 if self._is_ag_enabled() and self.is_unconditional: for name, config in zip(self._auto_guidance_hook_names, self.auto_guidance_config): _apply_layer_skip_hook(denoiser, config, name=name) - def prepare_inputs(self, denoiser: torch.nn.Module, *args: Union[Tuple[torch.Tensor], List[torch.Tensor]]) -> Tuple[List[torch.Tensor], ...]: - return _default_prepare_inputs(denoiser, self.num_conditions, *args) - - def prepare_outputs(self, denoiser: torch.nn.Module, pred: torch.Tensor) -> None: - self._num_outputs_prepared += 1 - if self._num_outputs_prepared > self.num_conditions: - raise ValueError(f"Expected {self.num_conditions} outputs, but prepare_outputs called more times.") - key = self._input_predictions[self._num_outputs_prepared - 1] - self._preds[key] = pred - - if key == "pred_uncond": - # If we are in AutoGuidance unconditional inference mode, we need to remove the hooks after inference - registry = HookRegistry.check_if_exists_or_initialize(denoiser) - # Remove the hooks after inference - for hook_name in self._auto_guidance_hook_names: - registry.remove_hook(hook_name, recurse=True) + def cleanup_models(self, denoiser: torch.nn.Module) -> None: + if self._is_ag_enabled() and self.is_unconditional: + for name in self._auto_guidance_hook_names: + registry = HookRegistry.check_if_exists_or_initialize(denoiser) + registry.remove_hook(name, recurse=True) + + def prepare_inputs(self, data: "BlockState") -> List["BlockState"]: + tuple_indices = [0] if self.num_conditions == 1 else [0, 1] + data_batches = [] + for i in range(self.num_conditions): + data_batch = self._prepare_batch(self._input_fields, data, tuple_indices[i], self._input_predictions[i]) + data_batches.append(data_batch) + return data_batches def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> torch.Tensor: pred = None @@ -139,12 +140,12 @@ def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = if self.guidance_rescale > 0.0: pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale) - - return pred + + return pred, {} @property def is_conditional(self) -> bool: - return self._num_outputs_prepared == 0 + return self._count_prepared == 1 @property def num_conditions(self) -> int: diff --git a/src/diffusers/guiders/classifier_free_guidance.py b/src/diffusers/guiders/classifier_free_guidance.py index 1df38b897107..429f8450410a 100644 --- a/src/diffusers/guiders/classifier_free_guidance.py +++ b/src/diffusers/guiders/classifier_free_guidance.py @@ -13,11 +13,11 @@ # limitations under the License. import math -from typing import Optional, Union, Tuple, List, TYPE_CHECKING +from typing import Optional, List, TYPE_CHECKING import torch -from .guider_utils import BaseGuidance, rescale_noise_cfg, _default_prepare_inputs +from .guider_utils import BaseGuidance, rescale_noise_cfg if TYPE_CHECKING: from ..pipelines.modular_pipeline import BlockState @@ -100,7 +100,7 @@ def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = @property def is_conditional(self) -> bool: - return self._num_outputs_prepared == 0 + return self._count_prepared == 1 @property def num_conditions(self) -> int: diff --git a/src/diffusers/guiders/classifier_free_guidance_plus_plus.py b/src/diffusers/guiders/classifier_free_guidance_plus_plus.py index d1c6f8744143..1f44f883c248 100644 --- a/src/diffusers/guiders/classifier_free_guidance_plus_plus.py +++ b/src/diffusers/guiders/classifier_free_guidance_plus_plus.py @@ -12,11 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Dict, Optional, Union, Tuple, List +from typing import Dict, Optional, List, TYPE_CHECKING import torch -from .guider_utils import BaseGuidance, rescale_noise_cfg, _default_prepare_inputs +from .guider_utils import BaseGuidance, rescale_noise_cfg + +if TYPE_CHECKING: + from ..pipelines.modular_pipeline import BlockState class CFGPlusPlusGuidance(BaseGuidance): @@ -58,15 +61,13 @@ def __init__( self.guidance_rescale = guidance_rescale self.use_original_formulation = use_original_formulation - def prepare_inputs(self, denoiser: torch.nn.Module, *args: Union[Tuple[torch.Tensor], List[torch.Tensor]]) -> Tuple[List[torch.Tensor], ...]: - return _default_prepare_inputs(denoiser, self.num_conditions, *args) - - def prepare_outputs(self, denoiser: torch.nn.Module, pred: torch.Tensor) -> None: - self._num_outputs_prepared += 1 - if self._num_outputs_prepared > self.num_conditions: - raise ValueError(f"Expected {self.num_conditions} outputs, but prepare_outputs called more times.") - key = self._input_predictions[self._num_outputs_prepared - 1] - self._preds[key] = pred + def prepare_inputs(self, data: "BlockState") -> List["BlockState"]: + tuple_indices = [0] if self.num_conditions == 1 else [0, 1] + data_batches = [] + for i in range(self.num_conditions): + data_batch = self._prepare_batch(self._input_fields, data, tuple_indices[i], self._input_predictions[i]) + data_batches.append(data_batch) + return data_batches def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> torch.Tensor: pred = None @@ -81,11 +82,14 @@ def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = if self.guidance_rescale > 0.0: pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale) - return pred + scheduler_kwargs = {} + if self._is_cfgpp_enabled(): + scheduler_kwargs = {"_use_cfgpp": True, "_model_output_uncond": pred_uncond} + return pred, scheduler_kwargs @property def is_conditional(self) -> bool: - return self._num_outputs_prepared == 0 + return self._count_prepared == 1 @property def num_conditions(self) -> int: @@ -94,14 +98,6 @@ def num_conditions(self) -> int: num_conditions += 1 return num_conditions - @property - def outputs(self) -> Dict[str, torch.Tensor]: - scheduler_step_kwargs = {} - if self._is_cfgpp_enabled(): - scheduler_step_kwargs["_use_cfgpp"] = True - scheduler_step_kwargs["_model_output_uncond"] = self._preds.get("pred_uncond") - return self._preds, scheduler_step_kwargs - def _is_cfgpp_enabled(self) -> bool: if not self._enabled: return False diff --git a/src/diffusers/guiders/classifier_free_zero_star_guidance.py b/src/diffusers/guiders/classifier_free_zero_star_guidance.py index 04c504f8f2d6..4c9839ee78f3 100644 --- a/src/diffusers/guiders/classifier_free_zero_star_guidance.py +++ b/src/diffusers/guiders/classifier_free_zero_star_guidance.py @@ -13,11 +13,14 @@ # limitations under the License. import math -from typing import Optional, Union, Tuple, List +from typing import Optional, List, TYPE_CHECKING import torch -from .guider_utils import BaseGuidance, rescale_noise_cfg, _default_prepare_inputs +from .guider_utils import BaseGuidance, rescale_noise_cfg + +if TYPE_CHECKING: + from ..pipelines.modular_pipeline import BlockState class ClassifierFreeZeroStarGuidance(BaseGuidance): @@ -70,15 +73,13 @@ def __init__( self.guidance_rescale = guidance_rescale self.use_original_formulation = use_original_formulation - def prepare_inputs(self, denoiser: torch.nn.Module, *args: Union[Tuple[torch.Tensor], List[torch.Tensor]]) -> Tuple[List[torch.Tensor], ...]: - return _default_prepare_inputs(denoiser, self.num_conditions, *args) - - def prepare_outputs(self, denoiser: torch.nn.Module, pred: torch.Tensor) -> None: - self._num_outputs_prepared += 1 - if self._num_outputs_prepared > self.num_conditions: - raise ValueError(f"Expected {self.num_conditions} outputs, but prepare_outputs called more times.") - key = self._input_predictions[self._num_outputs_prepared - 1] - self._preds[key] = pred + def prepare_inputs(self, data: "BlockState") -> List["BlockState"]: + tuple_indices = [0] if self.num_conditions == 1 else [0, 1] + data_batches = [] + for i in range(self.num_conditions): + data_batch = self._prepare_batch(self._input_fields, data, tuple_indices[i], self._input_predictions[i]) + data_batches.append(data_batch) + return data_batches def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> torch.Tensor: pred = None @@ -100,11 +101,11 @@ def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = if self.guidance_rescale > 0.0: pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale) - return pred + return pred, {} @property def is_conditional(self) -> bool: - return self._num_outputs_prepared == 0 + return self._count_prepared == 1 @property def num_conditions(self) -> int: diff --git a/src/diffusers/guiders/guider_utils.py b/src/diffusers/guiders/guider_utils.py index 4643346c0099..f144a77aa11c 100644 --- a/src/diffusers/guiders/guider_utils.py +++ b/src/diffusers/guiders/guider_utils.py @@ -39,8 +39,8 @@ def __init__(self, start: float = 0.0, stop: float = 1.0): self._step: int = None self._num_inference_steps: int = None self._timestep: torch.LongTensor = None + self._count_prepared = 0 self._input_fields: Dict[str, Union[str, Tuple[str, str]]] = None - self._num_outputs_prepared: int = 0 self._enabled = True if not (0.0 <= start < 1.0): @@ -67,7 +67,7 @@ def set_state(self, step: int, num_inference_steps: int, timestep: torch.LongTen self._step = step self._num_inference_steps = num_inference_steps self._timestep = timestep - self._num_outputs_prepared = 0 + self._count_prepared = 0 def set_input_fields(self, **kwargs: Dict[str, Union[str, Tuple[str, str]]]) -> None: """ @@ -111,6 +111,14 @@ def prepare_models(self, denoiser: torch.nn.Module) -> None: Prepares the models for the guidance technique on a given batch of data. This method should be overridden in subclasses to implement specific model preparation logic. """ + self._count_prepared += 1 + + def cleanup_models(self, denoiser: torch.nn.Module) -> None: + """ + Cleans up the models for the guidance technique after a given batch of data. This method should be overridden in + subclasses to implement specific model cleanup logic. It is useful for removing any hooks or other stateful + modifications made during `prepare_models`. + """ pass def prepare_inputs(self, data: "BlockState") -> List["BlockState"]: diff --git a/src/diffusers/guiders/skip_layer_guidance.py b/src/diffusers/guiders/skip_layer_guidance.py index 64b2b8a73c1a..bdd9e4af81b6 100644 --- a/src/diffusers/guiders/skip_layer_guidance.py +++ b/src/diffusers/guiders/skip_layer_guidance.py @@ -13,7 +13,7 @@ # limitations under the License. import math -from typing import List, Optional, Tuple, Union +from typing import List, Optional, Union, TYPE_CHECKING import torch @@ -21,6 +21,9 @@ from ..hooks.layer_skip import _apply_layer_skip_hook from .guider_utils import BaseGuidance, rescale_noise_cfg +if TYPE_CHECKING: + from ..pipelines.modular_pipeline import BlockState + class SkipLayerGuidance(BaseGuidance): """ @@ -141,51 +144,33 @@ def __init__( self._skip_layer_hook_names = [f"SkipLayerGuidance_{i}" for i in range(len(self.skip_layer_config))] def prepare_models(self, denoiser: torch.nn.Module) -> None: - if self._is_slg_enabled() and self.is_conditional and self._num_outputs_prepared > 0: + self._count_prepared += 1 + if self._is_slg_enabled() and self.is_conditional and self._count_prepared > 1: for name, config in zip(self._skip_layer_hook_names, self.skip_layer_config): _apply_layer_skip_hook(denoiser, config, name=name) - def prepare_inputs(self, denoiser: torch.nn.Module, *args: Union[Tuple[torch.Tensor], List[torch.Tensor]]) -> Tuple[List[torch.Tensor], ...]: - num_conditions = self.num_conditions - list_of_inputs = [] - for arg in args: - if arg is None or isinstance(arg, torch.Tensor): - list_of_inputs.append([arg] * num_conditions) - elif isinstance(arg, (tuple, list)): - if len(arg) != 2: - raise ValueError( - f"Expected a tuple or list of length 2, but got {len(arg)} for argument {arg}. Please provide a tuple/list of length 2 " - f"with the first element being the conditional input and the second element being the unconditional input or None." - ) - if arg[1] is None: - # Only conditioning inputs for all batches - list_of_inputs.append([arg[0]] * num_conditions) - else: - list_of_inputs.append([arg[0], arg[1], arg[0]]) - else: - raise ValueError( - f"Expected a tensor, tuple, or list, but got {type(arg)} for argument {arg}. Please provide a tensor, tuple, or list." - ) - return tuple(list_of_inputs) - - def prepare_outputs(self, denoiser: torch.nn.Module, pred: torch.Tensor) -> None: - self._num_outputs_prepared += 1 - if self._num_outputs_prepared > self.num_conditions: - raise ValueError(f"Expected {self.num_conditions} outputs, but prepare_outputs called more times.") - key = self._input_predictions[self._num_outputs_prepared - 1] - if not self._is_cfg_enabled() and self._is_slg_enabled(): - # If we're predicting pred_cond and pred_cond_skip only, we need to set the key to pred_cond_skip - # to avoid writing into pred_uncond which is not used - if self._num_outputs_prepared == 2: - key = "pred_cond_skip" - self._preds[key] = pred - - if key == "pred_cond_skip": - # If we are in SLG mode, we need to remove the hooks after inference + def cleanup_models(self, denoiser: torch.nn.Module) -> None: + if self._is_slg_enabled() and self.is_conditional and self._count_prepared > 1: registry = HookRegistry.check_if_exists_or_initialize(denoiser) # Remove the hooks after inference for hook_name in self._skip_layer_hook_names: registry.remove_hook(hook_name, recurse=True) + + def prepare_inputs(self, data: "BlockState") -> List["BlockState"]: + if self.num_conditions == 1: + tuple_indices = [0] + input_predictions = ["pred_cond"] + elif self.num_conditions == 2: + tuple_indices = [0, 1] + input_predictions = ["pred_cond", "pred_uncond"] if self._is_cfg_enabled() else ["pred_cond", "pred_cond_skip"] + else: + tuple_indices = [0, 1, 0] + input_predictions = ["pred_cond", "pred_uncond", "pred_cond_skip"] + data_batches = [] + for i in range(self.num_conditions): + data_batch = self._prepare_batch(self._input_fields, data, tuple_indices[i], input_predictions[i]) + data_batches.append(data_batch) + return data_batches def forward( self, @@ -214,11 +199,11 @@ def forward( if self.guidance_rescale > 0.0: pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale) - return pred + return pred, {} @property def is_conditional(self) -> bool: - return self._num_outputs_prepared == 0 or self._num_outputs_prepared == 2 + return self._count_prepared == 1 or self._count_prepared == 3 @property def num_conditions(self) -> int: diff --git a/src/diffusers/guiders/smoothed_energy_guidance.py b/src/diffusers/guiders/smoothed_energy_guidance.py index 906900856f4c..1c7ee45dc3db 100644 --- a/src/diffusers/guiders/smoothed_energy_guidance.py +++ b/src/diffusers/guiders/smoothed_energy_guidance.py @@ -13,7 +13,7 @@ # limitations under the License. import math -from typing import List, Optional, Tuple, Union +from typing import List, Optional, Union, TYPE_CHECKING import torch @@ -21,6 +21,9 @@ from ..hooks.smoothed_energy_guidance_utils import SmoothedEnergyGuidanceConfig, _apply_smoothed_energy_guidance_hook from .guider_utils import BaseGuidance, rescale_noise_cfg +if TYPE_CHECKING: + from ..pipelines.modular_pipeline import BlockState + class SmoothedEnergyGuidance(BaseGuidance): """ @@ -135,51 +138,32 @@ def __init__( self._seg_layer_hook_names = [f"SmoothedEnergyGuidance_{i}" for i in range(len(self.seg_guidance_config))] def prepare_models(self, denoiser: torch.nn.Module) -> None: - if self._is_seg_enabled() and self.is_conditional and self._num_outputs_prepared > 0: + if self._is_seg_enabled() and self.is_conditional and self._count_prepared > 1: for name, config in zip(self._seg_layer_hook_names, self.seg_guidance_config): _apply_smoothed_energy_guidance_hook(denoiser, config, self.seg_blur_sigma, name=name) - def prepare_inputs(self, denoiser: torch.nn.Module, *args: Union[Tuple[torch.Tensor], List[torch.Tensor]]) -> Tuple[List[torch.Tensor], ...]: - num_conditions = self.num_conditions - list_of_inputs = [] - for arg in args: - if arg is None or isinstance(arg, torch.Tensor): - list_of_inputs.append([arg] * num_conditions) - elif isinstance(arg, (tuple, list)): - if len(arg) != 2: - raise ValueError( - f"Expected a tuple or list of length 2, but got {len(arg)} for argument {arg}. Please provide a tuple/list of length 2 " - f"with the first element being the conditional input and the second element being the unconditional input or None." - ) - if arg[1] is None: - # Only conditioning inputs for all batches - list_of_inputs.append([arg[0]] * num_conditions) - else: - list_of_inputs.append([arg[0], arg[1], arg[0]]) - else: - raise ValueError( - f"Expected a tensor, tuple, or list, but got {type(arg)} for argument {arg}. Please provide a tensor, tuple, or list." - ) - return tuple(list_of_inputs) - - def prepare_outputs(self, denoiser: torch.nn.Module, pred: torch.Tensor) -> None: - self._num_outputs_prepared += 1 - if self._num_outputs_prepared > self.num_conditions: - raise ValueError(f"Expected {self.num_conditions} outputs, but prepare_outputs called more times.") - key = self._input_predictions[self._num_outputs_prepared - 1] - if not self._is_cfg_enabled() and self._is_seg_enabled(): - # If we're predicting pred_cond and pred_cond_seg only, we need to set the key to pred_cond_seg - # to avoid writing into pred_uncond which is not used - if self._num_outputs_prepared == 2: - key = "pred_cond_seg" - self._preds[key] = pred - - if key == "pred_cond_seg": - # If we are in SLG mode, we need to remove the hooks after inference + def cleanup_models(self, denoiser: torch.nn.Module): + if self._is_seg_enabled() and self.is_conditional and self._count_prepared > 1: registry = HookRegistry.check_if_exists_or_initialize(denoiser) # Remove the hooks after inference for hook_name in self._seg_layer_hook_names: registry.remove_hook(hook_name, recurse=True) + + def prepare_inputs(self, data: "BlockState") -> List["BlockState"]: + if self.num_conditions == 1: + tuple_indices = [0] + input_predictions = ["pred_cond"] + elif self.num_conditions == 2: + tuple_indices = [0, 1] + input_predictions = ["pred_cond", "pred_uncond"] if self._is_cfg_enabled() else ["pred_cond", "pred_cond_seg"] + else: + tuple_indices = [0, 1, 0] + input_predictions = ["pred_cond", "pred_uncond", "pred_cond_seg"] + data_batches = [] + for i in range(self.num_conditions): + data_batch = self._prepare_batch(self._input_fields, data, tuple_indices[i], input_predictions[i]) + data_batches.append(data_batch) + return data_batches def forward( self, @@ -208,11 +192,11 @@ def forward( if self.guidance_rescale > 0.0: pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale) - return pred + return pred, {} @property def is_conditional(self) -> bool: - return self._num_outputs_prepared == 0 or self._num_outputs_prepared == 2 + return self._count_prepared == 1 or self._count_prepared == 3 @property def num_conditions(self) -> int: diff --git a/src/diffusers/guiders/tangential_classifier_free_guidance.py b/src/diffusers/guiders/tangential_classifier_free_guidance.py index 7529114bfd6f..631f9a5f33b2 100644 --- a/src/diffusers/guiders/tangential_classifier_free_guidance.py +++ b/src/diffusers/guiders/tangential_classifier_free_guidance.py @@ -13,11 +13,14 @@ # limitations under the License. import math -from typing import Optional, Union, Tuple, List +from typing import Optional, List, TYPE_CHECKING import torch -from .guider_utils import BaseGuidance, rescale_noise_cfg, _default_prepare_inputs +from .guider_utils import BaseGuidance, rescale_noise_cfg + +if TYPE_CHECKING: + from ..pipelines.modular_pipeline import BlockState class TangentialClassifierFreeGuidance(BaseGuidance): @@ -59,15 +62,13 @@ def __init__( self.guidance_rescale = guidance_rescale self.use_original_formulation = use_original_formulation - def prepare_inputs(self, denoiser: torch.nn.Module, *args: Union[Tuple[torch.Tensor], List[torch.Tensor]]) -> Tuple[List[torch.Tensor], ...]: - return _default_prepare_inputs(denoiser, self.num_conditions, *args) - - def prepare_outputs(self, denoiser: torch.nn.Module, pred: torch.Tensor) -> None: - self._num_outputs_prepared += 1 - if self._num_outputs_prepared > self.num_conditions: - raise ValueError(f"Expected {self.num_conditions} outputs, but prepare_outputs called more times.") - key = self._input_predictions[self._num_outputs_prepared - 1] - self._preds[key] = pred + def prepare_inputs(self, data: "BlockState") -> List["BlockState"]: + tuple_indices = [0] if self.num_conditions == 1 else [0, 1] + data_batches = [] + for i in range(self.num_conditions): + data_batch = self._prepare_batch(self._input_fields, data, tuple_indices[i], self._input_predictions[i]) + data_batches.append(data_batch) + return data_batches def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> torch.Tensor: pred = None @@ -80,11 +81,11 @@ def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = if self.guidance_rescale > 0.0: pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale) - return pred + return pred, {} @property def is_conditional(self) -> bool: - return self._num_outputs_prepared == 0 + return self._num_outputs_prepared == 1 @property def num_conditions(self) -> int: diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py index b70a4b6c082a..c643ab399833 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py @@ -2278,6 +2278,7 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: added_cond_kwargs=batch.added_cond_kwargs, return_dict=False, )[0] + pipeline.guider.cleanup_models(pipeline.unet) # Perform guidance data.noise_pred, scheduler_step_kwargs = pipeline.guider(guider_data) @@ -2618,29 +2619,20 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: data.extra_step_kwargs = self.prepare_extra_step_kwargs(pipeline, data.generator, data.eta) data.num_warmup_steps = max(len(data.timesteps) - data.num_inference_steps * pipeline.scheduler.order, 0) + pipeline.guider.set_input_fields( + prompt_embeds=("prompt_embeds", "negative_prompt_embeds"), + add_time_ids=("add_time_ids", "negative_add_time_ids"), + pooled_prompt_embeds=("pooled_prompt_embeds", "negative_pooled_prompt_embeds"), + ip_adapter_embeds=("ip_adapter_embeds", "negative_ip_adapter_embeds"), + ) + # (5) Denoise loop with pipeline.progress_bar(total=data.num_inference_steps) as progress_bar: for i, t in enumerate(data.timesteps): pipeline.guider.set_state(step=i, num_inference_steps=data.num_inference_steps, timestep=t) + guider_data = pipeline.guider.prepare_inputs(data) - ( - latents, - prompt_embeds, - add_time_ids, - pooled_prompt_embeds, - mask, - masked_image_latents, - ip_adapter_embeds, - ) = pipeline.guider.prepare_inputs( - pipeline.unet, - data.latents, - (data.prompt_embeds, data.negative_prompt_embeds), - (data.add_time_ids, data.negative_add_time_ids), - (data.pooled_prompt_embeds, data.negative_pooled_prompt_embeds), - data.mask, - data.masked_image_latents, - (data.ip_adapter_embeds, data.negative_ip_adapter_embeds), - ) + data.scaled_latents = pipeline.scheduler.scale_model_input(data.latents, t) if isinstance(data.controlnet_keep[i], list): data.cond_scale = [c * s for c, s in zip(data.controlnet_conditioning_scale, data.controlnet_keep[i])] @@ -2649,67 +2641,63 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: if isinstance(data.controlnet_cond_scale, list): data.controlnet_cond_scale = data.controlnet_cond_scale[0] data.cond_scale = data.controlnet_cond_scale * data.controlnet_keep[i] - - for batch_index, ( - latents_i, prompt_embeds_i, add_time_ids_i, pooled_prompt_embeds_i, mask_i, masked_image_latents_i, ip_adapter_embeds_i - ) in enumerate(zip( - latents, prompt_embeds, add_time_ids, pooled_prompt_embeds, mask, masked_image_latents, ip_adapter_embeds - )): + + for batch in guider_data: pipeline.guider.prepare_models(pipeline.unet) - latents_i = pipeline.scheduler.scale_model_input(latents_i, t) - - # Prepare for inpainting - if data.num_channels_unet == 9: - latents_i = torch.cat([latents_i, mask_i, masked_image_latents_i], dim=1) # Prepare additional conditionings - data.added_cond_kwargs = { - "text_embeds": pooled_prompt_embeds_i, - "time_ids": add_time_ids_i, + batch.added_cond_kwargs = { + "text_embeds": batch.pooled_prompt_embeds, + "time_ids": batch.add_time_ids, } - if ip_adapter_embeds_i is not None: - data.added_cond_kwargs["image_embeds"] = ip_adapter_embeds_i + if batch.ip_adapter_embeds is not None: + batch.added_cond_kwargs["image_embeds"] = batch.ip_adapter_embeds # Prepare controlnet additional conditionings - data.controlnet_added_cond_kwargs = { - "text_embeds": pooled_prompt_embeds_i, - "time_ids": add_time_ids_i, + batch.controlnet_added_cond_kwargs = { + "text_embeds": batch.pooled_prompt_embeds, + "time_ids": batch.add_time_ids, } + # Will always be run atleast once with every guider if pipeline.guider.is_conditional or not data.guess_mode: data.down_block_res_samples, data.mid_block_res_sample = pipeline.controlnet( - latents_i, + data.scaled_latents, t, - encoder_hidden_states=prompt_embeds_i, + encoder_hidden_states=batch.prompt_embeds, controlnet_cond=data.control_image, conditioning_scale=data.cond_scale, guess_mode=data.guess_mode, - added_cond_kwargs=data.controlnet_added_cond_kwargs, + added_cond_kwargs=batch.controlnet_added_cond_kwargs, return_dict=False, ) - elif pipeline.guider.is_unconditional and data.guess_mode: - data.down_block_res_samples = [torch.zeros_like(d) for d in data.down_block_res_samples] - data.mid_block_res_sample = torch.zeros_like(data.mid_block_res_sample) - + + batch.down_block_res_samples = data.down_block_res_samples + batch.mid_block_res_sample = data.mid_block_res_sample + + if pipeline.guider.is_unconditional and data.guess_mode: + batch.down_block_res_samples = [torch.zeros_like(d) for d in data.down_block_res_samples] + batch.mid_block_res_sample = torch.zeros_like(data.mid_block_res_sample) + + # Prepare for inpainting if data.num_channels_unet == 9: - latents_i = torch.cat([latents_i, mask_i, masked_image_latents_i], dim=1) + data.scaled_latents = torch.cat([data.scaled_latents, data.mask, data.masked_image_latents], dim=1) - data.noise_pred = pipeline.unet( - latents_i, + batch.noise_pred = pipeline.unet( + data.scaled_latents, t, - encoder_hidden_states=prompt_embeds_i, + encoder_hidden_states=batch.prompt_embeds, timestep_cond=data.timestep_cond, cross_attention_kwargs=data.cross_attention_kwargs, - added_cond_kwargs=data.added_cond_kwargs, - down_block_additional_residuals=data.down_block_res_samples, - mid_block_additional_residual=data.mid_block_res_sample, + added_cond_kwargs=batch.added_cond_kwargs, + down_block_additional_residuals=batch.down_block_res_samples, + mid_block_additional_residual=batch.mid_block_res_sample, return_dict=False, )[0] - data.noise_pred = pipeline.guider.prepare_outputs(pipeline.unet, data.noise_pred) + pipeline.guider.cleanup_models(pipeline.unet) # Perform guidance - outputs, scheduler_step_kwargs = pipeline.guider.outputs - data.noise_pred = pipeline.guider(**outputs) + data.noise_pred, scheduler_step_kwargs = pipeline.guider(guider_data) # Perform scheduler step using the predicted output data.latents_dtype = data.latents.dtype @@ -3035,28 +3023,19 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: data.extra_step_kwargs = self.prepare_extra_step_kwargs(pipeline, data.generator, data.eta) data.num_warmup_steps = max(len(data.timesteps) - data.num_inference_steps * pipeline.scheduler.order, 0) + pipeline.guider.set_input_fields( + prompt_embeds=("prompt_embeds", "negative_prompt_embeds"), + add_time_ids=("add_time_ids", "negative_add_time_ids"), + pooled_prompt_embeds=("pooled_prompt_embeds", "negative_pooled_prompt_embeds"), + ip_adapter_embeds=("ip_adapter_embeds", "negative_ip_adapter_embeds"), + ) + with pipeline.progress_bar(total=data.num_inference_steps) as progress_bar: for i, t in enumerate(data.timesteps): pipeline.guider.set_state(step=i, num_inference_steps=data.num_inference_steps, timestep=t) + guider_data = pipeline.guider.prepare_inputs(data) - ( - latents, - prompt_embeds, - add_time_ids, - pooled_prompt_embeds, - mask, - masked_image_latents, - ip_adapter_embeds, - ) = pipeline.guider.prepare_inputs( - pipeline.unet, - data.latents, - (data.prompt_embeds, data.negative_prompt_embeds), - (data.add_time_ids, data.negative_add_time_ids), - (data.pooled_prompt_embeds, data.negative_pooled_prompt_embeds), - data.mask, - data.masked_image_latents, - (data.ip_adapter_embeds, data.negative_ip_adapter_embeds), - ) + data.scaled_latents = pipeline.scheduler.scale_model_input(data.latents, t) if isinstance(data.controlnet_keep[i], list): data.cond_scale = [c * s for c, s in zip(data.controlnet_conditioning_scale, data.controlnet_keep[i])] @@ -3065,69 +3044,64 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: if isinstance(data.controlnet_cond_scale, list): data.controlnet_cond_scale = data.controlnet_cond_scale[0] data.cond_scale = data.controlnet_cond_scale * data.controlnet_keep[i] - - for batch_index, ( - latents_i, prompt_embeds_i, add_time_ids_i, pooled_prompt_embeds_i, mask_i, masked_image_latents_i, ip_adapter_embeds_i - ) in enumerate(zip( - latents, prompt_embeds, add_time_ids, pooled_prompt_embeds, mask, masked_image_latents, ip_adapter_embeds - )): + + for batch in guider_data: pipeline.guider.prepare_models(pipeline.unet) - latents_i = pipeline.scheduler.scale_model_input(latents_i, t) - - # Prepare for inpainting - if data.num_channels_unet == 9: - latents_i = torch.cat([latents_i, mask_i, masked_image_latents_i], dim=1) # Prepare additional conditionings - data.added_cond_kwargs = { - "text_embeds": pooled_prompt_embeds_i, - "time_ids": add_time_ids_i, + batch.added_cond_kwargs = { + "text_embeds": batch.pooled_prompt_embeds, + "time_ids": batch.add_time_ids, } - if ip_adapter_embeds_i is not None: - data.added_cond_kwargs["image_embeds"] = ip_adapter_embeds_i + if batch.ip_adapter_embeds is not None: + batch.added_cond_kwargs["image_embeds"] = batch.ip_adapter_embeds # Prepare controlnet additional conditionings - data.controlnet_added_cond_kwargs = { - "text_embeds": pooled_prompt_embeds_i, - "time_ids": add_time_ids_i, + batch.controlnet_added_cond_kwargs = { + "text_embeds": batch.pooled_prompt_embeds, + "time_ids": batch.add_time_ids, } - + + # Will always be run atleast once with every guider if pipeline.guider.is_conditional or not data.guess_mode: data.down_block_res_samples, data.mid_block_res_sample = pipeline.controlnet( - latents_i, + data.scaled_latents, t, - encoder_hidden_states=prompt_embeds_i, + encoder_hidden_states=batch.prompt_embeds, controlnet_cond=data.control_image, control_type=data.control_type, control_type_idx=data.control_mode, conditioning_scale=data.cond_scale, guess_mode=data.guess_mode, - added_cond_kwargs=data.controlnet_added_cond_kwargs, + added_cond_kwargs=batch.controlnet_added_cond_kwargs, return_dict=False, ) - elif pipeline.guider.is_unconditional and data.guess_mode: - data.down_block_res_samples = [torch.zeros_like(d) for d in data.down_block_res_samples] - data.mid_block_res_sample = torch.zeros_like(data.mid_block_res_sample) + + batch.down_block_res_samples = data.down_block_res_samples + batch.mid_block_res_sample = data.mid_block_res_sample + + if pipeline.guider.is_unconditional and data.guess_mode: + batch.down_block_res_samples = [torch.zeros_like(d) for d in data.down_block_res_samples] + batch.mid_block_res_sample = torch.zeros_like(data.mid_block_res_sample) if data.num_channels_unet == 9: - latents_i = torch.cat([latents_i, mask_i, masked_image_latents_i], dim=1) + data.scaled_latents = torch.cat([data.scaled_latents, data.mask, data.masked_image_latents], dim=1) - data.noise_pred = pipeline.unet( - latents_i, + batch.noise_pred = pipeline.unet( + data.scaled_latents, t, - encoder_hidden_states=prompt_embeds_i, + encoder_hidden_states=batch.prompt_embeds, timestep_cond=data.timestep_cond, cross_attention_kwargs=data.cross_attention_kwargs, - added_cond_kwargs=data.added_cond_kwargs, - down_block_additional_residuals=data.down_block_res_samples, - mid_block_additional_residual=data.mid_block_res_sample, + added_cond_kwargs=batch.added_cond_kwargs, + down_block_additional_residuals=batch.down_block_res_samples, + mid_block_additional_residual=batch.mid_block_res_sample, return_dict=False, )[0] - data.noise_pred = pipeline.guider.prepare_outputs(pipeline.unet, data.noise_pred) + pipeline.guider.cleanup_models(pipeline.unet) # Perform guidance - outputs, scheduler_step_kwargs = pipeline.guider.outputs - data.noise_pred = pipeline.guider(**outputs) + data.noise_pred, scheduler_step_kwargs = pipeline.guider(guider_data) # Perform scheduler step using the predicted output data.latents_dtype = data.latents.dtype From 5a4d2c72d526ff09b3c6380c3902438f4621dc7d Mon Sep 17 00:00:00 2001 From: Aryan Date: Fri, 25 Apr 2025 20:58:30 +0200 Subject: [PATCH 21/24] rename enable/disable --- src/diffusers/guiders/guider_utils.py | 4 ++-- .../pipeline_stable_diffusion_xl_modular.py | 12 ++++++------ 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/src/diffusers/guiders/guider_utils.py b/src/diffusers/guiders/guider_utils.py index f144a77aa11c..179e8acf2731 100644 --- a/src/diffusers/guiders/guider_utils.py +++ b/src/diffusers/guiders/guider_utils.py @@ -57,10 +57,10 @@ def __init__(self, start: float = 0.0, stop: float = 1.0): "`_input_predictions` must be a list of required prediction names for the guidance technique." ) - def _force_disable(self): + def disable(self): self._enabled = False - def _force_enable(self): + def enable(self): self._enabled = True def set_state(self, step: int, num_inference_steps: int, timestep: torch.LongTensor) -> None: diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py index c643ab399833..5cf4b937b04f 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py @@ -2231,9 +2231,9 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: data.num_channels_unet = pipeline.unet.config.in_channels data.disable_guidance = True if pipeline.unet.config.time_cond_proj_dim is not None else False if data.disable_guidance: - pipeline.guider._force_disable() + pipeline.guider.disable() else: - pipeline.guider._force_enable() + pipeline.guider.enable() # Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline data.extra_step_kwargs = self.prepare_extra_step_kwargs(pipeline, data.generator, data.eta) @@ -2611,9 +2611,9 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: # (2) Prepare conditional inputs for unet using the guider data.disable_guidance = True if pipeline.unet.config.time_cond_proj_dim is not None else False if data.disable_guidance: - pipeline.guider._force_disable() + pipeline.guider.disable() else: - pipeline.guider._force_enable() + pipeline.guider.enable() # (4) Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline data.extra_step_kwargs = self.prepare_extra_step_kwargs(pipeline, data.generator, data.eta) @@ -3011,9 +3011,9 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: # adding default guider arguments: disable_guidance, guidance_scale, guidance_rescale data.disable_guidance = True if pipeline.unet.config.time_cond_proj_dim is not None else False if data.disable_guidance: - pipeline.guider._force_disable() + pipeline.guider.disable() else: - pipeline.guider._force_enable() + pipeline.guider.enable() data.control_type = data.control_type.reshape(1, -1).to(data.device, dtype=data.prompt_embeds.dtype) repeat_by = data.batch_size * data.num_images_per_prompt // data.control_type.shape[0] From 53ebfa139729e8b5a9ec5510105905f8463303b5 Mon Sep 17 00:00:00 2001 From: Aryan Date: Fri, 25 Apr 2025 21:00:58 +0200 Subject: [PATCH 22/24] remove cfg++ for now --- src/diffusers/__init__.py | 2 - src/diffusers/guiders/__init__.py | 1 - .../classifier_free_guidance_plus_plus.py | 111 ------------------ src/diffusers/schedulers/scheduling_ddim.py | 9 +- .../schedulers/scheduling_euler_discrete.py | 37 ------ 5 files changed, 1 insertion(+), 159 deletions(-) delete mode 100644 src/diffusers/guiders/classifier_free_guidance_plus_plus.py diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 424011961ab0..a4f55acf8b70 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -134,7 +134,6 @@ [ "AdaptiveProjectedGuidance", "AutoGuidance", - "CFGPlusPlusGuidance", "ClassifierFreeGuidance", "ClassifierFreeZeroStarGuidance", "SkipLayerGuidance", @@ -730,7 +729,6 @@ from .guiders import ( AdaptiveProjectedGuidance, AutoGuidance, - CFGPlusPlusGuidance, ClassifierFreeGuidance, ClassifierFreeZeroStarGuidance, SkipLayerGuidance, diff --git a/src/diffusers/guiders/__init__.py b/src/diffusers/guiders/__init__.py index 56e95c92b697..3c1ee293382d 100644 --- a/src/diffusers/guiders/__init__.py +++ b/src/diffusers/guiders/__init__.py @@ -20,7 +20,6 @@ if is_torch_available(): from .adaptive_projected_guidance import AdaptiveProjectedGuidance from .auto_guidance import AutoGuidance - from .classifier_free_guidance_plus_plus import CFGPlusPlusGuidance from .classifier_free_guidance import ClassifierFreeGuidance from .classifier_free_zero_star_guidance import ClassifierFreeZeroStarGuidance from .skip_layer_guidance import SkipLayerGuidance diff --git a/src/diffusers/guiders/classifier_free_guidance_plus_plus.py b/src/diffusers/guiders/classifier_free_guidance_plus_plus.py deleted file mode 100644 index 1f44f883c248..000000000000 --- a/src/diffusers/guiders/classifier_free_guidance_plus_plus.py +++ /dev/null @@ -1,111 +0,0 @@ -# Copyright 2024 The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from typing import Dict, Optional, List, TYPE_CHECKING - -import torch - -from .guider_utils import BaseGuidance, rescale_noise_cfg - -if TYPE_CHECKING: - from ..pipelines.modular_pipeline import BlockState - - -class CFGPlusPlusGuidance(BaseGuidance): - """ - CFG++: https://huggingface.co/papers/2406.08070 - - Args: - guidance_scale (`float`, defaults to `0.7`): - The scale parameter for classifier-free guidance. Higher values result in stronger conditioning on the text - prompt, while lower values allow for more freedom in generation. Higher values may lead to saturation and - deterioration of image quality. - guidance_rescale (`float`, defaults to `0.0`): - The rescale factor applied to the noise predictions. This is used to improve image quality and fix - overexposure. Based on Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are - Flawed](https://huggingface.co/papers/2305.08891). - use_original_formulation (`bool`, defaults to `False`): - Whether to use the original formulation of classifier-free guidance as proposed in the paper. By default, - we use the diffusers-native implementation that has been in the codebase for a long time. See - [~guiders.classifier_free_guidance.ClassifierFreeGuidance] for more details. - start (`float`, defaults to `0.0`): - The fraction of the total number of denoising steps after which guidance starts. - stop (`float`, defaults to `1.0`): - The fraction of the total number of denoising steps after which guidance stops. - """ - - _input_predictions = ["pred_cond", "pred_uncond"] - - def __init__( - self, - guidance_scale: float = 0.7, - guidance_rescale: float = 0.0, - use_original_formulation: bool = False, - start: float = 0.0, - stop: float = 1.0, - ): - super().__init__(start, stop) - - self.guidance_scale = guidance_scale - self.guidance_rescale = guidance_rescale - self.use_original_formulation = use_original_formulation - - def prepare_inputs(self, data: "BlockState") -> List["BlockState"]: - tuple_indices = [0] if self.num_conditions == 1 else [0, 1] - data_batches = [] - for i in range(self.num_conditions): - data_batch = self._prepare_batch(self._input_fields, data, tuple_indices[i], self._input_predictions[i]) - data_batches.append(data_batch) - return data_batches - - def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> torch.Tensor: - pred = None - - if not self._is_cfgpp_enabled(): - pred = pred_cond - else: - shift = pred_cond - pred_uncond - pred = pred_cond if self.use_original_formulation else pred_uncond - pred = pred + self.guidance_scale * shift - - if self.guidance_rescale > 0.0: - pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale) - - scheduler_kwargs = {} - if self._is_cfgpp_enabled(): - scheduler_kwargs = {"_use_cfgpp": True, "_model_output_uncond": pred_uncond} - return pred, scheduler_kwargs - - @property - def is_conditional(self) -> bool: - return self._count_prepared == 1 - - @property - def num_conditions(self) -> int: - num_conditions = 1 - if self._is_cfgpp_enabled(): - num_conditions += 1 - return num_conditions - - def _is_cfgpp_enabled(self) -> bool: - if not self._enabled: - return False - - is_within_range = True - if self._num_inference_steps is not None: - skip_start_step = int(self._start * self._num_inference_steps) - skip_stop_step = int(self._stop * self._num_inference_steps) - is_within_range = skip_start_step <= self._step < skip_stop_step - - return is_within_range diff --git a/src/diffusers/schedulers/scheduling_ddim.py b/src/diffusers/schedulers/scheduling_ddim.py index 2e74c9bbfccd..13c9b3b4a5e9 100644 --- a/src/diffusers/schedulers/scheduling_ddim.py +++ b/src/diffusers/schedulers/scheduling_ddim.py @@ -349,8 +349,6 @@ def step( generator=None, variance_noise: Optional[torch.Tensor] = None, return_dict: bool = True, - _model_output_uncond: Optional[torch.Tensor] = None, - _use_cfgpp: bool = True, ) -> Union[DDIMSchedulerOutput, Tuple]: """ Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion @@ -388,11 +386,6 @@ def step( raise ValueError( "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" ) - - if _use_cfgpp and self.config.prediction_type != "epsilon": - raise ValueError( - f"CFG++ is only supported for prediction type `epsilon`, but got {self.config.prediction_type}." - ) # See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf # Ideally, read DDIM paper in-detail understanding @@ -418,7 +411,7 @@ def step( # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf if self.config.prediction_type == "epsilon": pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5) - pred_epsilon = model_output if not _use_cfgpp else _model_output_uncond + pred_epsilon = model_output elif self.config.prediction_type == "sample": pred_original_sample = model_output pred_epsilon = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5) diff --git a/src/diffusers/schedulers/scheduling_euler_discrete.py b/src/diffusers/schedulers/scheduling_euler_discrete.py index 4c82ca7e389b..fbb33fe8b41c 100644 --- a/src/diffusers/schedulers/scheduling_euler_discrete.py +++ b/src/diffusers/schedulers/scheduling_euler_discrete.py @@ -585,7 +585,6 @@ def step( generator: Optional[torch.Generator] = None, return_dict: bool = True, _model_output_uncond: Optional[torch.Tensor] = None, - _use_cfgpp: bool = False, ) -> Union[EulerDiscreteSchedulerOutput, Tuple]: """ Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion @@ -629,11 +628,6 @@ def step( "The `scale_model_input` function should be called before `step` to ensure correct denoising. " "See `StableDiffusionPipeline` for a usage example." ) - - if _use_cfgpp and self.config.prediction_type != "epsilon": - raise ValueError( - f"CFG++ is only supported for prediction type `epsilon`, but got {self.config.prediction_type}." - ) if self.step_index is None: self._init_step_index(timestep) @@ -675,37 +669,6 @@ def step( dt = self.sigmas[self.step_index + 1] - sigma_hat prev_sample = sample + derivative * dt - if _use_cfgpp: - prev_sample = prev_sample + (_model_output_uncond - model_output) * self.sigmas[self.step_index + 1] - - # denoised = sample - model_output * sigmas[i] - # d = (sample - denoised) / sigmas[i] - # new_sample = denoised + d * sigmas[i + 1] - - # new_sample = denoised + (sample - denoised) * sigmas[i + 1] / sigmas[i] - # new_sample = sample - model_output * sigmas[i] + model_output * sigmas[i + 1] - # new_sample = sample + model_output * (sigmas[i + 1] - sigmas[i]) - # new_sample = sample - model_output * sigmas[i] + model_output * sigmas[i + 1] --- (1) - - # CFG++ ===== - # denoised = sample - model_output * sigmas[i] - # uncond_denoised = sample - model_output_uncond * sigmas[i] - # d = (sample - uncond_denoised) / sigmas[i] - # new_sample = denoised + d * sigmas[i + 1] - - # new_sample = denoised + (sample - uncond_denoised) * sigmas[i + 1] / sigmas[i] - # new_sample = sample - model_output * sigmas[i] + model_output_uncond * sigmas[i + 1] --- (2) - - # To go from (1) to (2): - # new_sample_2 = new_sample_1 - model_output * sigmas[i + 1] + model_output_uncond * sigmas[i + 1] - # new_sample_2 = new_sample_1 + (model_output_uncond - model_output) * sigmas[i + 1] - # new_sample_2 = new_sample_1 + diff * sigmas[i + 1] - - # diff = model_output_uncond - model_output - # diff = model_output_uncond - (model_output_uncond + g * (model_output_cond - model_output_uncond)) - # diff = model_output_uncond - (g * model_output_cond + (1 - g) * model_output_uncond) - # diff = model_output_uncond - g * model_output_cond + (g - 1) * model_output_uncond - # diff = g * (model_output_uncond - model_output_cond) # Cast sample back to model compatible dtype prev_sample = prev_sample.to(model_output.dtype) From 6bc1dd57dd5059eb29e08705d231c5494a09e917 Mon Sep 17 00:00:00 2001 From: Aryan Date: Fri, 25 Apr 2025 21:10:40 +0200 Subject: [PATCH 23/24] rename do_classifier_free_guidance->prepare_unconditional_embeds --- .../pipeline_stable_diffusion_xl_modular.py | 35 +++++++++---------- .../schedulers/scheduling_euler_discrete.py | 2 +- 2 files changed, 18 insertions(+), 19 deletions(-) diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py index 5cf4b937b04f..2493d5635552 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py @@ -233,10 +233,10 @@ def encode_image(self, components, image, device, num_images_per_prompt, output_ # modified from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds def prepare_ip_adapter_image_embeds( - self, components, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance + self, components, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, prepare_unconditional_embeds ): image_embeds = [] - if do_classifier_free_guidance: + if prepare_unconditional_embeds: negative_image_embeds = [] if ip_adapter_image_embeds is None: if not isinstance(ip_adapter_image, list): @@ -256,11 +256,11 @@ def prepare_ip_adapter_image_embeds( ) image_embeds.append(single_image_embeds[None, :]) - if do_classifier_free_guidance: + if prepare_unconditional_embeds: negative_image_embeds.append(single_negative_image_embeds[None, :]) else: for single_image_embeds in ip_adapter_image_embeds: - if do_classifier_free_guidance: + if prepare_unconditional_embeds: single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2) negative_image_embeds.append(single_negative_image_embeds) image_embeds.append(single_image_embeds) @@ -268,7 +268,7 @@ def prepare_ip_adapter_image_embeds( ip_adapter_image_embeds = [] for i, single_image_embeds in enumerate(image_embeds): single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0) - if do_classifier_free_guidance: + if prepare_unconditional_embeds: single_negative_image_embeds = torch.cat([negative_image_embeds[i]] * num_images_per_prompt, dim=0) single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0) @@ -281,7 +281,7 @@ def prepare_ip_adapter_image_embeds( def __call__(self, pipeline, state: PipelineState) -> PipelineState: data = self.get_block_state(state) - data.do_classifier_free_guidance = pipeline.guider.num_conditions > 1 + data.prepare_unconditional_embeds = pipeline.guider.num_conditions > 1 data.device = pipeline._execution_device data.ip_adapter_embeds = self.prepare_ip_adapter_image_embeds( @@ -290,9 +290,9 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: ip_adapter_image_embeds=None, device=data.device, num_images_per_prompt=1, - do_classifier_free_guidance=data.do_classifier_free_guidance, + prepare_unconditional_embeds=data.prepare_unconditional_embeds, ) - if data.do_classifier_free_guidance: + if data.prepare_unconditional_embeds: data.negative_ip_adapter_embeds = [] for i, image_embeds in enumerate(data.ip_adapter_embeds): negative_image_embeds, image_embeds = image_embeds.chunk(2) @@ -355,7 +355,6 @@ def check_inputs(self, pipeline, data): elif data.prompt_2 is not None and (not isinstance(data.prompt_2, str) and not isinstance(data.prompt_2, list)): raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(data.prompt_2)}") - # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.encode_prompt with self -> components def encode_prompt( self, components, @@ -363,7 +362,7 @@ def encode_prompt( prompt_2: Optional[str] = None, device: Optional[torch.device] = None, num_images_per_prompt: int = 1, - do_classifier_free_guidance: bool = True, + prepare_unconditional_embeds: bool = True, negative_prompt: Optional[str] = None, negative_prompt_2: Optional[str] = None, prompt_embeds: Optional[torch.Tensor] = None, @@ -386,8 +385,8 @@ def encode_prompt( torch device num_images_per_prompt (`int`): number of images that should be generated per prompt - do_classifier_free_guidance (`bool`): - whether to use classifier free guidance or not + prepare_unconditional_embeds (`bool`): + whether to use prepare unconditional embeddings or not negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is @@ -495,10 +494,10 @@ def encode_prompt( # get unconditional embeddings for classifier free guidance zero_out_negative_prompt = negative_prompt is None and components.config.force_zeros_for_empty_prompt - if do_classifier_free_guidance and negative_prompt_embeds is None and zero_out_negative_prompt: + if prepare_unconditional_embeds and negative_prompt_embeds is None and zero_out_negative_prompt: negative_prompt_embeds = torch.zeros_like(prompt_embeds) negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds) - elif do_classifier_free_guidance and negative_prompt_embeds is None: + elif prepare_unconditional_embeds and negative_prompt_embeds is None: negative_prompt = negative_prompt or "" negative_prompt_2 = negative_prompt_2 or negative_prompt @@ -559,7 +558,7 @@ def encode_prompt( prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) - if do_classifier_free_guidance: + if prepare_unconditional_embeds: # duplicate unconditional embeddings for each generation per prompt, using mps friendly method seq_len = negative_prompt_embeds.shape[1] @@ -574,7 +573,7 @@ def encode_prompt( pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( bs_embed * num_images_per_prompt, -1 ) - if do_classifier_free_guidance: + if prepare_unconditional_embeds: negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( bs_embed * num_images_per_prompt, -1 ) @@ -598,7 +597,7 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: data = self.get_block_state(state) self.check_inputs(pipeline, data) - data.do_classifier_free_guidance = pipeline.guider.num_conditions > 1 + data.prepare_unconditional_embeds = pipeline.guider.num_conditions > 1 data.device = pipeline._execution_device # Encode input prompt @@ -616,7 +615,7 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: data.prompt_2, data.device, 1, - data.do_classifier_free_guidance, + data.prepare_unconditional_embeds, data.negative_prompt, data.negative_prompt_2, prompt_embeds=None, diff --git a/src/diffusers/schedulers/scheduling_euler_discrete.py b/src/diffusers/schedulers/scheduling_euler_discrete.py index fbb33fe8b41c..9c8a8afaebdc 100644 --- a/src/diffusers/schedulers/scheduling_euler_discrete.py +++ b/src/diffusers/schedulers/scheduling_euler_discrete.py @@ -669,7 +669,7 @@ def step( dt = self.sigmas[self.step_index + 1] - sigma_hat prev_sample = sample + derivative * dt - + # Cast sample back to model compatible dtype prev_sample = prev_sample.to(model_output.dtype) From 704aef49ecbc71a0e095a880ce6469aca70e25e5 Mon Sep 17 00:00:00 2001 From: Aryan Date: Sat, 26 Apr 2025 00:03:25 +0200 Subject: [PATCH 24/24] remove unused --- src/diffusers/guiders/guider_utils.py | 39 ------------------- .../schedulers/scheduling_euler_discrete.py | 1 - 2 files changed, 40 deletions(-) diff --git a/src/diffusers/guiders/guider_utils.py b/src/diffusers/guiders/guider_utils.py index 179e8acf2731..7d005442e89c 100644 --- a/src/diffusers/guiders/guider_utils.py +++ b/src/diffusers/guiders/guider_utils.py @@ -20,7 +20,6 @@ if TYPE_CHECKING: - from ..models.attention_processor import AttentionProcessor from ..pipelines.modular_pipeline import BlockState @@ -214,41 +213,3 @@ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg return noise_cfg - - -def _default_prepare_inputs(denoiser: torch.nn.Module, num_conditions: int, *args: Union[Tuple[torch.Tensor], List[torch.Tensor]]) -> Tuple[List[torch.Tensor], ...]: - """ - Prepares the inputs for the denoiser by ensuring that the conditional and unconditional inputs are correctly - prepared based on required number of conditions. This function is used in the `prepare_inputs` method of the - `BaseGuidance` class. - - Either tensors or tuples/lists of tensors can be provided. If a tuple/list is provided, it should contain two elements: - - The first element is the conditional input. - - The second element is the unconditional input or None. - - If only the conditional input is provided, it will be repeated for all batches. - - If both conditional and unconditional inputs are provided, they are alternated as batches of data. - """ - list_of_inputs = [] - for arg in args: - if arg is None or isinstance(arg, torch.Tensor): - list_of_inputs.append([arg] * num_conditions) - elif isinstance(arg, (tuple, list)): - if len(arg) != 2: - raise ValueError( - f"Expected a tuple or list of length 2, but got {len(arg)} for argument {arg}. Please provide a tuple/list of length 2 " - f"with the first element being the conditional input and the second element being the unconditional input or None." - ) - if arg[1] is None: - # Only conditioning inputs for all batches - list_of_inputs.append([arg[0]] * num_conditions) - else: - # Alternating conditional and unconditional inputs as batches - inputs = [arg[i % 2] for i in range(num_conditions)] - list_of_inputs.append(inputs) - else: - raise ValueError( - f"Expected a tensor, tuple, or list, but got {type(arg)} for argument {arg}. Please provide a tensor, tuple, or list." - ) - return tuple(list_of_inputs) diff --git a/src/diffusers/schedulers/scheduling_euler_discrete.py b/src/diffusers/schedulers/scheduling_euler_discrete.py index 9c8a8afaebdc..56757f3ca197 100644 --- a/src/diffusers/schedulers/scheduling_euler_discrete.py +++ b/src/diffusers/schedulers/scheduling_euler_discrete.py @@ -584,7 +584,6 @@ def step( s_noise: float = 1.0, generator: Optional[torch.Generator] = None, return_dict: bool = True, - _model_output_uncond: Optional[torch.Tensor] = None, ) -> Union[EulerDiscreteSchedulerOutput, Tuple]: """ Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion