From 0c4c1a843089a6411233a69b7e27473d78e869c8 Mon Sep 17 00:00:00 2001 From: Aryan Date: Mon, 14 Apr 2025 10:04:13 +0200 Subject: [PATCH 01/39] cfg; slg; pag; sdxl without controlnet --- src/diffusers/__init__.py | 15 + src/diffusers/guiders/__init__.py | 24 ++ .../guiders/classifier_free_guidance.py | 111 +++++++ src/diffusers/guiders/guider_utils.py | 148 ++++++++++ src/diffusers/guiders/skip_layer_guidance.py | 218 ++++++++++++++ src/diffusers/hooks/__init__.py | 1 + src/diffusers/hooks/_common.py | 32 +++ src/diffusers/hooks/_helpers.py | 271 ++++++++++++++++++ src/diffusers/hooks/layer_skip.py | 194 +++++++++++++ .../pipeline_stable_diffusion_xl_modular.py | 147 +++++----- src/diffusers/utils/torch_utils.py | 5 + 11 files changed, 1093 insertions(+), 73 deletions(-) create mode 100644 src/diffusers/guiders/__init__.py create mode 100644 src/diffusers/guiders/classifier_free_guidance.py create mode 100644 src/diffusers/guiders/guider_utils.py create mode 100644 src/diffusers/guiders/skip_layer_guidance.py create mode 100644 src/diffusers/hooks/_common.py create mode 100644 src/diffusers/hooks/_helpers.py create mode 100644 src/diffusers/hooks/layer_skip.py diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 440c67da629d..d8e274cb93b6 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -33,6 +33,7 @@ _import_structure = { "configuration_utils": ["ConfigMixin"], + "guiders": [], "hooks": [], "loaders": ["FromOriginalModelMixin"], "models": [], @@ -129,12 +130,20 @@ _import_structure["utils.dummy_pt_objects"] = [name for name in dir(dummy_pt_objects) if not name.startswith("_")] else: + _import_structure["guiders"].extend( + [ + "ClassifierFreeGuidance", + "SkipLayerGuidance", + ] + ) _import_structure["hooks"].extend( [ "FasterCacheConfig", "HookRegistry", "PyramidAttentionBroadcastConfig", + "LayerSkipConfig", "apply_faster_cache", + "apply_layer_skip", "apply_pyramid_attention_broadcast", ] ) @@ -711,10 +720,16 @@ except OptionalDependencyNotAvailable: from .utils.dummy_pt_objects import * # noqa F403 else: + from .guiders import ( + ClassifierFreeGuidance, + SkipLayerGuidance, + ) from .hooks import ( FasterCacheConfig, HookRegistry, + LayerSkipConfig, PyramidAttentionBroadcastConfig, + apply_layer_skip, apply_faster_cache, apply_pyramid_attention_broadcast, ) diff --git a/src/diffusers/guiders/__init__.py b/src/diffusers/guiders/__init__.py new file mode 100644 index 000000000000..adef65277b6d --- /dev/null +++ b/src/diffusers/guiders/__init__.py @@ -0,0 +1,24 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Union + +from ..utils import is_torch_available + + +if is_torch_available(): + from .classifier_free_guidance import ClassifierFreeGuidance + from .skip_layer_guidance import SkipLayerGuidance + + GuiderType = Union[ClassifierFreeGuidance, SkipLayerGuidance] diff --git a/src/diffusers/guiders/classifier_free_guidance.py b/src/diffusers/guiders/classifier_free_guidance.py new file mode 100644 index 000000000000..4048d70484c5 --- /dev/null +++ b/src/diffusers/guiders/classifier_free_guidance.py @@ -0,0 +1,111 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from typing import Optional, Union, Tuple, List + +import torch + +from .guider_utils import BaseGuidance, rescale_noise_cfg, _default_prepare_inputs + + +class ClassifierFreeGuidance(BaseGuidance): + """ + Classifier-free guidance (CFG): https://huggingface.co/papers/2207.12598 + CFG is a technique used to improve generation quality and condition-following in diffusion models. It works by + jointly training a model on both conditional and unconditional data, and using a weighted sum of the two during + inference. This allows the model to tradeoff between generation quality and sample diversity. + The original paper proposes scaling and shifting the conditional distribution based on the difference between + conditional and unconditional predictions. [x_pred = x_cond + scale * (x_cond - x_uncond)] + Diffusers implemented the scaling and shifting on the unconditional prediction instead based on the [Imagen + paper](https://huggingface.co/papers/2205.11487), which is equivalent to what the original paper proposed in + theory. [x_pred = x_uncond + scale * (x_cond - x_uncond)] + The intution behind the original formulation can be thought of as moving the conditional distribution estimates + further away from the unconditional distribution estimates, while the diffusers-native implementation can be + thought of as moving the unconditional distribution towards the conditional distribution estimates to get rid of + the unconditional predictions (usually negative features like "bad quality, bad anotomy, watermarks", etc.) + The `use_original_formulation` argument can be set to `True` to use the original CFG formulation mentioned in the + paper. By default, we use the diffusers-native implementation that has been in the codebase for a long time. + Args: + guidance_scale (`float`, defaults to `7.5`): + The scale parameter for classifier-free guidance. Higher values result in stronger conditioning on the text + prompt, while lower values allow for more freedom in generation. Higher values may lead to saturation and + deterioration of image quality. + guidance_rescale (`float`, defaults to `0.0`): + The rescale factor applied to the noise predictions. This is used to improve image quality and fix + overexposure. Based on Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are + Flawed](https://huggingface.co/papers/2305.08891). + use_original_formulation (`bool`, defaults to `False`): + Whether to use the original formulation of classifier-free guidance as proposed in the paper. By default, + we use the diffusers-native implementation that has been in the codebase for a long time. See + [~guiders.classifier_free_guidance.ClassifierFreeGuidance] for more details. + start (`float`, defaults to `0.0`): + The fraction of the total number of denoising steps after which guidance starts. + stop (`float`, defaults to `1.0`): + The fraction of the total number of denoising steps after which guidance stops. + """ + + _input_predictions = ["pred_cond", "pred_uncond"] + + def __init__( + self, guidance_scale: float = 7.5, guidance_rescale: float = 0.0, use_original_formulation: bool = False, start: float = 0.0, stop: float = 1.0 + ): + super().__init__(start, stop) + + self.guidance_scale = guidance_scale + self.guidance_rescale = guidance_rescale + self.use_original_formulation = use_original_formulation + + def prepare_inputs(self, denoiser: torch.nn.Module, *args: Union[Tuple[torch.Tensor], List[torch.Tensor]]) -> Tuple[List[torch.Tensor], ...]: + return _default_prepare_inputs(denoiser, self.num_conditions, *args) + + def prepare_outputs(self, denoiser: torch.nn.Module, pred: torch.Tensor) -> None: + self._num_outputs_prepared += 1 + if self._num_outputs_prepared > self.num_conditions: + raise ValueError(f"Expected {self.num_conditions} outputs, but prepare_outputs called more times.") + key = self._input_predictions[self._num_outputs_prepared - 1] + self._preds[key] = pred + + def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> torch.Tensor: + pred = None + + if not self._is_cfg_enabled(): + pred = pred_cond + else: + shift = pred_cond - pred_uncond + pred = pred_cond if self.use_original_formulation else pred_uncond + pred = pred + self.guidance_scale * shift + + if self.guidance_rescale > 0.0: + pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale) + + return pred + + @property + def num_conditions(self) -> int: + num_conditions = 1 + if self._is_cfg_enabled(): + num_conditions += 1 + return num_conditions + + def _is_cfg_enabled(self) -> bool: + skip_start_step = int(self._start * self._num_inference_steps) + skip_stop_step = int(self._stop * self._num_inference_steps) + is_within_range = skip_start_step <= self._step < skip_stop_step + is_close = False + if self.use_original_formulation: + is_close = math.isclose(self.guidance_scale, 0.0) + else: + is_close = math.isclose(self.guidance_scale, 1.0) + return is_within_range and not is_close diff --git a/src/diffusers/guiders/guider_utils.py b/src/diffusers/guiders/guider_utils.py new file mode 100644 index 000000000000..ecde7334b2b9 --- /dev/null +++ b/src/diffusers/guiders/guider_utils.py @@ -0,0 +1,148 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import TYPE_CHECKING, Any, Dict, List, Tuple, Union + +import torch + +from ..utils import get_logger + + +if TYPE_CHECKING: + from ..models.attention_processor import AttentionProcessor + + +logger = get_logger(__name__) # pylint: disable=invalid-name + + +class BaseGuidance: + r"""Base class providing the skeleton for implementing guidance techniques.""" + + _input_predictions = None + + def __init__(self, start: float = 0.0, stop: float = 1.0): + self._start = start + self._stop = stop + self._step: int = None + self._num_inference_steps: int = None + self._timestep: torch.LongTensor = None + self._preds: Dict[str, torch.Tensor] = {} + self._num_outputs_prepared: int = 0 + + if not (0.0 <= start < 1.0): + raise ValueError( + f"Expected `start` to be between 0.0 and 1.0, but got {start}." + ) + if not (start <= stop <= 1.0): + raise ValueError( + f"Expected `stop` to be between {start} and 1.0, but got {stop}." + ) + + if self._input_predictions is None or not isinstance(self._input_predictions, list): + raise ValueError( + "`_input_predictions` must be a list of required prediction names for the guidance technique." + ) + + def set_state(self, step: int, num_inference_steps: int, timestep: torch.LongTensor) -> None: + self._step = step + self._num_inference_steps = num_inference_steps + self._timestep = timestep + self._preds = {} + self._num_outputs_prepared = 0 + + def prepare_inputs(self, denoiser: torch.nn.Module, *args: Union[Tuple[torch.Tensor], List[torch.Tensor]]) -> Tuple[List[torch.Tensor], ...]: + raise NotImplementedError("GuidanceMixin::prepare_inputs must be implemented in subclasses.") + + def prepare_outputs(self, denoiser: torch.nn.Module, pred: torch.Tensor) -> None: + raise NotImplementedError("GuidanceMixin::prepare_outputs must be implemented in subclasses.") + + def __call__(self, **kwargs) -> Any: + if len(kwargs) != self.num_conditions: + raise ValueError( + f"Expected {self.num_conditions} arguments, but got {len(kwargs)}. Please provide the correct number of arguments." + ) + return self.forward(**kwargs) + + def forward(self, *args, **kwargs) -> Any: + raise NotImplementedError("GuidanceMixin::forward must be implemented in subclasses.") + + @property + def num_conditions(self) -> int: + raise NotImplementedError("GuidanceMixin::num_conditions must be implemented in subclasses.") + + @property + def outputs(self) -> Dict[str, torch.Tensor]: + return self._preds + + +def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): + r""" + Rescales `noise_cfg` tensor based on `guidance_rescale` to improve image quality and fix overexposure. Based on + Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are + Flawed](https://arxiv.org/pdf/2305.08891.pdf). + Args: + noise_cfg (`torch.Tensor`): + The predicted noise tensor for the guided diffusion process. + noise_pred_text (`torch.Tensor`): + The predicted noise tensor for the text-guided diffusion process. + guidance_rescale (`float`, *optional*, defaults to 0.0): + A rescale factor applied to the noise predictions. + Returns: + noise_cfg (`torch.Tensor`): The rescaled noise prediction tensor. + """ + std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True) + std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True) + # rescale the results from guidance (fixes overexposure) + noise_pred_rescaled = noise_cfg * (std_text / std_cfg) + # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images + noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg + return noise_cfg + + +def _default_prepare_inputs(denoiser: torch.nn.Module, num_conditions: int, *args: Union[Tuple[torch.Tensor], List[torch.Tensor]]) -> Tuple[List[torch.Tensor], ...]: + """ + Prepares the inputs for the denoiser by ensuring that the conditional and unconditional inputs are correctly + prepared based on required number of conditions. This function is used in the `prepare_inputs` method of the + `GuidanceMixin` class. + + Either tensors or tuples/lists of tensors can be provided. If a tuple/list is provided, it should contain two elements: + - The first element is the conditional input. + - The second element is the unconditional input or None. + + If only the conditional input is provided, it will be repeated for all batches. + + If both conditional and unconditional inputs are provided, they are alternated as batches of data. + """ + list_of_inputs = [] + for arg in args: + if arg is None or isinstance(arg, torch.Tensor): + list_of_inputs.append([arg] * num_conditions) + elif isinstance(arg, (tuple, list)): + if len(arg) != 2: + raise ValueError( + f"Expected a tuple or list of length 2, but got {len(arg)} for argument {arg}. Please provide a tuple/list of length 2 " + f"with the first element being the conditional input and the second element being the unconditional input or None." + ) + if arg[1] is None: + # Only conditioning inputs for all batches + list_of_inputs.append([arg[0]] * num_conditions) + else: + # Alternating conditional and unconditional inputs as batches + inputs = [arg[i % 2] for i in range(num_conditions)] + list_of_inputs.append(inputs) + else: + raise ValueError( + f"Expected a tensor, tuple, or list, but got {type(arg)} for argument {arg}. Please provide a tensor, tuple, or list." + ) + return tuple(list_of_inputs) diff --git a/src/diffusers/guiders/skip_layer_guidance.py b/src/diffusers/guiders/skip_layer_guidance.py new file mode 100644 index 000000000000..e20a700fee6a --- /dev/null +++ b/src/diffusers/guiders/skip_layer_guidance.py @@ -0,0 +1,218 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from typing import List, Optional, Tuple, Union + +import torch + +from ..hooks import HookRegistry, LayerSkipConfig +from ..hooks.layer_skip import _apply_layer_skip_hook +from .guider_utils import BaseGuidance, rescale_noise_cfg + + +class SkipLayerGuidance(BaseGuidance): + """ + Skip Layer Guidance (SLG): https://github.com/Stability-AI/sd3.5 Spatio-Temporal Guidance (STG): + https://huggingface.co/papers/2411.18664 + SLG was introduced by StabilityAI for improving structure and anotomy coherence in generated images. It works by + skipping the forward pass of specified transformer blocks during the denoising process on an additional conditional + batch of data, apart from the conditional and unconditional batches already used in CFG + ([~guiders.classifier_free_guidance.ClassifierFreeGuidance]), and then scaling and shifting the CFG predictions + based on the difference between conditional without skipping and conditional with skipping predictions. + The intution behind SLG can be thought of as moving the CFG predicted distribution estimates further away from + worse versions of the conditional distribution estimates (because skipping layers is equivalent to using a worse + version of the model for the conditional prediction). + STG is an improvement and follow-up work combining ideas from SLG, PAG and similar techniques for improving + generation quality in video diffusion models. + Additional reading: + - [Guiding a Diffusion Model with a Bad Version of Itself](https://huggingface.co/papers/2406.02507) + The values for `skip_layer_guidance_scale`, `skip_layer_guidance_start`, and `skip_layer_guidance_stop` are + defaulted to the recommendations by StabilityAI for Stable Diffusion 3.5 Medium. + Args: + guidance_scale (`float`, defaults to `7.5`): + The scale parameter for classifier-free guidance. Higher values result in stronger conditioning on the text + prompt, while lower values allow for more freedom in generation. Higher values may lead to saturation and + deterioration of image quality. + skip_layer_guidance_scale (`float`, defaults to `2.8`): + The scale parameter for skip layer guidance. Anatomy and structure coherence may improve with higher + values, but it may also lead to overexposure and saturation. + skip_layer_guidance_layers (`int` or `List[int]`, *optional*): + The layer indices to apply skip layer guidance to. Can be a single integer or a list of integers. If not + provided, `skip_layer_config` must be provided. The recommended values are `[7, 8, 9]` for Stable Diffusion + 3.5 Medium. + skip_layer_config (`LayerSkipConfig` or `List[LayerSkipConfig]`, *optional*): + The configuration for the skip layer guidance. Can be a single `LayerSkipConfig` or a list of + `LayerSkipConfig`. If not provided, `skip_layer_guidance_layers` must be provided. + guidance_rescale (`float`, defaults to `0.0`): + The rescale factor applied to the noise predictions. This is used to improve image quality and fix + overexposure. Based on Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are + Flawed](https://huggingface.co/papers/2305.08891). + use_original_formulation (`bool`, defaults to `False`): + Whether to use the original formulation of classifier-free guidance as proposed in the paper. By default, + we use the diffusers-native implementation that has been in the codebase for a long time. See + [~guiders.classifier_free_guidance.ClassifierFreeGuidance] for more details. + start (`float`, defaults to `0.01`): + The fraction of the total number of denoising steps after which guidance starts. + stop (`float`, defaults to `0.2`): + The fraction of the total number of denoising steps after which guidance stops. + """ + + _input_predictions = ["pred_cond", "pred_uncond", "pred_cond_skip"] + + def __init__( + self, + guidance_scale: float = 7.5, + skip_layer_guidance_scale: float = 2.8, + skip_layer_guidance_layers: Optional[Union[int, List[int]]] = None, + skip_layer_config: Union[LayerSkipConfig, List[LayerSkipConfig]] = None, + guidance_rescale: float = 0.0, + use_original_formulation: bool = False, + start: float = 0.01, + stop: float = 0.2, + ): + super().__init__(start, stop) + + self.guidance_scale = guidance_scale + self.skip_layer_guidance_scale = skip_layer_guidance_scale + self.guidance_rescale = guidance_rescale + self.use_original_formulation = use_original_formulation + + if skip_layer_guidance_layers is None and skip_layer_config is None: + raise ValueError( + "Either `skip_layer_guidance_layers` or `skip_layer_config` must be provided to enable Skip Layer Guidance." + ) + if skip_layer_guidance_layers is not None and skip_layer_config is not None: + raise ValueError("Only one of `skip_layer_guidance_layers` or `skip_layer_config` can be provided.") + + if skip_layer_guidance_layers is not None: + if isinstance(skip_layer_guidance_layers, int): + skip_layer_guidance_layers = [skip_layer_guidance_layers] + if not isinstance(skip_layer_guidance_layers, list): + raise ValueError( + f"Expected `skip_layer_guidance_layers` to be an int or a list of ints, but got {type(skip_layer_guidance_layers)}." + ) + skip_layer_config = [LayerSkipConfig(layer, fqn="auto") for layer in skip_layer_guidance_layers] + + if isinstance(skip_layer_config, LayerSkipConfig): + skip_layer_config = [skip_layer_config] + + if not isinstance(skip_layer_config, list): + raise ValueError( + f"Expected `skip_layer_config` to be a LayerSkipConfig or a list of LayerSkipConfig, but got {type(skip_layer_config)}." + ) + + self.skip_layer_config = skip_layer_config + self._skip_layer_hook_names = [f"SkipLayerGuidance_{i}" for i in range(len(self.skip_layer_config))] + + def prepare_inputs(self, denoiser: torch.nn.Module, *args: Union[Tuple[torch.Tensor], List[torch.Tensor]]) -> Tuple[List[torch.Tensor], ...]: + if self._num_outputs_prepared == 0 and self._is_slg_enabled(): + for name, config in zip(self._skip_layer_hook_names, self.skip_layer_config): + _apply_layer_skip_hook(denoiser, config, name=name) + + num_conditions = self.num_conditions + list_of_inputs = [] + for arg in args: + if arg is None or isinstance(arg, torch.Tensor): + list_of_inputs.append([arg] * num_conditions) + elif isinstance(arg, (tuple, list)): + if len(arg) != 2: + raise ValueError( + f"Expected a tuple or list of length 2, but got {len(arg)} for argument {arg}. Please provide a tuple/list of length 2 " + f"with the first element being the conditional input and the second element being the unconditional input or None." + ) + if arg[1] is None: + # Only conditioning inputs for all batches + list_of_inputs.append([arg[0]] * num_conditions) + else: + list_of_inputs.append([arg[0], arg[1], arg[0]]) + else: + raise ValueError( + f"Expected a tensor, tuple, or list, but got {type(arg)} for argument {arg}. Please provide a tensor, tuple, or list." + ) + return tuple(list_of_inputs) + + def prepare_outputs(self, denoiser: torch.nn.Module, pred: torch.Tensor) -> None: + self._num_outputs_prepared += 1 + if self._num_outputs_prepared > self.num_conditions: + raise ValueError(f"Expected {self.num_conditions} outputs, but prepare_outputs called more times.") + key = self._input_predictions[self._num_outputs_prepared - 1] + if not self._is_cfg_enabled() and self._is_slg_enabled(): + # If we're predicting pred_cond and pred_cond_skip only, we need to set the key to pred_cond_skip + # to avoid writing into pred_uncond which is not used + if self._num_outputs_prepared == 2: + key = "pred_cond_skip" + self._preds[key] = pred + + if self._num_outputs_prepared == self.num_conditions: + registry = HookRegistry.check_if_exists_or_initialize(denoiser) + # Remove the hooks after inference + for hook_name in self._skip_layer_hook_names: + registry.remove_hook(hook_name, recurse=True) + + def forward( + self, + pred_cond: torch.Tensor, + pred_uncond: Optional[torch.Tensor] = None, + pred_cond_skip: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + pred = None + + if not self._is_cfg_enabled() and not self._is_slg_enabled(): + pred = pred_cond + elif not self._is_cfg_enabled(): + shift = pred_cond - pred_cond_skip + pred = pred_cond if self.use_original_formulation else pred_cond_skip + pred = pred + self.skip_layer_guidance_scale * shift + elif not self._is_slg_enabled(): + shift = pred_cond - pred_uncond + pred = pred_cond if self.use_original_formulation else pred_uncond + pred = pred + self.guidance_scale * shift + else: + shift = pred_cond - pred_uncond + shift_skip = pred_cond - pred_cond_skip + pred = pred_cond if self.use_original_formulation else pred_uncond + pred = pred + self.guidance_scale * shift + self.skip_layer_guidance_scale * shift_skip + + if self.guidance_rescale > 0.0: + pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale) + + return pred + + @property + def num_conditions(self) -> int: + num_conditions = 1 + if self._is_cfg_enabled(): + num_conditions += 1 + if self._is_slg_enabled(): + num_conditions += 1 + return num_conditions + + def _is_cfg_enabled(self) -> bool: + skip_start_step = int(self._start * self._num_inference_steps) + skip_stop_step = int(self._stop * self._num_inference_steps) + is_within_range = skip_start_step <= self._step < skip_stop_step + is_close = False + if self.use_original_formulation: + is_close = math.isclose(self.guidance_scale, 0.0) + else: + is_close = math.isclose(self.guidance_scale, 1.0) + return is_within_range and not is_close + + def _is_slg_enabled(self) -> bool: + skip_start_step = int(self._start * self._num_inference_steps) + skip_stop_step = int(self._stop * self._num_inference_steps) + is_within_range = skip_start_step < self._step < skip_stop_step + is_zero = math.isclose(self.skip_layer_guidance_scale, 0.0) + return is_within_range and not is_zero diff --git a/src/diffusers/hooks/__init__.py b/src/diffusers/hooks/__init__.py index 764ceb25b465..142ff860371c 100644 --- a/src/diffusers/hooks/__init__.py +++ b/src/diffusers/hooks/__init__.py @@ -5,5 +5,6 @@ from .faster_cache import FasterCacheConfig, apply_faster_cache from .group_offloading import apply_group_offloading from .hooks import HookRegistry, ModelHook + from .layer_skip import LayerSkipConfig, apply_layer_skip from .layerwise_casting import apply_layerwise_casting, apply_layerwise_casting_hook from .pyramid_attention_broadcast import PyramidAttentionBroadcastConfig, apply_pyramid_attention_broadcast diff --git a/src/diffusers/hooks/_common.py b/src/diffusers/hooks/_common.py new file mode 100644 index 000000000000..6ea83dcbf6a7 --- /dev/null +++ b/src/diffusers/hooks/_common.py @@ -0,0 +1,32 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ..models.attention import FeedForward, LuminaFeedForward +from ..models.attention_processor import Attention, MochiAttention + + +_ATTENTION_CLASSES = (Attention, MochiAttention) +_FEEDFORWARD_CLASSES = (FeedForward, LuminaFeedForward) + +_SPATIAL_TRANSFORMER_BLOCK_IDENTIFIERS = ("blocks", "transformer_blocks", "single_transformer_blocks", "layers") +_TEMPORAL_TRANSFORMER_BLOCK_IDENTIFIERS = ("temporal_transformer_blocks",) +_CROSS_TRANSFORMER_BLOCK_IDENTIFIERS = ("blocks", "transformer_blocks", "layers") + +_ALL_TRANSFORMER_BLOCK_IDENTIFIERS = tuple( + { + *_SPATIAL_TRANSFORMER_BLOCK_IDENTIFIERS, + *_TEMPORAL_TRANSFORMER_BLOCK_IDENTIFIERS, + *_CROSS_TRANSFORMER_BLOCK_IDENTIFIERS, + } +) diff --git a/src/diffusers/hooks/_helpers.py b/src/diffusers/hooks/_helpers.py new file mode 100644 index 000000000000..9043ffc41838 --- /dev/null +++ b/src/diffusers/hooks/_helpers.py @@ -0,0 +1,271 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass +from typing import Any, Callable, Type + +from ..models.attention import BasicTransformerBlock +from ..models.attention_processor import AttnProcessor2_0 +from ..models.transformers.cogvideox_transformer_3d import CogVideoXBlock +from ..models.transformers.transformer_cogview4 import CogView4AttnProcessor, CogView4TransformerBlock +from ..models.transformers.transformer_flux import FluxSingleTransformerBlock, FluxTransformerBlock +from ..models.transformers.transformer_hunyuan_video import ( + HunyuanVideoSingleTransformerBlock, + HunyuanVideoTokenReplaceSingleTransformerBlock, + HunyuanVideoTokenReplaceTransformerBlock, + HunyuanVideoTransformerBlock, +) +from ..models.transformers.transformer_ltx import LTXVideoTransformerBlock +from ..models.transformers.transformer_mochi import MochiTransformerBlock +from ..models.transformers.transformer_wan import WanTransformerBlock + + +@dataclass +class AttentionProcessorMetadata: + skip_processor_output_fn: Callable[[Any], Any] + + +@dataclass +class TransformerBlockMetadata: + skip_block_output_fn: Callable[[Any], Any] + return_hidden_states_index: int = None + return_encoder_hidden_states_index: int = None + + +class AttentionProcessorRegistry: + _registry = {} + + @classmethod + def register(cls, model_class: Type, metadata: AttentionProcessorMetadata): + cls._registry[model_class] = metadata + + @classmethod + def get(cls, model_class: Type) -> AttentionProcessorMetadata: + if model_class not in cls._registry: + raise ValueError(f"Model class {model_class} not registered.") + return cls._registry[model_class] + + +class TransformerBlockRegistry: + _registry = {} + + @classmethod + def register(cls, model_class: Type, metadata: TransformerBlockMetadata): + cls._registry[model_class] = metadata + + @classmethod + def get(cls, model_class: Type) -> TransformerBlockMetadata: + if model_class not in cls._registry: + raise ValueError(f"Model class {model_class} not registered.") + return cls._registry[model_class] + + +def _register_attention_processors_metadata(): + # AttnProcessor2_0 + AttentionProcessorRegistry.register( + model_class=AttnProcessor2_0, + metadata=AttentionProcessorMetadata( + skip_processor_output_fn=_skip_proc_output_fn_Attention_AttnProcessor2_0, + ), + ) + + # CogView4AttnProcessor + AttentionProcessorRegistry.register( + model_class=CogView4AttnProcessor, + metadata=AttentionProcessorMetadata( + skip_processor_output_fn=_skip_proc_output_fn_Attention_CogView4AttnProcessor, + ), + ) + + +def _register_transformer_blocks_metadata(): + # BasicTransformerBlock + TransformerBlockRegistry.register( + model_class=BasicTransformerBlock, + metadata=TransformerBlockMetadata( + skip_block_output_fn=_skip_block_output_fn_BasicTransformerBlock, + return_hidden_states_index=0, + return_encoder_hidden_states_index=None, + ), + ) + + # CogVideoX + TransformerBlockRegistry.register( + model_class=CogVideoXBlock, + metadata=TransformerBlockMetadata( + skip_block_output_fn=_skip_block_output_fn_CogVideoXBlock, + return_hidden_states_index=0, + return_encoder_hidden_states_index=1, + ), + ) + + # CogView4 + TransformerBlockRegistry.register( + model_class=CogView4TransformerBlock, + metadata=TransformerBlockMetadata( + skip_block_output_fn=_skip_block_output_fn_CogView4TransformerBlock, + return_hidden_states_index=0, + return_encoder_hidden_states_index=1, + ), + ) + + # Flux + TransformerBlockRegistry.register( + model_class=FluxTransformerBlock, + metadata=TransformerBlockMetadata( + skip_block_output_fn=_skip_block_output_fn_FluxTransformerBlock, + return_hidden_states_index=1, + return_encoder_hidden_states_index=0, + ), + ) + TransformerBlockRegistry.register( + model_class=FluxSingleTransformerBlock, + metadata=TransformerBlockMetadata( + skip_block_output_fn=_skip_block_output_fn_FluxSingleTransformerBlock, + return_hidden_states_index=1, + return_encoder_hidden_states_index=0, + ), + ) + + # HunyuanVideo + TransformerBlockRegistry.register( + model_class=HunyuanVideoTransformerBlock, + metadata=TransformerBlockMetadata( + skip_block_output_fn=_skip_block_output_fn_HunyuanVideoTransformerBlock, + return_hidden_states_index=0, + return_encoder_hidden_states_index=1, + ), + ) + TransformerBlockRegistry.register( + model_class=HunyuanVideoSingleTransformerBlock, + metadata=TransformerBlockMetadata( + skip_block_output_fn=_skip_block_output_fn_HunyuanVideoSingleTransformerBlock, + return_hidden_states_index=0, + return_encoder_hidden_states_index=1, + ), + ) + TransformerBlockRegistry.register( + model_class=HunyuanVideoTokenReplaceTransformerBlock, + metadata=TransformerBlockMetadata( + skip_block_output_fn=_skip_block_output_fn_HunyuanVideoTokenReplaceTransformerBlock, + return_hidden_states_index=0, + return_encoder_hidden_states_index=1, + ), + ) + TransformerBlockRegistry.register( + model_class=HunyuanVideoTokenReplaceSingleTransformerBlock, + metadata=TransformerBlockMetadata( + skip_block_output_fn=_skip_block_output_fn_HunyuanVideoTokenReplaceSingleTransformerBlock, + return_hidden_states_index=0, + return_encoder_hidden_states_index=1, + ), + ) + + # LTXVideo + TransformerBlockRegistry.register( + model_class=LTXVideoTransformerBlock, + metadata=TransformerBlockMetadata( + skip_block_output_fn=_skip_block_output_fn_LTXVideoTransformerBlock, + return_hidden_states_index=0, + return_encoder_hidden_states_index=None, + ), + ) + + # Mochi + TransformerBlockRegistry.register( + model_class=MochiTransformerBlock, + metadata=TransformerBlockMetadata( + skip_block_output_fn=_skip_block_output_fn_MochiTransformerBlock, + return_hidden_states_index=0, + return_encoder_hidden_states_index=1, + ), + ) + + # Wan + TransformerBlockRegistry.register( + model_class=WanTransformerBlock, + metadata=TransformerBlockMetadata( + skip_block_output_fn=_skip_block_output_fn_WanTransformerBlock, + return_hidden_states_index=0, + return_encoder_hidden_states_index=None, + ), + ) + + +# fmt: off +def _skip_attention___ret___hidden_states(self, *args, **kwargs): + hidden_states = kwargs.get("hidden_states", None) + if hidden_states is None and len(args) > 0: + hidden_states = args[0] + return hidden_states + + +def _skip_attention___ret___hidden_states___encoder_hidden_states(self, *args, **kwargs): + hidden_states = kwargs.get("hidden_states", None) + encoder_hidden_states = kwargs.get("encoder_hidden_states", None) + if hidden_states is None and len(args) > 0: + hidden_states = args[0] + if encoder_hidden_states is None and len(args) > 1: + encoder_hidden_states = args[1] + return hidden_states, encoder_hidden_states + + +_skip_proc_output_fn_Attention_AttnProcessor2_0 = _skip_attention___ret___hidden_states +_skip_proc_output_fn_Attention_CogView4AttnProcessor = _skip_attention___ret___hidden_states___encoder_hidden_states + + +def _skip_block_output_fn___hidden_states_0___ret___hidden_states(self, *args, **kwargs): + hidden_states = kwargs.get("hidden_states", None) + if hidden_states is None and len(args) > 0: + hidden_states = args[0] + return hidden_states + + +def _skip_block_output_fn___hidden_states_0___encoder_hidden_states_1___ret___hidden_states___encoder_hidden_states(self, *args, **kwargs): + hidden_states = kwargs.get("hidden_states", None) + encoder_hidden_states = kwargs.get("encoder_hidden_states", None) + if hidden_states is None and len(args) > 0: + hidden_states = args[0] + if encoder_hidden_states is None and len(args) > 1: + encoder_hidden_states = args[1] + return hidden_states, encoder_hidden_states + + +def _skip_block_output_fn___hidden_states_0___encoder_hidden_states_1___ret___encoder_hidden_states___hidden_states(self, *args, **kwargs): + hidden_states = kwargs.get("hidden_states", None) + encoder_hidden_states = kwargs.get("encoder_hidden_states", None) + if hidden_states is None and len(args) > 0: + hidden_states = args[0] + if encoder_hidden_states is None and len(args) > 1: + encoder_hidden_states = args[1] + return encoder_hidden_states, hidden_states + + +_skip_block_output_fn_BasicTransformerBlock = _skip_block_output_fn___hidden_states_0___ret___hidden_states +_skip_block_output_fn_CogVideoXBlock = _skip_block_output_fn___hidden_states_0___encoder_hidden_states_1___ret___hidden_states___encoder_hidden_states +_skip_block_output_fn_CogView4TransformerBlock = _skip_block_output_fn___hidden_states_0___encoder_hidden_states_1___ret___hidden_states___encoder_hidden_states +_skip_block_output_fn_FluxTransformerBlock = _skip_block_output_fn___hidden_states_0___encoder_hidden_states_1___ret___encoder_hidden_states___hidden_states +_skip_block_output_fn_FluxSingleTransformerBlock = _skip_block_output_fn___hidden_states_0___encoder_hidden_states_1___ret___encoder_hidden_states___hidden_states +_skip_block_output_fn_HunyuanVideoTransformerBlock = _skip_block_output_fn___hidden_states_0___encoder_hidden_states_1___ret___hidden_states___encoder_hidden_states +_skip_block_output_fn_HunyuanVideoSingleTransformerBlock = _skip_block_output_fn___hidden_states_0___encoder_hidden_states_1___ret___hidden_states___encoder_hidden_states +_skip_block_output_fn_HunyuanVideoTokenReplaceTransformerBlock = _skip_block_output_fn___hidden_states_0___encoder_hidden_states_1___ret___hidden_states___encoder_hidden_states +_skip_block_output_fn_HunyuanVideoTokenReplaceSingleTransformerBlock = _skip_block_output_fn___hidden_states_0___encoder_hidden_states_1___ret___hidden_states___encoder_hidden_states +_skip_block_output_fn_LTXVideoTransformerBlock = _skip_block_output_fn___hidden_states_0___ret___hidden_states +_skip_block_output_fn_MochiTransformerBlock = _skip_block_output_fn___hidden_states_0___encoder_hidden_states_1___ret___hidden_states___encoder_hidden_states +_skip_block_output_fn_WanTransformerBlock = _skip_block_output_fn___hidden_states_0___ret___hidden_states +# fmt: on + + +_register_attention_processors_metadata() +_register_transformer_blocks_metadata() diff --git a/src/diffusers/hooks/layer_skip.py b/src/diffusers/hooks/layer_skip.py new file mode 100644 index 000000000000..b8471523902a --- /dev/null +++ b/src/diffusers/hooks/layer_skip.py @@ -0,0 +1,194 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass +from typing import Callable, List, Optional + +import torch + +from ..utils import get_logger +from ..utils.torch_utils import unwrap_module +from ._common import _ALL_TRANSFORMER_BLOCK_IDENTIFIERS, _ATTENTION_CLASSES, _FEEDFORWARD_CLASSES +from ._helpers import AttentionProcessorRegistry, TransformerBlockRegistry +from .hooks import HookRegistry, ModelHook + + +logger = get_logger(__name__) # pylint: disable=invalid-name + +_LAYER_SKIP_HOOK = "layer_skip_hook" + + +@dataclass +class LayerSkipConfig: + r""" + Configuration for skipping internal transformer blocks when executing a transformer model. + + Args: + indices (`List[int]`): + The indices of the layer to skip. This is typically the first layer in the transformer block. + fqn (`str`, defaults to `"auto"`): + The fully qualified name identifying the stack of transformer blocks. Typically, this is + `transformer_blocks`, `single_transformer_blocks`, `blocks`, `layers`, or `temporal_transformer_blocks`. + For automatic detection, set this to `"auto"`. + "auto" only works on DiT models. For UNet models, you must provide the correct fqn. + skip_attention (`bool`, defaults to `True`): + Whether to skip attention blocks. + skip_ff (`bool`, defaults to `True`): + Whether to skip feed-forward blocks. + skip_attention_scores (`bool`, defaults to `False`): + Whether to skip attention score computation in the attention blocks. This is equivalent to using + `value` projections as the output of scaled dot product attention. + """ + + indices: List[int] + fqn: str = "auto" + skip_attention: bool = True + skip_attention_scores: bool = False + skip_ff: bool = True + + +class AttentionScoreSkipFunctionMode(torch.overrides.TorchFunctionMode): + def __init__(self) -> None: + super().__init__() + + def __torch_function__(self, func, types, args=(), kwargs=None): + if kwargs is None: + kwargs = {} + if func is torch.nn.functional.scaled_dot_product_attention: + value = kwargs.get("value", None) + if value is None: + value = args[2] + return value + return func(*args, **kwargs) + + +class AttentionProcessorSkipHook(ModelHook): + def __init__(self, skip_processor_output_fn: Callable, skip_attention_scores: bool = False): + self.skip_processor_output_fn = skip_processor_output_fn + self.skip_attention_scores = skip_attention_scores + + def new_forward(self, module: torch.nn.Module, *args, **kwargs): + if self.skip_attention_scores: + with AttentionScoreSkipFunctionMode(): + return self.fn_ref.original_forward(*args, **kwargs) + else: + return self.skip_processor_output_fn(module, *args, **kwargs) + + +class FeedForwardSkipHook(ModelHook): + def new_forward(self, module: torch.nn.Module, *args, **kwargs): + output = kwargs.get("hidden_states", None) + if output is None: + output = kwargs.get("x", None) + if output is None and len(args) > 0: + output = args[0] + return output + + +class TransformerBlockSkipHook(ModelHook): + def initialize_hook(self, module): + self._metadata = TransformerBlockRegistry.get(unwrap_module(module).__class__) + return module + + def new_forward(self, module: torch.nn.Module, *args, **kwargs): + return self._metadata.skip_block_output_fn(module, *args, **kwargs) + + +def apply_layer_skip(module: torch.nn.Module, config: LayerSkipConfig) -> None: + r""" + Apply layer skipping to internal layers of a transformer. + Args: + module (`torch.nn.Module`): + The transformer model to which the layer skip hook should be applied. + config (`LayerSkipConfig`): + The configuration for the layer skip hook. + Example: + ```python + >>> from diffusers import apply_layer_skip_hook, CogVideoXTransformer3DModel, LayerSkipConfig + >>> transformer = CogVideoXTransformer3DModel.from_pretrained("THUDM/CogVideoX-5b", torch_dtype=torch.bfloat16) + >>> config = LayerSkipConfig(layer_index=[10, 20], fqn="transformer_blocks") + >>> apply_layer_skip_hook(transformer, config) + ``` + """ + _apply_layer_skip_hook(module, config) + + +def _apply_layer_skip_hook(module: torch.nn.Module, config: LayerSkipConfig, name: Optional[str] = None) -> None: + name = name or _LAYER_SKIP_HOOK + + if config.skip_attention and config.skip_attention_scores: + raise ValueError("Cannot set both `skip_attention` and `skip_attention_scores` to True. Please choose one.") + + if config.fqn == "auto": + for identifier in _ALL_TRANSFORMER_BLOCK_IDENTIFIERS: + if hasattr(module, identifier): + config.fqn = identifier + break + else: + raise ValueError( + "Could not find a suitable identifier for the transformer blocks automatically. Please provide a valid " + "`fqn` (fully qualified name) that identifies a stack of transformer blocks." + ) + + transformer_blocks = _get_submodule_from_fqn(module, config.fqn) + if transformer_blocks is None or not isinstance(transformer_blocks, torch.nn.ModuleList): + raise ValueError( + f"Could not find {config.fqn} in the provided module, or configured `fqn` (fully qualified name) does not identify " + f"a `torch.nn.ModuleList`. Please provide a valid `fqn` that identifies a stack of transformer blocks." + ) + if len(config.indices) == 0: + raise ValueError("Layer index list is empty. Please provide a non-empty list of layer indices to skip.") + + blocks_found = False + for i, block in enumerate(transformer_blocks): + if i not in config.indices: + continue + blocks_found = True + if config.skip_attention and config.skip_ff: + logger.debug(f"Applying TransformerBlockSkipHook to '{config.fqn}.{i}'") + registry = HookRegistry.check_if_exists_or_initialize(block) + hook = TransformerBlockSkipHook() + registry.register_hook(hook, name) + elif config.skip_attention or config.skip_attention_scores: + for submodule_name, submodule in block.named_modules(): + if isinstance(submodule, _ATTENTION_CLASSES) and not submodule.is_cross_attention: + logger.debug(f"Applying AttentionProcessorSkipHook to '{config.fqn}.{i}.{submodule_name}'") + output_fn = AttentionProcessorRegistry.get(submodule.processor.__class__).skip_processor_output_fn + registry = HookRegistry.check_if_exists_or_initialize(submodule) + hook = AttentionProcessorSkipHook(output_fn, config.skip_attention_scores) + registry.register_hook(hook, name) + elif config.skip_ff: + for submodule_name, submodule in block.named_modules(): + if isinstance(submodule, _FEEDFORWARD_CLASSES): + logger.debug(f"Applying FeedForwardSkipHook to '{config.fqn}.{i}.{submodule_name}'") + registry = HookRegistry.check_if_exists_or_initialize(submodule) + hook = FeedForwardSkipHook() + registry.register_hook(hook, name) + else: + raise ValueError( + "At least one of `skip_attention`, `skip_attention_scores`, or `skip_ff` must be set to True." + ) + + if not blocks_found: + raise ValueError( + f"Could not find any transformer blocks matching the provided indices {config.indices} and " + f"fully qualified name '{config.fqn}'. Please check the indices and fqn for correctness." + ) + + +def _get_submodule_from_fqn(module: torch.nn.Module, fqn: str) -> Optional[torch.nn.Module]: + for submodule_name, submodule in module.named_modules(): + if submodule_name == fqn: + return submodule + return None diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py index 8e7109308962..5f125605a271 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py @@ -19,7 +19,6 @@ import torch from collections import OrderedDict -from ...guider import CFGGuider from ...image_processor import VaeImageProcessor, PipelineImageInput from ...loaders import StableDiffusionXLLoraLoaderMixin, TextualInversionLoaderMixin, ModularIPAdapterMixin from ...models import ControlNetModel, ImageProjection, UNet2DConditionModel, AutoencoderKL, ControlNetUnionModel @@ -58,7 +57,7 @@ ) from ...schedulers import KarrasDiffusionSchedulers -from ...guider import Guiders, CFGGuider +from ...guiders import GuiderType, ClassifierFreeGuidance import numpy as np @@ -2068,7 +2067,7 @@ class StableDiffusionXLDenoiseStep(PipelineBlock): @property def expected_components(self) -> List[ComponentSpec]: return [ - ComponentSpec("guider", CFGGuider, obj=CFGGuider()), + ComponentSpec("guider", GuiderType, obj=ClassifierFreeGuidance()), ComponentSpec("scheduler", KarrasDiffusionSchedulers), ComponentSpec("unet", UNet2DConditionModel), ] @@ -2082,12 +2081,9 @@ def description(self) -> str: @property def inputs(self) -> List[Tuple[str, Any]]: return [ - InputParam("guidance_scale", default=5.0), - InputParam("guidance_rescale", default=0.0), InputParam("cross_attention_kwargs"), InputParam("generator"), InputParam("eta", default=0.0), - InputParam("guider_kwargs"), InputParam("num_images_per_prompt", default=1), ] @@ -2239,77 +2235,83 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: data.num_channels_unet = pipeline.unet.config.in_channels data.disable_guidance = True if pipeline.unet.config.time_cond_proj_dim is not None else False - # adding default guider arguments: do_classifier_free_guidance, guidance_scale, guidance_rescale - data.guider_kwargs = data.guider_kwargs or {} - data.guider_kwargs = { - **data.guider_kwargs, - "disable_guidance": data.disable_guidance, - "guidance_scale": data.guidance_scale, - "guidance_rescale": data.guidance_rescale, - "batch_size": data.batch_size * data.num_images_per_prompt, - } - - pipeline.guider.set_guider(pipeline, data.guider_kwargs) - # Prepare conditional inputs using the guider - data.prompt_embeds = pipeline.guider.prepare_input( - data.prompt_embeds, - data.negative_prompt_embeds, - ) - data.add_time_ids = pipeline.guider.prepare_input( - data.add_time_ids, - data.negative_add_time_ids, - ) - data.pooled_prompt_embeds = pipeline.guider.prepare_input( - data.pooled_prompt_embeds, - data.negative_pooled_prompt_embeds, - ) - - if data.num_channels_unet == 9: - data.mask = pipeline.guider.prepare_input(data.mask, data.mask) - data.masked_image_latents = pipeline.guider.prepare_input(data.masked_image_latents, data.masked_image_latents) - - data.added_cond_kwargs = { - "text_embeds": data.pooled_prompt_embeds, - "time_ids": data.add_time_ids, - } - - if data.ip_adapter_embeds is not None: - data.ip_adapter_embeds = pipeline.guider.prepare_input(data.ip_adapter_embeds, data.negative_ip_adapter_embeds) - data.added_cond_kwargs["image_embeds"] = data.ip_adapter_embeds - # Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline data.extra_step_kwargs = self.prepare_extra_step_kwargs(pipeline, data.generator, data.eta) data.num_warmup_steps = max(len(data.timesteps) - data.num_inference_steps * pipeline.scheduler.order, 0) with pipeline.progress_bar(total=data.num_inference_steps) as progress_bar: for i, t in enumerate(data.timesteps): - # expand the latents if we are doing classifier free guidance - data.latent_model_input = pipeline.guider.prepare_input(data.latents, data.latents) - data.latent_model_input = pipeline.scheduler.scale_model_input(data.latent_model_input, t) - - # inpainting - if data.num_channels_unet == 9: - data.latent_model_input = torch.cat([data.latent_model_input, data.mask, data.masked_image_latents], dim=1) - - # predict the noise residual - data.noise_pred = pipeline.unet( - data.latent_model_input, - t, - encoder_hidden_states=data.prompt_embeds, - timestep_cond=data.timestep_cond, - cross_attention_kwargs=data.cross_attention_kwargs, - added_cond_kwargs=data.added_cond_kwargs, - return_dict=False, - )[0] - # perform guidance - data.noise_pred = pipeline.guider.apply_guidance( - data.noise_pred, - timestep=t, - latents=data.latents, + pipeline.guider.set_state(step=i, num_inference_steps=data.num_inference_steps, timestep=t) + + ( + latents, + prompt_embeds, + add_time_ids, + pooled_prompt_embeds, + mask, + masked_image_latents, + ip_adapter_embeds, + ) = pipeline.guider.prepare_inputs( + pipeline.unet, + data.latents, + (data.prompt_embeds, data.negative_prompt_embeds), + (data.add_time_ids, data.negative_add_time_ids), + (data.pooled_prompt_embeds, data.negative_pooled_prompt_embeds), + data.mask, + data.masked_image_latents, + (data.ip_adapter_embeds, data.negative_ip_adapter_embeds), ) - # compute the previous noisy sample x_t -> x_t-1 + + for batch_index, ( + latents_i, + prompt_embeds_i, + add_time_ids_i, + pooled_prompt_embeds_i, + mask_i, + masked_image_latents_i, + ip_adapter_embeds_i, + ) in enumerate(zip( + latents, + prompt_embeds, + add_time_ids, + pooled_prompt_embeds, + mask, + masked_image_latents, + ip_adapter_embeds + )): + latents_i = pipeline.scheduler.scale_model_input(latents_i, t) + + # Prepare for inpainting + if data.num_channels_unet == 9: + latents_i = torch.cat([latents_i, mask_i, masked_image_latents_i], dim=1) + + data.added_cond_kwargs = { + "text_embeds": pooled_prompt_embeds_i, + "time_ids": add_time_ids_i, + } + if ip_adapter_embeds_i is not None: + data.added_cond_kwargs["image_embeds"] = ip_adapter_embeds_i + + # predict the noise residual + data.noise_pred = pipeline.unet( + latents_i, + t, + encoder_hidden_states=prompt_embeds_i, + timestep_cond=data.timestep_cond, + cross_attention_kwargs=data.cross_attention_kwargs, + added_cond_kwargs=data.added_cond_kwargs, + return_dict=False, + )[0] + data.noise_pred = pipeline.guider.prepare_outputs(pipeline.unet, data.noise_pred) + + # Perform guidance + outputs = pipeline.guider.outputs + data.noise_pred = pipeline.guider(**outputs) + + # Perform scheduler step using the predicted output data.latents_dtype = data.latents.dtype data.latents = pipeline.scheduler.step(data.noise_pred, t, data.latents, **data.extra_step_kwargs, return_dict=False)[0] + if data.latents.dtype != data.latents_dtype: if torch.backends.mps.is_available(): # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 @@ -2328,7 +2330,6 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: if i == len(data.timesteps) - 1 or ((i + 1) > data.num_warmup_steps and (i + 1) % pipeline.scheduler.order == 0): progress_bar.update() - pipeline.guider.reset_guider(pipeline) self.add_block_state(state, data) return pipeline, state @@ -2341,12 +2342,12 @@ class StableDiffusionXLControlNetDenoiseStep(PipelineBlock): @property def expected_components(self) -> List[ComponentSpec]: return [ - ComponentSpec("guider", CFGGuider, obj=CFGGuider()), + ComponentSpec("guider", GuiderType, obj=ClassifierFreeGuidance()), ComponentSpec("scheduler", KarrasDiffusionSchedulers), ComponentSpec("unet", UNet2DConditionModel), ComponentSpec("controlnet", ControlNetModel), ComponentSpec("control_image_processor", VaeImageProcessor, obj=VaeImageProcessor(do_convert_rgb=True, do_normalize=False)), - ComponentSpec("controlnet_guider", CFGGuider, obj=CFGGuider()), + ComponentSpec("controlnet_guider", GuiderType, obj=ClassifierFreeGuidance()), ] @property @@ -2792,8 +2793,8 @@ def expected_components(self) -> List[ComponentSpec]: ComponentSpec("unet", UNet2DConditionModel), ComponentSpec("controlnet", ControlNetUnionModel), ComponentSpec("scheduler", KarrasDiffusionSchedulers), - ComponentSpec("guider", CFGGuider, obj=CFGGuider()), - ComponentSpec("controlnet_guider", CFGGuider, obj=CFGGuider()), + ComponentSpec("guider", GuiderType, obj=ClassifierFreeGuidance()), + ComponentSpec("controlnet_guider", GuiderType, obj=ClassifierFreeGuidance()), ComponentSpec("control_image_processor", VaeImageProcessor, obj=VaeImageProcessor(do_convert_rgb=True, do_normalize=False)), ] diff --git a/src/diffusers/utils/torch_utils.py b/src/diffusers/utils/torch_utils.py index 3c8911773e39..06f9981f0138 100644 --- a/src/diffusers/utils/torch_utils.py +++ b/src/diffusers/utils/torch_utils.py @@ -90,6 +90,11 @@ def is_compiled_module(module) -> bool: return isinstance(module, torch._dynamo.eval_frame.OptimizedModule) +def unwrap_module(module): + """Unwraps a module if it was compiled with torch.compile()""" + return module._orig_mod if is_compiled_module(module) else module + + def fourier_filter(x_in: "torch.Tensor", threshold: int, scale: int) -> "torch.Tensor": """Fourier filter as introduced in FreeU (https://arxiv.org/abs/2309.11497). From 9da8a9d1d557a290d691ee2a9eaebd107463130f Mon Sep 17 00:00:00 2001 From: Aryan Date: Mon, 14 Apr 2025 14:04:04 +0200 Subject: [PATCH 02/39] support sdxl controlnet --- .../guiders/classifier_free_guidance.py | 6 + src/diffusers/guiders/guider_utils.py | 25 ++- src/diffusers/guiders/skip_layer_guidance.py | 8 + .../pipeline_stable_diffusion_xl_modular.py | 200 ++++++++---------- 4 files changed, 119 insertions(+), 120 deletions(-) diff --git a/src/diffusers/guiders/classifier_free_guidance.py b/src/diffusers/guiders/classifier_free_guidance.py index 4048d70484c5..b3508307d478 100644 --- a/src/diffusers/guiders/classifier_free_guidance.py +++ b/src/diffusers/guiders/classifier_free_guidance.py @@ -92,6 +92,10 @@ def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = return pred + @property + def is_conditional(self) -> bool: + return self._num_outputs_prepared == 0 + @property def num_conditions(self) -> int: num_conditions = 1 @@ -100,6 +104,8 @@ def num_conditions(self) -> int: return num_conditions def _is_cfg_enabled(self) -> bool: + if not self._enabled: + return False skip_start_step = int(self._start * self._num_inference_steps) skip_stop_step = int(self._stop * self._num_inference_steps) is_within_range = skip_start_step <= self._step < skip_stop_step diff --git a/src/diffusers/guiders/guider_utils.py b/src/diffusers/guiders/guider_utils.py index ecde7334b2b9..690afae89178 100644 --- a/src/diffusers/guiders/guider_utils.py +++ b/src/diffusers/guiders/guider_utils.py @@ -39,6 +39,7 @@ def __init__(self, start: float = 0.0, stop: float = 1.0): self._timestep: torch.LongTensor = None self._preds: Dict[str, torch.Tensor] = {} self._num_outputs_prepared: int = 0 + self._enabled = True if not (0.0 <= start < 1.0): raise ValueError( @@ -54,6 +55,12 @@ def __init__(self, start: float = 0.0, stop: float = 1.0): "`_input_predictions` must be a list of required prediction names for the guidance technique." ) + def force_disable(self): + self._enabled = False + + def force_enable(self): + self._enabled = True + def set_state(self, step: int, num_inference_steps: int, timestep: torch.LongTensor) -> None: self._step = step self._num_inference_steps = num_inference_steps @@ -62,10 +69,10 @@ def set_state(self, step: int, num_inference_steps: int, timestep: torch.LongTen self._num_outputs_prepared = 0 def prepare_inputs(self, denoiser: torch.nn.Module, *args: Union[Tuple[torch.Tensor], List[torch.Tensor]]) -> Tuple[List[torch.Tensor], ...]: - raise NotImplementedError("GuidanceMixin::prepare_inputs must be implemented in subclasses.") + raise NotImplementedError("BaseGuidance::prepare_inputs must be implemented in subclasses.") def prepare_outputs(self, denoiser: torch.nn.Module, pred: torch.Tensor) -> None: - raise NotImplementedError("GuidanceMixin::prepare_outputs must be implemented in subclasses.") + raise NotImplementedError("BaseGuidance::prepare_outputs must be implemented in subclasses.") def __call__(self, **kwargs) -> Any: if len(kwargs) != self.num_conditions: @@ -75,11 +82,19 @@ def __call__(self, **kwargs) -> Any: return self.forward(**kwargs) def forward(self, *args, **kwargs) -> Any: - raise NotImplementedError("GuidanceMixin::forward must be implemented in subclasses.") + raise NotImplementedError("BaseGuidance::forward must be implemented in subclasses.") + @property + def is_conditional(self) -> bool: + raise NotImplementedError("BaseGuidance::is_conditional must be implemented in subclasses.") + + @property + def is_unconditional(self) -> bool: + return not self.is_conditional + @property def num_conditions(self) -> int: - raise NotImplementedError("GuidanceMixin::num_conditions must be implemented in subclasses.") + raise NotImplementedError("BaseGuidance::num_conditions must be implemented in subclasses.") @property def outputs(self) -> Dict[str, torch.Tensor]: @@ -114,7 +129,7 @@ def _default_prepare_inputs(denoiser: torch.nn.Module, num_conditions: int, *arg """ Prepares the inputs for the denoiser by ensuring that the conditional and unconditional inputs are correctly prepared based on required number of conditions. This function is used in the `prepare_inputs` method of the - `GuidanceMixin` class. + `BaseGuidance` class. Either tensors or tuples/lists of tensors can be provided. If a tuple/list is provided, it should contain two elements: - The first element is the conditional input. diff --git a/src/diffusers/guiders/skip_layer_guidance.py b/src/diffusers/guiders/skip_layer_guidance.py index e20a700fee6a..92ae7f8518d8 100644 --- a/src/diffusers/guiders/skip_layer_guidance.py +++ b/src/diffusers/guiders/skip_layer_guidance.py @@ -189,6 +189,10 @@ def forward( pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale) return pred + + @property + def is_conditional(self) -> bool: + return self._num_outputs_prepared == 0 or self._num_outputs_prepared == 2 @property def num_conditions(self) -> int: @@ -200,6 +204,8 @@ def num_conditions(self) -> int: return num_conditions def _is_cfg_enabled(self) -> bool: + if not self._enabled: + return False skip_start_step = int(self._start * self._num_inference_steps) skip_stop_step = int(self._stop * self._num_inference_steps) is_within_range = skip_start_step <= self._step < skip_stop_step @@ -211,6 +217,8 @@ def _is_cfg_enabled(self) -> bool: return is_within_range and not is_close def _is_slg_enabled(self) -> bool: + if not self._enabled: + return False skip_start_step = int(self._start * self._num_inference_steps) skip_stop_step = int(self._stop * self._num_inference_steps) is_within_range = skip_start_step < self._step < skip_stop_step diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py index 5f125605a271..5df57c6c1605 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py @@ -2263,21 +2263,9 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: ) for batch_index, ( - latents_i, - prompt_embeds_i, - add_time_ids_i, - pooled_prompt_embeds_i, - mask_i, - masked_image_latents_i, - ip_adapter_embeds_i, + latents_i, prompt_embeds_i, add_time_ids_i, pooled_prompt_embeds_i, mask_i, masked_image_latents_i, ip_adapter_embeds_i, ) in enumerate(zip( - latents, - prompt_embeds, - add_time_ids, - pooled_prompt_embeds, - mask, - masked_image_latents, - ip_adapter_embeds + latents, prompt_embeds, add_time_ids, pooled_prompt_embeds, mask, masked_image_latents, ip_adapter_embeds )): latents_i = pipeline.scheduler.scale_model_input(latents_i, t) @@ -2285,6 +2273,7 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: if data.num_channels_unet == 9: latents_i = torch.cat([latents_i, mask_i, masked_image_latents_i], dim=1) + # Prepare additional conditionings data.added_cond_kwargs = { "text_embeds": pooled_prompt_embeds_i, "time_ids": add_time_ids_i, @@ -2292,7 +2281,7 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: if ip_adapter_embeds_i is not None: data.added_cond_kwargs["image_embeds"] = ip_adapter_embeds_i - # predict the noise residual + # Predict the noise residual data.noise_pred = pipeline.unet( latents_i, t, @@ -2347,7 +2336,6 @@ def expected_components(self) -> List[ComponentSpec]: ComponentSpec("unet", UNet2DConditionModel), ComponentSpec("controlnet", ControlNetModel), ComponentSpec("control_image_processor", VaeImageProcessor, obj=VaeImageProcessor(do_convert_rgb=True, do_normalize=False)), - ComponentSpec("controlnet_guider", GuiderType, obj=ClassifierFreeGuidance()), ] @property @@ -2363,8 +2351,6 @@ def inputs(self) -> List[Tuple[str, Any]]: InputParam("controlnet_conditioning_scale", default=1.0), InputParam("guess_mode", default=False), InputParam("num_images_per_prompt", default=1), - InputParam("guidance_scale", default=5.0), - InputParam("guidance_rescale", default=0.0), InputParam("cross_attention_kwargs"), InputParam("generator"), InputParam("eta", default=0.0), @@ -2515,8 +2501,8 @@ def prepare_control_image( image = components.control_image_processor.preprocess(image, height=height, width=width, crops_coords=crops_coords, resize_mode="fill").to(dtype=torch.float32) else: image = components.control_image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32) + image_batch_size = image.shape[0] - if image_batch_size == 1: repeat_by = batch_size else: @@ -2524,9 +2510,7 @@ def prepare_control_image( repeat_by = num_images_per_prompt image = image.repeat_interleave(repeat_by, dim=0) - image = image.to(device=device, dtype=dtype) - return image # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs with self -> components @@ -2557,9 +2541,7 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: data.num_channels_unet = pipeline.unet.config.in_channels # (1) prepare controlnet inputs - data.device = pipeline._execution_device - data.height, data.width = data.latents.shape[-2:] data.height = data.height * pipeline.vae_scale_factor data.width = data.width * pipeline.vae_scale_factor @@ -2642,59 +2624,12 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: data.controlnet_keep.append(keeps[0] if isinstance(controlnet, ControlNetModel) else keeps) # (2) Prepare conditional inputs for unet using the guider - # adding default guider arguments: disable_guidance, guidance_scale, guidance_rescale data.disable_guidance = True if pipeline.unet.config.time_cond_proj_dim is not None else False - data.guider_kwargs = data.guider_kwargs or {} - data.guider_kwargs = { - **data.guider_kwargs, - "disable_guidance": data.disable_guidance, - "guidance_scale": data.guidance_scale, - "guidance_rescale": data.guidance_rescale, - "batch_size": data.batch_size * data.num_images_per_prompt, - } - pipeline.guider.set_guider(pipeline, data.guider_kwargs) - data.prompt_embeds = pipeline.guider.prepare_input( - data.prompt_embeds, - data.negative_prompt_embeds, - ) - data.add_time_ids = pipeline.guider.prepare_input( - data.add_time_ids, - data.negative_add_time_ids, - ) - data.pooled_prompt_embeds = pipeline.guider.prepare_input( - data.pooled_prompt_embeds, - data.negative_pooled_prompt_embeds, - ) - if data.num_channels_unet == 9: - data.mask = pipeline.guider.prepare_input(data.mask, data.mask) - data.masked_image_latents = pipeline.guider.prepare_input(data.masked_image_latents, data.masked_image_latents) - - data.added_cond_kwargs = { - "text_embeds": data.pooled_prompt_embeds, - "time_ids": data.add_time_ids, - } - - if data.ip_adapter_embeds is not None: - data.ip_adapter_embeds = pipeline.guider.prepare_input(data.ip_adapter_embeds, data.negative_ip_adapter_embeds) - data.added_cond_kwargs["image_embeds"] = data.ip_adapter_embeds + if data.disable_guidance: + pipeline.guider.force_disable() # (3) Prepare conditional inputs for controlnet using the guider data.controlnet_disable_guidance = True if data.disable_guidance or data.guess_mode else False - data.controlnet_guider_kwargs = data.guider_kwargs or {} - data.controlnet_guider_kwargs = { - **data.controlnet_guider_kwargs, - "disable_guidance": data.controlnet_disable_guidance, - "guidance_scale": data.guidance_scale, - "guidance_rescale": data.guidance_rescale, - "batch_size": data.batch_size * data.num_images_per_prompt, - } - pipeline.controlnet_guider.set_guider(pipeline, data.controlnet_guider_kwargs) - data.controlnet_prompt_embeds = pipeline.controlnet_guider.prepare_input(data.prompt_embeds) - data.controlnet_added_cond_kwargs = { - "text_embeds": pipeline.controlnet_guider.prepare_input(data.pooled_prompt_embeds), - "time_ids": pipeline.controlnet_guider.prepare_input(data.add_time_ids), - } - data.control_image = pipeline.controlnet_guider.prepare_input(data.control_image, data.control_image) # (4) Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline data.extra_step_kwargs = self.prepare_extra_step_kwargs(pipeline, data.generator, data.eta) @@ -2703,11 +2638,26 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: # (5) Denoise loop with pipeline.progress_bar(total=data.num_inference_steps) as progress_bar: for i, t in enumerate(data.timesteps): - # prepare latents for unet using the guider - data.latent_model_input = pipeline.guider.prepare_input(data.latents, data.latents) + pipeline.guider.set_state(step=i, num_inference_steps=data.num_inference_steps, timestep=t) - # prepare latents for controlnet using the guider - data.control_model_input = pipeline.controlnet_guider.prepare_input(data.latents, data.latents) + ( + latents, + prompt_embeds, + add_time_ids, + pooled_prompt_embeds, + mask, + masked_image_latents, + ip_adapter_embeds, + ) = pipeline.guider.prepare_inputs( + pipeline.unet, + data.latents, + (data.prompt_embeds, data.negative_prompt_embeds), + (data.add_time_ids, data.negative_add_time_ids), + (data.pooled_prompt_embeds, data.negative_pooled_prompt_embeds), + data.mask, + data.masked_image_latents, + (data.ip_adapter_embeds, data.negative_ip_adapter_embeds), + ) if isinstance(data.controlnet_keep[i], list): data.cond_scale = [c * s for c, s in zip(data.controlnet_conditioning_scale, data.controlnet_keep[i])] @@ -2717,51 +2667,74 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: data.controlnet_cond_scale = data.controlnet_cond_scale[0] data.cond_scale = data.controlnet_cond_scale * data.controlnet_keep[i] - data.down_block_res_samples, data.mid_block_res_sample = pipeline.controlnet( - pipeline.scheduler.scale_model_input(data.control_model_input, t), - t, - encoder_hidden_states=data.controlnet_prompt_embeds, - controlnet_cond=data.control_image, - conditioning_scale=data.cond_scale, - guess_mode=data.guess_mode, - added_cond_kwargs=data.controlnet_added_cond_kwargs, - return_dict=False, - ) + for batch_index, ( + latents_i, prompt_embeds_i, add_time_ids_i, pooled_prompt_embeds_i, mask_i, masked_image_latents_i, ip_adapter_embeds_i + ) in enumerate(zip( + latents, prompt_embeds, add_time_ids, pooled_prompt_embeds, mask, masked_image_latents, ip_adapter_embeds + )): + latents_i = pipeline.scheduler.scale_model_input(latents_i, t) + + # Prepare for inpainting + if data.num_channels_unet == 9: + latents_i = torch.cat([latents_i, mask_i, masked_image_latents_i], dim=1) + + # Prepare additional conditionings + data.added_cond_kwargs = { + "text_embeds": pooled_prompt_embeds_i, + "time_ids": add_time_ids_i, + } + if ip_adapter_embeds_i is not None: + data.added_cond_kwargs["image_embeds"] = ip_adapter_embeds_i + + # Prepare controlnet additional conditionings + data.controlnet_added_cond_kwargs = { + "text_embeds": pooled_prompt_embeds_i, + "time_ids": add_time_ids_i, + } - # when we apply guidance for unet, but not for controlnet: - # add 0 to the unconditional batch - data.down_block_res_samples = pipeline.guider.prepare_input( - data.down_block_res_samples, [torch.zeros_like(d) for d in data.down_block_res_samples] - ) - data.mid_block_res_sample = pipeline.guider.prepare_input( - data.mid_block_res_sample, torch.zeros_like(data.mid_block_res_sample) - ) + if pipeline.guider.is_conditional or not data.guess_mode: + data.down_block_res_samples, data.mid_block_res_sample = pipeline.controlnet( + latents_i, + t, + encoder_hidden_states=prompt_embeds_i, + controlnet_cond=data.control_image, + conditioning_scale=data.cond_scale, + guess_mode=data.guess_mode, + added_cond_kwargs=data.controlnet_added_cond_kwargs, + return_dict=False, + ) + elif pipeline.guider.is_unconditional and data.guess_mode: + data.down_block_res_samples = [torch.zeros_like(d) for d in data.down_block_res_samples] + data.mid_block_res_sample = torch.zeros_like(data.mid_block_res_sample) - data.latent_model_input = pipeline.scheduler.scale_model_input(data.latent_model_input, t) - if data.num_channels_unet == 9: - data.latent_model_input = torch.cat([data.latent_model_input, data.mask, data.masked_image_latents], dim=1) + if data.num_channels_unet == 9: + latents_i = torch.cat([latents_i, mask_i, masked_image_latents_i], dim=1) - data.noise_pred = pipeline.unet( - data.latent_model_input, - t, - encoder_hidden_states=data.prompt_embeds, - timestep_cond=data.timestep_cond, - cross_attention_kwargs=data.cross_attention_kwargs, - added_cond_kwargs=data.added_cond_kwargs, - down_block_additional_residuals=data.down_block_res_samples, - mid_block_additional_residual=data.mid_block_res_sample, - return_dict=False, - )[0] - # perform guidance - data.noise_pred = pipeline.guider.apply_guidance(data.noise_pred, timestep=t, latents=data.latents) - # compute the previous noisy sample x_t -> x_t-1 + data.noise_pred = pipeline.unet( + latents_i, + t, + encoder_hidden_states=prompt_embeds_i, + timestep_cond=data.timestep_cond, + cross_attention_kwargs=data.cross_attention_kwargs, + added_cond_kwargs=data.added_cond_kwargs, + down_block_additional_residuals=data.down_block_res_samples, + mid_block_additional_residual=data.mid_block_res_sample, + return_dict=False, + )[0] + data.noise_pred = pipeline.guider.prepare_outputs(pipeline.unet, data.noise_pred) + + # Perform guidance + outputs = pipeline.guider.outputs + data.noise_pred = pipeline.guider(**outputs) + + # Perform scheduler step using the predicted output data.latents_dtype = data.latents.dtype data.latents = pipeline.scheduler.step(data.noise_pred, t, data.latents, **data.extra_step_kwargs, return_dict=False)[0] + if data.latents.dtype != data.latents_dtype: if torch.backends.mps.is_available(): # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 data.latents = data.latents.to(data.latents_dtype) - if data.num_channels_unet == 4 and data.mask is not None and data.image_latents is not None: data.init_latents_proper = data.image_latents @@ -2775,9 +2748,6 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: if i == len(data.timesteps) - 1 or ((i + 1) > data.num_warmup_steps and (i + 1) % pipeline.scheduler.order == 0): progress_bar.update() - - pipeline.guider.reset_guider(pipeline) - pipeline.controlnet_guider.reset_guider(pipeline) self.add_block_state(state, data) From b81bd78bf9ad422720b52ec86907ea0a945181e3 Mon Sep 17 00:00:00 2001 From: Aryan Date: Mon, 14 Apr 2025 15:39:16 +0200 Subject: [PATCH 03/39] support controlnet union --- src/diffusers/__init__.py | 2 + src/diffusers/guiders/__init__.py | 1 + .../guiders/adaptive_projected_guidance.py | 174 +++++++++++++ .../guiders/classifier_free_guidance.py | 11 +- src/diffusers/guiders/skip_layer_guidance.py | 22 +- .../pipeline_stable_diffusion_xl_modular.py | 246 ++++++++---------- 6 files changed, 312 insertions(+), 144 deletions(-) create mode 100644 src/diffusers/guiders/adaptive_projected_guidance.py diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index d8e274cb93b6..d3a61df3baff 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -132,6 +132,7 @@ else: _import_structure["guiders"].extend( [ + "AdaptiveProjectedGuidance", "ClassifierFreeGuidance", "SkipLayerGuidance", ] @@ -721,6 +722,7 @@ from .utils.dummy_pt_objects import * # noqa F403 else: from .guiders import ( + AdaptiveProjectedGuidance, ClassifierFreeGuidance, SkipLayerGuidance, ) diff --git a/src/diffusers/guiders/__init__.py b/src/diffusers/guiders/__init__.py index adef65277b6d..e3c6494de090 100644 --- a/src/diffusers/guiders/__init__.py +++ b/src/diffusers/guiders/__init__.py @@ -18,6 +18,7 @@ if is_torch_available(): + from .adaptive_projected_guidance import AdaptiveProjectedGuidance from .classifier_free_guidance import ClassifierFreeGuidance from .skip_layer_guidance import SkipLayerGuidance diff --git a/src/diffusers/guiders/adaptive_projected_guidance.py b/src/diffusers/guiders/adaptive_projected_guidance.py new file mode 100644 index 000000000000..ab5175745b07 --- /dev/null +++ b/src/diffusers/guiders/adaptive_projected_guidance.py @@ -0,0 +1,174 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from typing import Optional, Union, Tuple, List + +import torch + +from .guider_utils import BaseGuidance, rescale_noise_cfg, _default_prepare_inputs + + +class AdaptiveProjectedGuidance(BaseGuidance): + """ + Adaptive Projected Guidance (APG): https://huggingface.co/papers/2410.02416 + + Args: + guidance_scale (`float`, defaults to `7.5`): + The scale parameter for classifier-free guidance. Higher values result in stronger conditioning on the text + prompt, while lower values allow for more freedom in generation. Higher values may lead to saturation and + deterioration of image quality. + adaptive_projected_guidance_momentum (`float`, defaults to `None`): + The momentum parameter for the adaptive projected guidance. Disabled if set to `None`. + adaptive_projected_guidance_rescale (`float`, defaults to `15.0`): + The rescale factor applied to the noise predictions. This is used to improve image quality and fix + guidance_rescale (`float`, defaults to `0.0`): + The rescale factor applied to the noise predictions. This is used to improve image quality and fix + overexposure. Based on Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are + Flawed](https://huggingface.co/papers/2305.08891). + use_original_formulation (`bool`, defaults to `False`): + Whether to use the original formulation of classifier-free guidance as proposed in the paper. By default, + we use the diffusers-native implementation that has been in the codebase for a long time. See + [~guiders.classifier_free_guidance.ClassifierFreeGuidance] for more details. + start (`float`, defaults to `0.0`): + The fraction of the total number of denoising steps after which guidance starts. + stop (`float`, defaults to `1.0`): + The fraction of the total number of denoising steps after which guidance stops. + """ + + _input_predictions = ["pred_cond", "pred_uncond"] + + def __init__( + self, + guidance_scale: float = 7.5, + adaptive_projected_guidance_momentum: Optional[float] = None, + adaptive_projected_guidance_rescale: float = 15.0, + eta: float = 1.0, + guidance_rescale: float = 0.0, + use_original_formulation: bool = False, + start: float = 0.0, + stop: float = 1.0, + ): + super().__init__(start, stop) + + self.guidance_scale = guidance_scale + self.adaptive_projected_guidance_momentum = adaptive_projected_guidance_momentum + self.adaptive_projected_guidance_rescale = adaptive_projected_guidance_rescale + self.eta = eta + self.guidance_rescale = guidance_rescale + self.use_original_formulation = use_original_formulation + self.momentum_buffer = None + + def prepare_inputs(self, denoiser: torch.nn.Module, *args: Union[Tuple[torch.Tensor], List[torch.Tensor]]) -> Tuple[List[torch.Tensor], ...]: + if self._step == 0: + if self.adaptive_projected_guidance_momentum is not None: + self.momentum_buffer = MomentumBuffer(self.adaptive_projected_guidance_momentum) + return _default_prepare_inputs(denoiser, self.num_conditions, *args) + + def prepare_outputs(self, denoiser: torch.nn.Module, pred: torch.Tensor) -> None: + self._num_outputs_prepared += 1 + if self._num_outputs_prepared > self.num_conditions: + raise ValueError(f"Expected {self.num_conditions} outputs, but prepare_outputs called more times.") + key = self._input_predictions[self._num_outputs_prepared - 1] + self._preds[key] = pred + + def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> torch.Tensor: + pred = None + + if not self._is_cfg_enabled(): + pred = pred_cond + else: + pred = normalized_guidance( + pred_cond, + pred_uncond, + self.guidance_scale, + self.momentum_buffer, + self.eta, + self.adaptive_projected_guidance_rescale, + self.use_original_formulation, + ) + + if self.guidance_rescale > 0.0: + pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale) + + return pred + + @property + def is_conditional(self) -> bool: + return self._num_outputs_prepared == 0 + + @property + def num_conditions(self) -> int: + num_conditions = 1 + if self._is_cfg_enabled(): + num_conditions += 1 + return num_conditions + + def _is_cfg_enabled(self) -> bool: + if not self._enabled: + return False + + is_within_range = True + if self._num_inference_steps is not None: + skip_start_step = int(self._start * self._num_inference_steps) + skip_stop_step = int(self._stop * self._num_inference_steps) + is_within_range = skip_start_step <= self._step < skip_stop_step + + is_close = False + if self.use_original_formulation: + is_close = math.isclose(self.guidance_scale, 0.0) + else: + is_close = math.isclose(self.guidance_scale, 1.0) + + return is_within_range and not is_close + + +class MomentumBuffer: + def __init__(self, momentum: float): + self.momentum = momentum + self.running_average = 0 + + def update(self, update_value: torch.Tensor): + new_average = self.momentum * self.running_average + self.running_average = update_value + new_average + + +def normalized_guidance( + pred_cond: torch.Tensor, + pred_uncond: torch.Tensor, + guidance_scale: float, + momentum_buffer: Optional[MomentumBuffer] = None, + eta: float = 1.0, + norm_threshold: float = 0.0, + use_original_formulation: bool = False, +): + diff = pred_cond - pred_uncond + dim = [-i for i in range(1, len(diff.shape))] + if momentum_buffer is not None: + momentum_buffer.update(diff) + diff = momentum_buffer.running_average + if norm_threshold > 0: + ones = torch.ones_like(diff) + diff_norm = diff.norm(p=2, dim=dim, keepdim=True) + scale_factor = torch.minimum(ones, norm_threshold / diff_norm) + diff = diff * scale_factor + v0, v1 = diff.double(), pred_cond.double() + v1 = torch.nn.functional.normalize(v1, dim=dim) + v0_parallel = (v0 * v1).sum(dim=dim, keepdim=True) * v1 + v0_orthogonal = v0 - v0_parallel + diff_parallel, diff_orthogonal = v0_parallel.type_as(diff), v0_orthogonal.type_as(diff) + normalized_update = diff_orthogonal + eta * diff_parallel + pred = pred_cond if use_original_formulation else pred_uncond + pred = pred + (guidance_scale - 1) * normalized_update + return pred diff --git a/src/diffusers/guiders/classifier_free_guidance.py b/src/diffusers/guiders/classifier_free_guidance.py index b3508307d478..3deacdfb2863 100644 --- a/src/diffusers/guiders/classifier_free_guidance.py +++ b/src/diffusers/guiders/classifier_free_guidance.py @@ -106,12 +106,17 @@ def num_conditions(self) -> int: def _is_cfg_enabled(self) -> bool: if not self._enabled: return False - skip_start_step = int(self._start * self._num_inference_steps) - skip_stop_step = int(self._stop * self._num_inference_steps) - is_within_range = skip_start_step <= self._step < skip_stop_step + + is_within_range = True + if self._num_inference_steps is not None: + skip_start_step = int(self._start * self._num_inference_steps) + skip_stop_step = int(self._stop * self._num_inference_steps) + is_within_range = skip_start_step <= self._step < skip_stop_step + is_close = False if self.use_original_formulation: is_close = math.isclose(self.guidance_scale, 0.0) else: is_close = math.isclose(self.guidance_scale, 1.0) + return is_within_range and not is_close diff --git a/src/diffusers/guiders/skip_layer_guidance.py b/src/diffusers/guiders/skip_layer_guidance.py index 92ae7f8518d8..4abb93272e9a 100644 --- a/src/diffusers/guiders/skip_layer_guidance.py +++ b/src/diffusers/guiders/skip_layer_guidance.py @@ -206,21 +206,31 @@ def num_conditions(self) -> int: def _is_cfg_enabled(self) -> bool: if not self._enabled: return False - skip_start_step = int(self._start * self._num_inference_steps) - skip_stop_step = int(self._stop * self._num_inference_steps) - is_within_range = skip_start_step <= self._step < skip_stop_step + + is_within_range = True + if self._num_inference_steps is not None: + skip_start_step = int(self._start * self._num_inference_steps) + skip_stop_step = int(self._stop * self._num_inference_steps) + is_within_range = skip_start_step <= self._step < skip_stop_step + is_close = False if self.use_original_formulation: is_close = math.isclose(self.guidance_scale, 0.0) else: is_close = math.isclose(self.guidance_scale, 1.0) + return is_within_range and not is_close def _is_slg_enabled(self) -> bool: if not self._enabled: return False - skip_start_step = int(self._start * self._num_inference_steps) - skip_stop_step = int(self._stop * self._num_inference_steps) - is_within_range = skip_start_step < self._step < skip_stop_step + + is_within_range = True + if self._num_inference_steps is not None: + skip_start_step = int(self._start * self._num_inference_steps) + skip_stop_step = int(self._stop * self._num_inference_steps) + is_within_range = skip_start_step < self._step < skip_stop_step + is_zero = math.isclose(self.skip_layer_guidance_scale, 0.0) + return is_within_range and not is_zero diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py index 5df57c6c1605..24d7e333c4fa 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py @@ -194,11 +194,7 @@ def inputs(self) -> List[InputParam]: PipelineImageInput, required=True, description="The image(s) to be used as ip adapter" - ), - InputParam( - "guidance_scale", - default=5.0, - ), + ) ] @@ -236,11 +232,10 @@ def encode_image(self, components, image, device, num_images_per_prompt, output_ # modified from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds def prepare_ip_adapter_image_embeds( - self, components, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance + self, components, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt ): image_embeds = [] - if do_classifier_free_guidance: - negative_image_embeds = [] + negative_image_embeds = [] if ip_adapter_image_embeds is None: if not isinstance(ip_adapter_image, list): ip_adapter_image = [ip_adapter_image] @@ -259,21 +254,18 @@ def prepare_ip_adapter_image_embeds( ) image_embeds.append(single_image_embeds[None, :]) - if do_classifier_free_guidance: - negative_image_embeds.append(single_negative_image_embeds[None, :]) + negative_image_embeds.append(single_negative_image_embeds[None, :]) else: for single_image_embeds in ip_adapter_image_embeds: - if do_classifier_free_guidance: - single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2) - negative_image_embeds.append(single_negative_image_embeds) + single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2) + negative_image_embeds.append(single_negative_image_embeds) image_embeds.append(single_image_embeds) ip_adapter_image_embeds = [] for i, single_image_embeds in enumerate(image_embeds): single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0) - if do_classifier_free_guidance: - single_negative_image_embeds = torch.cat([negative_image_embeds[i]] * num_images_per_prompt, dim=0) - single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0) + single_negative_image_embeds = torch.cat([negative_image_embeds[i]] * num_images_per_prompt, dim=0) + single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0) single_image_embeds = single_image_embeds.to(device=device) ip_adapter_image_embeds.append(single_image_embeds) @@ -323,6 +315,7 @@ def expected_components(self) -> List[ComponentSpec]: ComponentSpec("text_encoder_2", CLIPTextModelWithProjection), ComponentSpec("tokenizer", CLIPTokenizer), ComponentSpec("tokenizer_2", CLIPTokenizer), + ComponentSpec("guider", GuiderType, obj=ClassifierFreeGuidance()), ] @property @@ -337,7 +330,6 @@ def inputs(self) -> List[InputParam]: InputParam("negative_prompt"), InputParam("negative_prompt_2"), InputParam("cross_attention_kwargs"), - InputParam("guidance_scale",default=5.0), InputParam("clip_skip"), ] @@ -601,10 +593,9 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: data = self.get_block_state(state) self.check_inputs(pipeline, data) - data.do_classifier_free_guidance = data.guidance_scale > 1.0 + data.do_classifier_free_guidance = pipeline.guider.num_conditions > 1 data.device = pipeline._execution_device - # Encode input prompt data.text_encoder_lora_scale = ( data.cross_attention_kwargs.get("scale", None) if data.cross_attention_kwargs is not None else None @@ -1750,7 +1741,6 @@ def inputs(self) -> List[Tuple[str, Any]]: InputParam("crops_coords_top_left", default=(0, 0)), InputParam("negative_crops_coords_top_left", default=(0, 0)), InputParam("num_images_per_prompt", default=1), - InputParam("guidance_scale", required=True), InputParam("aesthetic_score", default=6.0), InputParam("negative_aesthetic_score", default=2.0), ] @@ -1897,7 +1887,8 @@ def __call__(self, pipeline: DiffusionPipeline, state: PipelineState) -> Pipelin and pipeline.unet is not None and pipeline.unet.config.time_cond_proj_dim is not None ): - data.guidance_scale_tensor = torch.tensor(data.guidance_scale - 1).repeat(data.batch_size * data.num_images_per_prompt) + # TODO(yiyi, aryan): Ideally, this should be `embedded_guidance_scale` instead of pulling from guider. Guider scales should be different from this! + data.guidance_scale_tensor = torch.tensor(pipeline.guider.guidance_scale - 1).repeat(data.batch_size * data.num_images_per_prompt) data.timestep_cond = self.get_guidance_scale_embedding( data.guidance_scale_tensor, embedding_dim=pipeline.unet.config.time_cond_proj_dim ).to(device=data.device, dtype=data.latents.dtype) @@ -1925,7 +1916,6 @@ def inputs(self) -> List[Tuple[str, Any]]: InputParam("crops_coords_top_left", default=(0, 0)), InputParam("negative_crops_coords_top_left", default=(0, 0)), InputParam("num_images_per_prompt", default=1), - InputParam("guidance_scale", default=5.0), ] @property @@ -2051,7 +2041,8 @@ def __call__(self, pipeline: DiffusionPipeline, state: PipelineState) -> Pipelin and pipeline.unet is not None and pipeline.unet.config.time_cond_proj_dim is not None ): - data.guidance_scale_tensor = torch.tensor(data.guidance_scale - 1).repeat(data.batch_size * data.num_images_per_prompt) + # TODO(yiyi, aryan): Ideally, this should be `embedded_guidance_scale` instead of pulling from guider. Guider scales should be different from this! + data.guidance_scale_tensor = torch.tensor(pipeline.guider.guidance_scale - 1).repeat(data.batch_size * data.num_images_per_prompt) data.timestep_cond = self.get_guidance_scale_embedding( data.guidance_scale_tensor, embedding_dim=pipeline.unet.config.time_cond_proj_dim ).to(device=data.device, dtype=data.latents.dtype) @@ -2234,6 +2225,10 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: data.num_channels_unet = pipeline.unet.config.in_channels data.disable_guidance = True if pipeline.unet.config.time_cond_proj_dim is not None else False + if data.disable_guidance: + pipeline.guider.force_disable() + else: + pipeline.guider.force_enable() # Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline data.extra_step_kwargs = self.prepare_extra_step_kwargs(pipeline, data.generator, data.eta) @@ -2354,7 +2349,6 @@ def inputs(self) -> List[Tuple[str, Any]]: InputParam("cross_attention_kwargs"), InputParam("generator"), InputParam("eta", default=0.0), - InputParam("guider_kwargs"), ] @property @@ -2627,9 +2621,8 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: data.disable_guidance = True if pipeline.unet.config.time_cond_proj_dim is not None else False if data.disable_guidance: pipeline.guider.force_disable() - - # (3) Prepare conditional inputs for controlnet using the guider - data.controlnet_disable_guidance = True if data.disable_guidance or data.guess_mode else False + else: + pipeline.guider.force_enable() # (4) Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline data.extra_step_kwargs = self.prepare_extra_step_kwargs(pipeline, data.generator, data.eta) @@ -2764,7 +2757,6 @@ def expected_components(self) -> List[ComponentSpec]: ComponentSpec("controlnet", ControlNetUnionModel), ComponentSpec("scheduler", KarrasDiffusionSchedulers), ComponentSpec("guider", GuiderType, obj=ClassifierFreeGuidance()), - ComponentSpec("controlnet_guider", GuiderType, obj=ClassifierFreeGuidance()), ComponentSpec("control_image_processor", VaeImageProcessor, obj=VaeImageProcessor(do_convert_rgb=True, do_normalize=False)), ] @@ -2781,12 +2773,9 @@ def inputs(self) -> List[Tuple[str, Any]]: InputParam("controlnet_conditioning_scale", default=1.0), InputParam("guess_mode", default=False), InputParam("num_images_per_prompt", default=1), - InputParam("guidance_scale", default=5.0), - InputParam("guidance_rescale", default=0.0), InputParam("cross_attention_kwargs"), InputParam("generator"), InputParam("eta", default=0.0), - InputParam("guider_kwargs") ] @property @@ -3029,7 +3018,6 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: crops_coords=data.crops_coords, ) data.height, data.width = data.control_image[idx].shape[-2:] - # (1.6) # controlnet_keep @@ -3043,80 +3031,48 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: # (2) Prepare conditional inputs for unet using the guider # adding default guider arguments: disable_guidance, guidance_scale, guidance_rescale data.disable_guidance = True if pipeline.unet.config.time_cond_proj_dim is not None else False - data.guider_kwargs = data.guider_kwargs or {} - data.guider_kwargs = { - **data.guider_kwargs, - "disable_guidance": data.disable_guidance, - "guidance_scale": data.guidance_scale, - "guidance_rescale": data.guidance_rescale, - "batch_size": data.batch_size * data.num_images_per_prompt, - } - pipeline.guider.set_guider(pipeline, data.guider_kwargs) - data.prompt_embeds = pipeline.guider.prepare_input( - data.prompt_embeds, - data.negative_prompt_embeds, - ) - data.add_time_ids = pipeline.guider.prepare_input( - data.add_time_ids, - data.negative_add_time_ids, - ) - data.pooled_prompt_embeds = pipeline.guider.prepare_input( - data.pooled_prompt_embeds, - data.negative_pooled_prompt_embeds, - ) - - if data.num_channels_unet == 9: - data.mask = pipeline.guider.prepare_input(data.mask, data.mask) - data.masked_image_latents = pipeline.guider.prepare_input(data.masked_image_latents, data.masked_image_latents) - - data.added_cond_kwargs = { - "text_embeds": data.pooled_prompt_embeds, - "time_ids": data.add_time_ids, - } - - if data.ip_adapter_embeds is not None: - data.ip_adapter_embeds = pipeline.guider.prepare_input(data.ip_adapter_embeds, data.negative_ip_adapter_embeds) - data.added_cond_kwargs["image_embeds"] = data.ip_adapter_embeds + if data.disable_guidance: + pipeline.guider.force_disable() + else: + pipeline.guider.force_enable() # (3) Prepare conditional inputs for controlnet using the guider - data.controlnet_disable_guidance = True if data.disable_guidance or data.guess_mode else False - data.controlnet_guider_kwargs = data.guider_kwargs or {} - data.controlnet_guider_kwargs = { - **data.controlnet_guider_kwargs, - "disable_guidance": data.controlnet_disable_guidance, - "guidance_scale": data.guidance_scale, - "guidance_rescale": data.guidance_rescale, - "batch_size": data.batch_size * data.num_images_per_prompt, - } - pipeline.controlnet_guider.set_guider(pipeline, data.controlnet_guider_kwargs) - data.controlnet_prompt_embeds = pipeline.controlnet_guider.prepare_input(data.prompt_embeds) - data.controlnet_added_cond_kwargs = { - "text_embeds": pipeline.controlnet_guider.prepare_input(data.pooled_prompt_embeds), - "time_ids": pipeline.controlnet_guider.prepare_input(data.add_time_ids), - } - for idx, _ in enumerate(data.control_image): - data.control_image[idx] = pipeline.controlnet_guider.prepare_input(data.control_image[idx], data.control_image[idx]) + # data.controlnet_prompt_embeds = pipeline.controlnet_guider.prepare_input(data.prompt_embeds) + # data.controlnet_added_cond_kwargs = { + # "text_embeds": pipeline.controlnet_guider.prepare_input(data.pooled_prompt_embeds), + # "time_ids": pipeline.controlnet_guider.prepare_input(data.add_time_ids), + # } - data.control_type = ( - data.control_type.reshape(1, -1) - .to(data.device, dtype=data.prompt_embeds.dtype) - ) + data.control_type = data.control_type.reshape(1, -1).to(data.device, dtype=data.prompt_embeds.dtype) repeat_by = data.batch_size * data.num_images_per_prompt // data.control_type.shape[0] data.control_type = data.control_type.repeat_interleave(repeat_by, dim=0) - data.control_type = pipeline.controlnet_guider.prepare_input(data.control_type, data.control_type) # (4) Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline data.extra_step_kwargs = self.prepare_extra_step_kwargs(pipeline, data.generator, data.eta) data.num_warmup_steps = max(len(data.timesteps) - data.num_inference_steps * pipeline.scheduler.order, 0) - with pipeline.progress_bar(total=data.num_inference_steps) as progress_bar: for i, t in enumerate(data.timesteps): - # prepare latents for unet using the guider - data.latent_model_input = pipeline.guider.prepare_input(data.latents, data.latents) + pipeline.guider.set_state(step=i, num_inference_steps=data.num_inference_steps, timestep=t) - # prepare latents for controlnet using the guider - data.control_model_input = pipeline.controlnet_guider.prepare_input(data.latents, data.latents) + ( + latents, + prompt_embeds, + add_time_ids, + pooled_prompt_embeds, + mask, + masked_image_latents, + ip_adapter_embeds, + ) = pipeline.guider.prepare_inputs( + pipeline.unet, + data.latents, + (data.prompt_embeds, data.negative_prompt_embeds), + (data.add_time_ids, data.negative_add_time_ids), + (data.pooled_prompt_embeds, data.negative_pooled_prompt_embeds), + data.mask, + data.masked_image_latents, + (data.ip_adapter_embeds, data.negative_ip_adapter_embeds), + ) if isinstance(data.controlnet_keep[i], list): data.cond_scale = [c * s for c, s in zip(data.controlnet_conditioning_scale, data.controlnet_keep[i])] @@ -3126,48 +3082,72 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: data.controlnet_cond_scale = data.controlnet_cond_scale[0] data.cond_scale = data.controlnet_cond_scale * data.controlnet_keep[i] - data.down_block_res_samples, data.mid_block_res_sample = pipeline.controlnet( - pipeline.scheduler.scale_model_input(data.control_model_input, t), - t, - encoder_hidden_states=data.controlnet_prompt_embeds, - controlnet_cond=data.control_image, - control_type=data.control_type, - control_type_idx=data.control_mode, - conditioning_scale=data.cond_scale, - guess_mode=data.guess_mode, - added_cond_kwargs=data.controlnet_added_cond_kwargs, - return_dict=False, - ) + for batch_index, ( + latents_i, prompt_embeds_i, add_time_ids_i, pooled_prompt_embeds_i, mask_i, masked_image_latents_i, ip_adapter_embeds_i + ) in enumerate(zip( + latents, prompt_embeds, add_time_ids, pooled_prompt_embeds, mask, masked_image_latents, ip_adapter_embeds + )): + latents_i = pipeline.scheduler.scale_model_input(latents_i, t) + + # Prepare for inpainting + if data.num_channels_unet == 9: + latents_i = torch.cat([latents_i, mask_i, masked_image_latents_i], dim=1) + + # Prepare additional conditionings + data.added_cond_kwargs = { + "text_embeds": pooled_prompt_embeds_i, + "time_ids": add_time_ids_i, + } + if ip_adapter_embeds_i is not None: + data.added_cond_kwargs["image_embeds"] = ip_adapter_embeds_i + + # Prepare controlnet additional conditionings + data.controlnet_added_cond_kwargs = { + "text_embeds": pooled_prompt_embeds_i, + "time_ids": add_time_ids_i, + } + + if pipeline.guider.is_conditional or not data.guess_mode: + data.down_block_res_samples, data.mid_block_res_sample = pipeline.controlnet( + latents_i, + t, + encoder_hidden_states=prompt_embeds_i, + controlnet_cond=data.control_image, + control_type=data.control_type, + control_type_idx=data.control_mode, + conditioning_scale=data.cond_scale, + guess_mode=data.guess_mode, + added_cond_kwargs=data.controlnet_added_cond_kwargs, + return_dict=False, + ) + elif pipeline.guider.is_unconditional and data.guess_mode: + data.down_block_res_samples = [torch.zeros_like(d) for d in data.down_block_res_samples] + data.mid_block_res_sample = torch.zeros_like(data.mid_block_res_sample) - # when we apply guidance for unet, but not for controlnet: - # add 0 to the unconditional batch - data.down_block_res_samples = pipeline.guider.prepare_input( - data.down_block_res_samples, [torch.zeros_like(d) for d in data.down_block_res_samples] - ) - data.mid_block_res_sample = pipeline.guider.prepare_input( - data.mid_block_res_sample, torch.zeros_like(data.mid_block_res_sample) - ) + if data.num_channels_unet == 9: + latents_i = torch.cat([latents_i, mask_i, masked_image_latents_i], dim=1) + + data.noise_pred = pipeline.unet( + latents_i, + t, + encoder_hidden_states=prompt_embeds_i, + timestep_cond=data.timestep_cond, + cross_attention_kwargs=data.cross_attention_kwargs, + added_cond_kwargs=data.added_cond_kwargs, + down_block_additional_residuals=data.down_block_res_samples, + mid_block_additional_residual=data.mid_block_res_sample, + return_dict=False, + )[0] + data.noise_pred = pipeline.guider.prepare_outputs(pipeline.unet, data.noise_pred) + + # Perform guidance + outputs = pipeline.guider.outputs + data.noise_pred = pipeline.guider(**outputs) - data.latent_model_input = pipeline.scheduler.scale_model_input(data.latent_model_input, t) - if data.num_channels_unet == 9: - data.latent_model_input = torch.cat([data.latent_model_input, data.mask, data.masked_image_latents], dim=1) - - data.noise_pred = pipeline.unet( - data.latent_model_input, - t, - encoder_hidden_states=data.prompt_embeds, - timestep_cond=data.timestep_cond, - cross_attention_kwargs=data.cross_attention_kwargs, - added_cond_kwargs=data.added_cond_kwargs, - down_block_additional_residuals=data.down_block_res_samples, - mid_block_additional_residual=data.mid_block_res_sample, - return_dict=False, - )[0] - # perform guidance - data.noise_pred = pipeline.guider.apply_guidance(data.noise_pred, timestep=t, latents=data.latents) - # compute the previous noisy sample x_t -> x_t-1 + # Perform scheduler step using the predicted output data.latents_dtype = data.latents.dtype data.latents = pipeline.scheduler.step(data.noise_pred, t, data.latents, **data.extra_step_kwargs, return_dict=False)[0] + if data.latents.dtype != data.latents_dtype: if torch.backends.mps.is_available(): # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 @@ -3180,14 +3160,10 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: data.init_latents_proper = pipeline.scheduler.add_noise( data.init_latents_proper, data.noise, torch.tensor([data.noise_timestep]) ) - data.latents = (1 - data.mask) * data.init_latents_proper + data.mask * data.latents if i == len(data.timesteps) - 1 or ((i + 1) > data.num_warmup_steps and (i + 1) % pipeline.scheduler.order == 0): progress_bar.update() - - pipeline.guider.reset_guider(pipeline) - pipeline.controlnet_guider.reset_guider(pipeline) self.add_block_state(state, data) From 31593e2c3336b5eed36ae4349214dc612585ebf2 Mon Sep 17 00:00:00 2001 From: Aryan Date: Mon, 14 Apr 2025 15:56:36 +0200 Subject: [PATCH 04/39] update --- .../pipeline_stable_diffusion_xl_modular.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py index 24d7e333c4fa..0cb4294e12b9 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py @@ -184,6 +184,7 @@ def expected_components(self) -> List[ComponentSpec]: ComponentSpec("image_encoder", CLIPVisionModelWithProjection), ComponentSpec("feature_extractor", CLIPImageProcessor), ComponentSpec("unet", UNet2DConditionModel), + ComponentSpec("guider", GuiderType), ] @property @@ -276,7 +277,7 @@ def prepare_ip_adapter_image_embeds( def __call__(self, pipeline, state: PipelineState) -> PipelineState: data = self.get_block_state(state) - data.do_classifier_free_guidance = data.guidance_scale > 1.0 + data.do_classifier_free_guidance = pipeline.guider.num_conditions > 1 data.device = pipeline._execution_device data.ip_adapter_embeds = self.prepare_ip_adapter_image_embeds( @@ -315,7 +316,7 @@ def expected_components(self) -> List[ComponentSpec]: ComponentSpec("text_encoder_2", CLIPTextModelWithProjection), ComponentSpec("tokenizer", CLIPTokenizer), ComponentSpec("tokenizer_2", CLIPTokenizer), - ComponentSpec("guider", GuiderType, obj=ClassifierFreeGuidance()), + ComponentSpec("guider", GuiderType), ] @property @@ -3490,6 +3491,11 @@ def description(self): "- to run the ip_adapter workflow, you need to provide `ip_adapter_image`\n" + \ "- for text-to-image generation, all you need to provide is `prompt`" +# TODO(yiyi, aryan): We need another step before text encoder to set the `num_inference_steps` attribute for guider so that +# things like when to do guidance and how many conditions to be prepared can be determined. Currently, this is done by +# always assuming you want to do guidance in the Guiders. So, negative embeddings are prepared regardless of what the +# configuration of guider is. + # block mapping TEXT2IMAGE_BLOCKS = OrderedDict([ ("text_encoder", StableDiffusionXLTextEncoderStep), @@ -3611,7 +3617,6 @@ def num_channels_latents(self): "negative_prompt": InputParam("negative_prompt", type_hint=Union[str, List[str]], description="The prompt or prompts not to guide the image generation"), "negative_prompt_2": InputParam("negative_prompt_2", type_hint=Union[str, List[str]], description="The negative prompt or prompts for text_encoder_2"), "cross_attention_kwargs": InputParam("cross_attention_kwargs", type_hint=Optional[dict], description="Kwargs dictionary passed to the AttentionProcessor"), - "guidance_scale": InputParam("guidance_scale", type_hint=float, default=5.0, description="Classifier-Free Diffusion Guidance scale"), "clip_skip": InputParam("clip_skip", type_hint=Optional[int], description="Number of layers to skip in CLIP text encoder"), "image": InputParam("image", type_hint=PipelineImageInput, required=True, description="The image(s) to modify for img2img or inpainting"), "mask_image": InputParam("mask_image", type_hint=PipelineImageInput, required=True, description="Mask image for inpainting, white pixels will be repainted"), @@ -3636,7 +3641,6 @@ def num_channels_latents(self): "negative_crops_coords_top_left": InputParam("negative_crops_coords_top_left", type_hint=Tuple[int, int], default=(0, 0), description="Negative conditioning crop coordinates"), "aesthetic_score": InputParam("aesthetic_score", type_hint=float, default=6.0, description="Simulates aesthetic score of generated image"), "negative_aesthetic_score": InputParam("negative_aesthetic_score", type_hint=float, default=2.0, description="Simulates negative aesthetic score"), - "guidance_rescale": InputParam("guidance_rescale", type_hint=float, default=0.0, description="Guidance rescale factor to fix overexposure"), "eta": InputParam("eta", type_hint=float, default=0.0, description="Parameter η in the DDIM paper"), "guider_kwargs": InputParam("guider_kwargs", type_hint=Optional[Dict[str, Any]], description="Kwargs dictionary passed to the Guider"), "output_type": InputParam("output_type", type_hint=str, default="pil", description="Output format (pil/tensor/np.array)"), @@ -3704,4 +3708,4 @@ def num_channels_latents(self): SDXL_OUTPUTS_SCHEMA = { "images": OutputParam("images", type_hint=Union[Tuple[Union[List[PIL.Image.Image], List[torch.Tensor], List[np.array]]], StableDiffusionXLPipelineOutput], description="The final generated images") -} \ No newline at end of file +} From 625530295de7bb2e14c169a05c4358772fea907f Mon Sep 17 00:00:00 2001 From: Aryan Date: Mon, 14 Apr 2025 15:57:17 +0200 Subject: [PATCH 05/39] update --- .../pipeline_stable_diffusion_xl_modular.py | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py index 0cb4294e12b9..a02e869e8b32 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py @@ -233,10 +233,11 @@ def encode_image(self, components, image, device, num_images_per_prompt, output_ # modified from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds def prepare_ip_adapter_image_embeds( - self, components, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt + self, components, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance ): image_embeds = [] - negative_image_embeds = [] + if do_classifier_free_guidance: + negative_image_embeds = [] if ip_adapter_image_embeds is None: if not isinstance(ip_adapter_image, list): ip_adapter_image = [ip_adapter_image] @@ -255,18 +256,21 @@ def prepare_ip_adapter_image_embeds( ) image_embeds.append(single_image_embeds[None, :]) - negative_image_embeds.append(single_negative_image_embeds[None, :]) + if do_classifier_free_guidance: + negative_image_embeds.append(single_negative_image_embeds[None, :]) else: for single_image_embeds in ip_adapter_image_embeds: - single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2) - negative_image_embeds.append(single_negative_image_embeds) + if do_classifier_free_guidance: + single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2) + negative_image_embeds.append(single_negative_image_embeds) image_embeds.append(single_image_embeds) ip_adapter_image_embeds = [] for i, single_image_embeds in enumerate(image_embeds): single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0) - single_negative_image_embeds = torch.cat([negative_image_embeds[i]] * num_images_per_prompt, dim=0) - single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0) + if do_classifier_free_guidance: + single_negative_image_embeds = torch.cat([negative_image_embeds[i]] * num_images_per_prompt, dim=0) + single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0) single_image_embeds = single_image_embeds.to(device=device) ip_adapter_image_embeds.append(single_image_embeds) From 2238f55f40b3a7ddffab0fb176aa8d1bf1dec22f Mon Sep 17 00:00:00 2001 From: Aryan Date: Mon, 14 Apr 2025 16:43:16 +0200 Subject: [PATCH 06/39] cfg zero* --- src/diffusers/__init__.py | 2 + src/diffusers/guiders/__init__.py | 1 + .../guiders/classifier_free_guidance.py | 5 + .../classifier_free_zero_star_guidance.py | 143 ++++++++++++++++++ src/diffusers/guiders/skip_layer_guidance.py | 6 + 5 files changed, 157 insertions(+) create mode 100644 src/diffusers/guiders/classifier_free_zero_star_guidance.py diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index d3a61df3baff..b2e24614b97a 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -134,6 +134,7 @@ [ "AdaptiveProjectedGuidance", "ClassifierFreeGuidance", + "ClassifierFreeZeroStarGuidance", "SkipLayerGuidance", ] ) @@ -724,6 +725,7 @@ from .guiders import ( AdaptiveProjectedGuidance, ClassifierFreeGuidance, + ClassifierFreeZeroStarGuidance, SkipLayerGuidance, ) from .hooks import ( diff --git a/src/diffusers/guiders/__init__.py b/src/diffusers/guiders/__init__.py index e3c6494de090..ac3837a6b4d7 100644 --- a/src/diffusers/guiders/__init__.py +++ b/src/diffusers/guiders/__init__.py @@ -20,6 +20,7 @@ if is_torch_available(): from .adaptive_projected_guidance import AdaptiveProjectedGuidance from .classifier_free_guidance import ClassifierFreeGuidance + from .classifier_free_zero_star_guidance import ClassifierFreeZeroStarGuidance from .skip_layer_guidance import SkipLayerGuidance GuiderType = Union[ClassifierFreeGuidance, SkipLayerGuidance] diff --git a/src/diffusers/guiders/classifier_free_guidance.py b/src/diffusers/guiders/classifier_free_guidance.py index 3deacdfb2863..6978080b7152 100644 --- a/src/diffusers/guiders/classifier_free_guidance.py +++ b/src/diffusers/guiders/classifier_free_guidance.py @@ -23,20 +23,25 @@ class ClassifierFreeGuidance(BaseGuidance): """ Classifier-free guidance (CFG): https://huggingface.co/papers/2207.12598 + CFG is a technique used to improve generation quality and condition-following in diffusion models. It works by jointly training a model on both conditional and unconditional data, and using a weighted sum of the two during inference. This allows the model to tradeoff between generation quality and sample diversity. The original paper proposes scaling and shifting the conditional distribution based on the difference between conditional and unconditional predictions. [x_pred = x_cond + scale * (x_cond - x_uncond)] + Diffusers implemented the scaling and shifting on the unconditional prediction instead based on the [Imagen paper](https://huggingface.co/papers/2205.11487), which is equivalent to what the original paper proposed in theory. [x_pred = x_uncond + scale * (x_cond - x_uncond)] + The intution behind the original formulation can be thought of as moving the conditional distribution estimates further away from the unconditional distribution estimates, while the diffusers-native implementation can be thought of as moving the unconditional distribution towards the conditional distribution estimates to get rid of the unconditional predictions (usually negative features like "bad quality, bad anotomy, watermarks", etc.) + The `use_original_formulation` argument can be set to `True` to use the original CFG formulation mentioned in the paper. By default, we use the diffusers-native implementation that has been in the codebase for a long time. + Args: guidance_scale (`float`, defaults to `7.5`): The scale parameter for classifier-free guidance. Higher values result in stronger conditioning on the text diff --git a/src/diffusers/guiders/classifier_free_zero_star_guidance.py b/src/diffusers/guiders/classifier_free_zero_star_guidance.py new file mode 100644 index 000000000000..04c504f8f2d6 --- /dev/null +++ b/src/diffusers/guiders/classifier_free_zero_star_guidance.py @@ -0,0 +1,143 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from typing import Optional, Union, Tuple, List + +import torch + +from .guider_utils import BaseGuidance, rescale_noise_cfg, _default_prepare_inputs + + +class ClassifierFreeZeroStarGuidance(BaseGuidance): + """ + Classifier-free Zero* (CFG-Zero*): https://huggingface.co/papers/2503.18886 + + This is an implementation of the Classifier-Free Zero* guidance technique, which is a variant of classifier-free + guidance. It proposes zero initialization of the noise predictions for the first few steps of the diffusion + process, and also introduces an optimal rescaling factor for the noise predictions, which can help in improving the + quality of generated images. + + The authors of the paper suggest setting zero initialization in the first 4% of the inference steps. + + Args: + guidance_scale (`float`, defaults to `7.5`): + The scale parameter for classifier-free guidance. Higher values result in stronger conditioning on the text + prompt, while lower values allow for more freedom in generation. Higher values may lead to saturation and + deterioration of image quality. + zero_init_steps (`int`, defaults to `1`): + The number of inference steps for which the noise predictions are zeroed out (see Section 4.2). + guidance_rescale (`float`, defaults to `0.0`): + The rescale factor applied to the noise predictions. This is used to improve image quality and fix + overexposure. Based on Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are + Flawed](https://huggingface.co/papers/2305.08891). + use_original_formulation (`bool`, defaults to `False`): + Whether to use the original formulation of classifier-free guidance as proposed in the paper. By default, + we use the diffusers-native implementation that has been in the codebase for a long time. See + [~guiders.classifier_free_guidance.ClassifierFreeGuidance] for more details. + start (`float`, defaults to `0.01`): + The fraction of the total number of denoising steps after which guidance starts. + stop (`float`, defaults to `0.2`): + The fraction of the total number of denoising steps after which guidance stops. + """ + + _input_predictions = ["pred_cond", "pred_uncond"] + + def __init__( + self, + guidance_scale: float = 7.5, + zero_init_steps: int = 1, + guidance_rescale: float = 0.0, + use_original_formulation: bool = False, + start: float = 0.0, + stop: float = 1.0, + ): + super().__init__(start, stop) + + self.guidance_scale = guidance_scale + self.zero_init_steps = zero_init_steps + self.guidance_rescale = guidance_rescale + self.use_original_formulation = use_original_formulation + + def prepare_inputs(self, denoiser: torch.nn.Module, *args: Union[Tuple[torch.Tensor], List[torch.Tensor]]) -> Tuple[List[torch.Tensor], ...]: + return _default_prepare_inputs(denoiser, self.num_conditions, *args) + + def prepare_outputs(self, denoiser: torch.nn.Module, pred: torch.Tensor) -> None: + self._num_outputs_prepared += 1 + if self._num_outputs_prepared > self.num_conditions: + raise ValueError(f"Expected {self.num_conditions} outputs, but prepare_outputs called more times.") + key = self._input_predictions[self._num_outputs_prepared - 1] + self._preds[key] = pred + + def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> torch.Tensor: + pred = None + + if self._step < self.zero_init_steps: + pred = torch.zeros_like(pred_cond) + elif not self._is_cfg_enabled(): + pred = pred_cond + else: + pred_cond_flat = pred_cond.flatten(1) + pred_uncond_flat = pred_uncond.flatten(1) + alpha = cfg_zero_star_scale(pred_cond_flat, pred_uncond_flat) + alpha = alpha.view(-1, *(1,) * (len(pred_cond.shape) - 1)) + pred_uncond = pred_uncond * alpha + shift = pred_cond - pred_uncond + pred = pred_cond if self.use_original_formulation else pred_uncond + pred = pred + self.guidance_scale * shift + + if self.guidance_rescale > 0.0: + pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale) + + return pred + + @property + def is_conditional(self) -> bool: + return self._num_outputs_prepared == 0 + + @property + def num_conditions(self) -> int: + num_conditions = 1 + if self._is_cfg_enabled(): + num_conditions += 1 + return num_conditions + + def _is_cfg_enabled(self) -> bool: + if not self._enabled: + return False + + is_within_range = True + if self._num_inference_steps is not None: + skip_start_step = int(self._start * self._num_inference_steps) + skip_stop_step = int(self._stop * self._num_inference_steps) + is_within_range = skip_start_step <= self._step < skip_stop_step + + is_close = False + if self.use_original_formulation: + is_close = math.isclose(self.guidance_scale, 0.0) + else: + is_close = math.isclose(self.guidance_scale, 1.0) + + return is_within_range and not is_close + + +def cfg_zero_star_scale(cond: torch.Tensor, uncond: torch.Tensor, eps: float = 1e-8) -> torch.Tensor: + cond_dtype = cond.dtype + cond = cond.float() + uncond = uncond.float() + dot_product = torch.sum(cond * uncond, dim=1, keepdim=True) + squared_norm = torch.sum(uncond**2, dim=1, keepdim=True) + eps + # st_star = v_cond^T * v_uncond / ||v_uncond||^2 + scale = dot_product / squared_norm + return scale.to(dtype=cond_dtype) diff --git a/src/diffusers/guiders/skip_layer_guidance.py b/src/diffusers/guiders/skip_layer_guidance.py index 4abb93272e9a..bac851c0dc28 100644 --- a/src/diffusers/guiders/skip_layer_guidance.py +++ b/src/diffusers/guiders/skip_layer_guidance.py @@ -26,20 +26,26 @@ class SkipLayerGuidance(BaseGuidance): """ Skip Layer Guidance (SLG): https://github.com/Stability-AI/sd3.5 Spatio-Temporal Guidance (STG): https://huggingface.co/papers/2411.18664 + SLG was introduced by StabilityAI for improving structure and anotomy coherence in generated images. It works by skipping the forward pass of specified transformer blocks during the denoising process on an additional conditional batch of data, apart from the conditional and unconditional batches already used in CFG ([~guiders.classifier_free_guidance.ClassifierFreeGuidance]), and then scaling and shifting the CFG predictions based on the difference between conditional without skipping and conditional with skipping predictions. + The intution behind SLG can be thought of as moving the CFG predicted distribution estimates further away from worse versions of the conditional distribution estimates (because skipping layers is equivalent to using a worse version of the model for the conditional prediction). + STG is an improvement and follow-up work combining ideas from SLG, PAG and similar techniques for improving generation quality in video diffusion models. + Additional reading: - [Guiding a Diffusion Model with a Bad Version of Itself](https://huggingface.co/papers/2406.02507) + The values for `skip_layer_guidance_scale`, `skip_layer_guidance_start`, and `skip_layer_guidance_stop` are defaulted to the recommendations by StabilityAI for Stable Diffusion 3.5 Medium. + Args: guidance_scale (`float`, defaults to `7.5`): The scale parameter for classifier-free guidance. Higher values result in stronger conditioning on the text From 52b9b61f62087f0120cd6045bbac3f0e4cdc55e2 Mon Sep 17 00:00:00 2001 From: Aryan Date: Mon, 14 Apr 2025 17:34:10 +0200 Subject: [PATCH 07/39] use unwrap_module for torch compiled modules --- .../pipeline_stable_diffusion_xl_modular.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py index a02e869e8b32..0b36633efc4e 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py @@ -30,7 +30,7 @@ scale_lora_layers, unscale_lora_layers, ) -from ...utils.torch_utils import is_compiled_module, randn_tensor +from ...utils.torch_utils import randn_tensor, unwrap_module from ..controlnet.multicontrolnet import MultiControlNetModel from ..modular_pipeline import ( AutoPipelineBlocks, @@ -2545,7 +2545,7 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: data.height = data.height * pipeline.vae_scale_factor data.width = data.width * pipeline.vae_scale_factor - controlnet = pipeline.controlnet._orig_mod if is_compiled_module(pipeline.controlnet) else pipeline.controlnet + controlnet = unwrap_module(pipeline.controlnet) # (1.1) # control_guidance_start/control_guidance_end (align format) @@ -2973,7 +2973,7 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: data.height = data.height * pipeline.vae_scale_factor data.width = data.width * pipeline.vae_scale_factor - controlnet = pipeline.controlnet._orig_mod if is_compiled_module(pipeline.controlnet) else pipeline.controlnet + controlnet = unwrap_module(pipeline.controlnet) # (1.1) # control guidance From d8d2ea37293727d2201ddc09b03a307bbb07c421 Mon Sep 17 00:00:00 2001 From: Aryan Date: Mon, 14 Apr 2025 17:35:30 +0200 Subject: [PATCH 08/39] remove guider kwargs --- .../stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py index 0b36633efc4e..c67ec2f4decb 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py @@ -3646,7 +3646,6 @@ def num_channels_latents(self): "aesthetic_score": InputParam("aesthetic_score", type_hint=float, default=6.0, description="Simulates aesthetic score of generated image"), "negative_aesthetic_score": InputParam("negative_aesthetic_score", type_hint=float, default=2.0, description="Simulates negative aesthetic score"), "eta": InputParam("eta", type_hint=float, default=0.0, description="Parameter η in the DDIM paper"), - "guider_kwargs": InputParam("guider_kwargs", type_hint=Optional[Dict[str, Any]], description="Kwargs dictionary passed to the Guider"), "output_type": InputParam("output_type", type_hint=str, default="pil", description="Output format (pil/tensor/np.array)"), "return_dict": InputParam("return_dict", type_hint=bool, default=True, description="Whether to return a StableDiffusionXLPipelineOutput"), "ip_adapter_image": InputParam("ip_adapter_image", type_hint=PipelineImageInput, required=True, description="Image(s) to be used as IP adapter"), From b31904b853d0d271c16c640a712dc3ac43f15a05 Mon Sep 17 00:00:00 2001 From: Aryan Date: Mon, 14 Apr 2025 17:39:06 +0200 Subject: [PATCH 09/39] remove commented code --- .../pipeline_stable_diffusion_xl_modular.py | 7 ------- 1 file changed, 7 deletions(-) diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py index c67ec2f4decb..68d9d913bd3a 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py @@ -3041,13 +3041,6 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: else: pipeline.guider.force_enable() - # (3) Prepare conditional inputs for controlnet using the guider - # data.controlnet_prompt_embeds = pipeline.controlnet_guider.prepare_input(data.prompt_embeds) - # data.controlnet_added_cond_kwargs = { - # "text_embeds": pipeline.controlnet_guider.prepare_input(data.pooled_prompt_embeds), - # "time_ids": pipeline.controlnet_guider.prepare_input(data.add_time_ids), - # } - data.control_type = data.control_type.reshape(1, -1).to(data.device, dtype=data.prompt_embeds.dtype) repeat_by = data.batch_size * data.num_images_per_prompt // data.control_type.shape[0] data.control_type = data.control_type.repeat_interleave(repeat_by, dim=0) From 57c7e15a919d5d01fc236058ba1dc8fd2d16cccd Mon Sep 17 00:00:00 2001 From: Aryan Date: Mon, 14 Apr 2025 17:45:56 +0200 Subject: [PATCH 10/39] remove old guider --- src/diffusers/guider.py | 748 ------------------------------ src/diffusers/guiders/__init__.py | 2 +- 2 files changed, 1 insertion(+), 749 deletions(-) delete mode 100644 src/diffusers/guider.py diff --git a/src/diffusers/guider.py b/src/diffusers/guider.py deleted file mode 100644 index b42dca64d651..000000000000 --- a/src/diffusers/guider.py +++ /dev/null @@ -1,748 +0,0 @@ -# Copyright 2024 The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import re -from typing import Any, Dict, List, Optional, Tuple, Union - -import torch -import torch.nn as nn - -from .models.attention_processor import ( - Attention, - AttentionProcessor, - PAGCFGIdentitySelfAttnProcessor2_0, - PAGIdentitySelfAttnProcessor2_0, -) -from .utils import logging - - -logger = logging.get_logger(__name__) # pylint: disable=invalid-name - - -# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg -def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): - r""" - Rescales `noise_cfg` tensor based on `guidance_rescale` to improve image quality and fix overexposure. Based on - Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are - Flawed](https://arxiv.org/pdf/2305.08891.pdf). - - Args: - noise_cfg (`torch.Tensor`): - The predicted noise tensor for the guided diffusion process. - noise_pred_text (`torch.Tensor`): - The predicted noise tensor for the text-guided diffusion process. - guidance_rescale (`float`, *optional*, defaults to 0.0): - A rescale factor applied to the noise predictions. - - Returns: - noise_cfg (`torch.Tensor`): The rescaled noise prediction tensor. - """ - std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True) - std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True) - # rescale the results from guidance (fixes overexposure) - noise_pred_rescaled = noise_cfg * (std_text / std_cfg) - # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images - noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg - return noise_cfg - - -class CFGGuider: - """ - This class is used to guide the pipeline with CFG (Classifier-Free Guidance). - """ - - @property - def do_classifier_free_guidance(self): - return self._guidance_scale > 1.0 and not self._disable_guidance - - @property - def guidance_rescale(self): - return self._guidance_rescale - - @property - def guidance_scale(self): - return self._guidance_scale - - @property - def batch_size(self): - return self._batch_size - - def set_guider(self, pipeline, guider_kwargs: Dict[str, Any]): - # a flag to disable CFG, e.g. we disable it for LCM and use a guidance scale embedding instead - disable_guidance = guider_kwargs.get("disable_guidance", False) - guidance_scale = guider_kwargs.get("guidance_scale", None) - if guidance_scale is None: - raise ValueError("guidance_scale is not provided in guider_kwargs") - guidance_rescale = guider_kwargs.get("guidance_rescale", 0.0) - batch_size = guider_kwargs.get("batch_size", None) - if batch_size is None: - raise ValueError("batch_size is not provided in guider_kwargs") - self._guidance_scale = guidance_scale - self._guidance_rescale = guidance_rescale - self._batch_size = batch_size - self._disable_guidance = disable_guidance - - def reset_guider(self, pipeline): - pass - - def maybe_update_guider(self, pipeline, timestep): - pass - - def maybe_update_input(self, pipeline, cond_input): - pass - - def _maybe_split_prepared_input(self, cond): - """ - Process and potentially split the conditional input for Classifier-Free Guidance (CFG). - - This method handles inputs that may already have CFG applied (i.e. when `cond` is output of `prepare_input`). - It determines whether to split the input based on its batch size relative to the expected batch size. - - Args: - cond (torch.Tensor): The conditional input tensor to process. - - Returns: - Tuple[torch.Tensor, torch.Tensor]: A tuple containing: - - The negative conditional input (uncond_input) - - The positive conditional input (cond_input) - """ - if cond.shape[0] == self.batch_size * 2: - neg_cond = cond[0 : self.batch_size] - cond = cond[self.batch_size :] - return neg_cond, cond - elif cond.shape[0] == self.batch_size: - return cond, cond - else: - raise ValueError(f"Unsupported input shape: {cond.shape}") - - def _is_prepared_input(self, cond): - """ - Check if the input is already prepared for Classifier-Free Guidance (CFG). - - Args: - cond (torch.Tensor): The conditional input tensor to check. - - Returns: - bool: True if the input is already prepared, False otherwise. - """ - cond_tensor = cond[0] if isinstance(cond, (list, tuple)) else cond - - return cond_tensor.shape[0] == self.batch_size * 2 - - def prepare_input( - self, - cond_input: Union[torch.Tensor, List[torch.Tensor]], - negative_cond_input: Optional[Union[torch.Tensor, List[torch.Tensor]]] = None, - ) -> Union[torch.Tensor, List[torch.Tensor]]: - """ - Prepare the input for CFG. - - Args: - cond_input (Union[torch.Tensor, List[torch.Tensor]]): - The conditional input. It can be a single tensor or a - list of tensors. It must have the same length as `negative_cond_input`. - negative_cond_input (Union[torch.Tensor, List[torch.Tensor]]): The negative conditional input. It can be a - single tensor or a list of tensors. It must have the same length as `cond_input`. - - Returns: - Union[torch.Tensor, List[torch.Tensor]]: The prepared input. - """ - - # we check if cond_input already has CFG applied, and split if it is the case. - if self._is_prepared_input(cond_input) and self.do_classifier_free_guidance: - return cond_input - - if self._is_prepared_input(cond_input) and not self.do_classifier_free_guidance: - if isinstance(cond_input, list): - negative_cond_input, cond_input = zip(*[self._maybe_split_prepared_input(cond) for cond in cond_input]) - else: - negative_cond_input, cond_input = self._maybe_split_prepared_input(cond_input) - - if not self._is_prepared_input(cond_input) and self.do_classifier_free_guidance and negative_cond_input is None: - raise ValueError( - "`negative_cond_input` is required when cond_input does not already contains negative conditional input" - ) - - if isinstance(cond_input, (list, tuple)): - if not self.do_classifier_free_guidance: - return cond_input - - if len(negative_cond_input) != len(cond_input): - raise ValueError("The length of negative_cond_input and cond_input must be the same.") - prepared_input = [] - for neg_cond, cond in zip(negative_cond_input, cond_input): - if neg_cond.shape[0] != cond.shape[0]: - raise ValueError("The batch size of negative_cond_input and cond_input must be the same.") - prepared_input.append(torch.cat([neg_cond, cond], dim=0)) - return prepared_input - - elif isinstance(cond_input, torch.Tensor): - if not self.do_classifier_free_guidance: - return cond_input - else: - return torch.cat([negative_cond_input, cond_input], dim=0) - - else: - raise ValueError(f"Unsupported input type: {type(cond_input)}") - - def apply_guidance( - self, - model_output: torch.Tensor, - timestep: int = None, - latents: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - if not self.do_classifier_free_guidance: - return model_output - - noise_pred_uncond, noise_pred_text = model_output.chunk(2) - noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) - - if self.guidance_rescale > 0.0: - # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf - noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale) - return noise_pred - - -class PAGGuider: - """ - This class is used to guide the pipeline with CFG (Classifier-Free Guidance). - """ - - def __init__( - self, - pag_applied_layers: Union[str, List[str]], - pag_attn_processors: Tuple[AttentionProcessor, AttentionProcessor] = ( - PAGCFGIdentitySelfAttnProcessor2_0(), - PAGIdentitySelfAttnProcessor2_0(), - ), - ): - r""" - Set the the self-attention layers to apply PAG. Raise ValueError if the input is invalid. - - Args: - pag_applied_layers (`str` or `List[str]`): - One or more strings identifying the layer names, or a simple regex for matching multiple layers, where - PAG is to be applied. A few ways of expected usage are as follows: - - Single layers specified as - "blocks.{layer_index}" - - Multiple layers as a list - ["blocks.{layers_index_1}", "blocks.{layer_index_2}", ...] - - Multiple layers as a block name - "mid" - - Multiple layers as regex - "blocks.({layer_index_1}|{layer_index_2})" - pag_attn_processors: - (`Tuple[AttentionProcessor, AttentionProcessor]`, defaults to `(PAGCFGIdentitySelfAttnProcessor2_0(), - PAGIdentitySelfAttnProcessor2_0())`): A tuple of two attention processors. The first attention - processor is for PAG with Classifier-free guidance enabled (conditional and unconditional). The second - attention processor is for PAG with CFG disabled (unconditional only). - """ - - if not isinstance(pag_applied_layers, list): - pag_applied_layers = [pag_applied_layers] - if pag_attn_processors is not None: - if not isinstance(pag_attn_processors, tuple) or len(pag_attn_processors) != 2: - raise ValueError("Expected a tuple of two attention processors") - - for i in range(len(pag_applied_layers)): - if not isinstance(pag_applied_layers[i], str): - raise ValueError( - f"Expected either a string or a list of string but got type {type(pag_applied_layers[i])}" - ) - - self.pag_applied_layers = pag_applied_layers - self._pag_attn_processors = pag_attn_processors - - def _set_pag_attn_processor(self, model, pag_applied_layers, do_classifier_free_guidance): - r""" - Set the attention processor for the PAG layers. - """ - pag_attn_processors = self._pag_attn_processors - pag_attn_proc = pag_attn_processors[0] if do_classifier_free_guidance else pag_attn_processors[1] - - def is_self_attn(module: nn.Module) -> bool: - r""" - Check if the module is self-attention module based on its name. - """ - return isinstance(module, Attention) and not module.is_cross_attention - - def is_fake_integral_match(layer_id, name): - layer_id = layer_id.split(".")[-1] - name = name.split(".")[-1] - return layer_id.isnumeric() and name.isnumeric() and layer_id == name - - for layer_id in pag_applied_layers: - # for each PAG layer input, we find corresponding self-attention layers in the unet model - target_modules = [] - - for name, module in model.named_modules(): - # Identify the following simple cases: - # (1) Self Attention layer existing - # (2) Whether the module name matches pag layer id even partially - # (3) Make sure it's not a fake integral match if the layer_id ends with a number - # For example, blocks.1, blocks.10 should be differentiable if layer_id="blocks.1" - if ( - is_self_attn(module) - and re.search(layer_id, name) is not None - and not is_fake_integral_match(layer_id, name) - ): - logger.debug(f"Applying PAG to layer: {name}") - target_modules.append(module) - - if len(target_modules) == 0: - raise ValueError(f"Cannot find PAG layer to set attention processor for: {layer_id}") - - for module in target_modules: - module.processor = pag_attn_proc - - @property - def do_classifier_free_guidance(self): - return self._guidance_scale > 1 and not self._disable_guidance - - @property - def do_perturbed_attention_guidance(self): - return self._pag_scale > 0 and not self._disable_guidance - - @property - def do_pag_adaptive_scaling(self): - return self._pag_adaptive_scale > 0 and self._pag_scale > 0 and not self._disable_guidance - - @property - def guidance_scale(self): - return self._guidance_scale - - @property - def guidance_rescale(self): - return self._guidance_rescale - - @property - def batch_size(self): - return self._batch_size - - @property - def pag_scale(self): - return self._pag_scale - - @property - def pag_adaptive_scale(self): - return self._pag_adaptive_scale - - def set_guider(self, pipeline, guider_kwargs: Dict[str, Any]): - pag_scale = guider_kwargs.get("pag_scale", 3.0) - pag_adaptive_scale = guider_kwargs.get("pag_adaptive_scale", 0.0) - - batch_size = guider_kwargs.get("batch_size", None) - if batch_size is None: - raise ValueError("batch_size is a required argument for PAGGuider") - - guidance_scale = guider_kwargs.get("guidance_scale", None) - guidance_rescale = guider_kwargs.get("guidance_rescale", 0.0) - disable_guidance = guider_kwargs.get("disable_guidance", False) - - if guidance_scale is None: - raise ValueError("guidance_scale is a required argument for PAGGuider") - - self._pag_scale = pag_scale - self._pag_adaptive_scale = pag_adaptive_scale - self._guidance_scale = guidance_scale - self._disable_guidance = disable_guidance - self._guidance_rescale = guidance_rescale - self._batch_size = batch_size - if not hasattr(pipeline, "original_attn_proc") or pipeline.original_attn_proc is None: - pipeline.original_attn_proc = pipeline.unet.attn_processors - self._set_pag_attn_processor( - model=pipeline.unet if hasattr(pipeline, "unet") else pipeline.transformer, - pag_applied_layers=self.pag_applied_layers, - do_classifier_free_guidance=self.do_classifier_free_guidance, - ) - - def reset_guider(self, pipeline): - if ( - self.do_perturbed_attention_guidance - and hasattr(pipeline, "original_attn_proc") - and pipeline.original_attn_proc is not None - ): - pipeline.unet.set_attn_processor(pipeline.original_attn_proc) - pipeline.original_attn_proc = None - - def maybe_update_guider(self, pipeline, timestep): - pass - - def maybe_update_input(self, pipeline, cond_input): - pass - - def _is_prepared_input(self, cond): - """ - Check if the input is already prepared for Perturbed Attention Guidance (PAG). - - Args: - cond (torch.Tensor): The conditional input tensor to check. - - Returns: - bool: True if the input is already prepared, False otherwise. - """ - cond_tensor = cond[0] if isinstance(cond, (list, tuple)) else cond - - return cond_tensor.shape[0] == self.batch_size * 3 - - def _maybe_split_prepared_input(self, cond): - """ - Process and potentially split the conditional input for Classifier-Free Guidance (CFG). - - This method handles inputs that may already have CFG applied (i.e. when `cond` is output of `prepare_input`). - It determines whether to split the input based on its batch size relative to the expected batch size. - - Args: - cond (torch.Tensor): The conditional input tensor to process. - - Returns: - Tuple[torch.Tensor, torch.Tensor]: A tuple containing: - - The negative conditional input (uncond_input) - - The positive conditional input (cond_input) - """ - if cond.shape[0] == self.batch_size * 3: - neg_cond = cond[0 : self.batch_size] - cond = cond[self.batch_size : self.batch_size * 2] - return neg_cond, cond - elif cond.shape[0] == self.batch_size: - return cond, cond - else: - raise ValueError(f"Unsupported input shape: {cond.shape}") - - def prepare_input( - self, - cond_input: Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor, ...]], - negative_cond_input: Optional[Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor, ...]]] = None, - ) -> Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor, ...]]: - """ - Prepare the input for CFG. - - Args: - cond_input (Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor, ...]]): - The conditional input. It can be a single tensor or a - list of tensors. It must have the same length as `negative_cond_input`. - negative_cond_input (Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor, ...]]): - The negative conditional input. It can be a single tensor or a list of tensors. It must have the same - length as `cond_input`. - - Returns: - Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor, ...]]: The prepared input. - """ - - # we check if cond_input already has CFG applied, and split if it is the case. - - if self._is_prepared_input(cond_input) and self.do_perturbed_attention_guidance: - return cond_input - - if self._is_prepared_input(cond_input) and not self.do_perturbed_attention_guidance: - if isinstance(cond_input, list): - negative_cond_input, cond_input = zip(*[self._maybe_split_prepared_input(cond) for cond in cond_input]) - else: - negative_cond_input, cond_input = self._maybe_split_prepared_input(cond_input) - - if not self._is_prepared_input(cond_input) and self.do_perturbed_attention_guidance and negative_cond_input is None: - raise ValueError( - "`negative_cond_input` is required when cond_input does not already contains negative conditional input" - ) - - if isinstance(cond_input, (list, tuple)): - if not self.do_perturbed_attention_guidance: - return cond_input - - if len(negative_cond_input) != len(cond_input): - raise ValueError("The length of negative_cond_input and cond_input must be the same.") - - prepared_input = [] - for neg_cond, cond in zip(negative_cond_input, cond_input): - if neg_cond.shape[0] != cond.shape[0]: - raise ValueError("The batch size of negative_cond_input and cond_input must be the same.") - - cond = torch.cat([cond] * 2, dim=0) - if self.do_classifier_free_guidance: - prepared_input.append(torch.cat([neg_cond, cond], dim=0)) - else: - prepared_input.append(cond) - - return prepared_input - - elif isinstance(cond_input, torch.Tensor): - if not self.do_perturbed_attention_guidance: - return cond_input - - cond_input = torch.cat([cond_input] * 2, dim=0) - if self.do_classifier_free_guidance: - return torch.cat([negative_cond_input, cond_input], dim=0) - else: - return cond_input - - else: - raise ValueError(f"Unsupported input type: {type(negative_cond_input)} and {type(cond_input)}") - - def apply_guidance( - self, - model_output: torch.Tensor, - timestep: int, - latents: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - if not self.do_perturbed_attention_guidance: - return model_output - - if self.do_pag_adaptive_scaling: - pag_scale = max(self._pag_scale - self._pag_adaptive_scale * (1000 - timestep), 0) - else: - pag_scale = self._pag_scale - - if self.do_classifier_free_guidance: - noise_pred_uncond, noise_pred_text, noise_pred_perturb = model_output.chunk(3) - noise_pred = ( - noise_pred_uncond - + self.guidance_scale * (noise_pred_text - noise_pred_uncond) - + pag_scale * (noise_pred_text - noise_pred_perturb) - ) - else: - noise_pred_text, noise_pred_perturb = model_output.chunk(2) - noise_pred = noise_pred_text + pag_scale * (noise_pred_text - noise_pred_perturb) - - if self.guidance_rescale > 0.0: - # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf - noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale) - - return noise_pred - - -class MomentumBuffer: - def __init__(self, momentum: float): - self.momentum = momentum - self.running_average = 0 - - def update(self, update_value: torch.Tensor): - new_average = self.momentum * self.running_average - self.running_average = update_value + new_average - - -class APGGuider: - """ - This class is used to guide the pipeline with APG (Adaptive Projected Guidance). - """ - - def normalized_guidance( - self, - pred_cond: torch.Tensor, - pred_uncond: torch.Tensor, - guidance_scale: float, - momentum_buffer: MomentumBuffer = None, - norm_threshold: float = 0.0, - eta: float = 1.0, - ): - """ - Based on the findings of [Eliminating Oversaturation and Artifacts of High Guidance Scales in Diffusion - Models](https://arxiv.org/pdf/2410.02416) - """ - diff = pred_cond - pred_uncond - if momentum_buffer is not None: - momentum_buffer.update(diff) - diff = momentum_buffer.running_average - if norm_threshold > 0: - ones = torch.ones_like(diff) - diff_norm = diff.norm(p=2, dim=[-1, -2, -3], keepdim=True) - scale_factor = torch.minimum(ones, norm_threshold / diff_norm) - diff = diff * scale_factor - v0, v1 = diff.double(), pred_cond.double() - v1 = torch.nn.functional.normalize(v1, dim=[-1, -2, -3]) - v0_parallel = (v0 * v1).sum(dim=[-1, -2, -3], keepdim=True) * v1 - v0_orthogonal = v0 - v0_parallel - diff_parallel, diff_orthogonal = v0_parallel.to(diff.dtype), v0_orthogonal.to(diff.dtype) - normalized_update = diff_orthogonal + eta * diff_parallel - pred_guided = pred_cond + (guidance_scale - 1) * normalized_update - return pred_guided - - @property - def adaptive_projected_guidance_momentum(self): - return self._adaptive_projected_guidance_momentum - - @property - def adaptive_projected_guidance_rescale_factor(self): - return self._adaptive_projected_guidance_rescale_factor - - @property - def do_classifier_free_guidance(self): - return self._guidance_scale > 1.0 and not self._disable_guidance - - @property - def guidance_rescale(self): - return self._guidance_rescale - - @property - def guidance_scale(self): - return self._guidance_scale - - @property - def batch_size(self): - return self._batch_size - - def set_guider(self, pipeline, guider_kwargs: Dict[str, Any]): - disable_guidance = guider_kwargs.get("disable_guidance", False) - guidance_scale = guider_kwargs.get("guidance_scale", None) - if guidance_scale is None: - raise ValueError("guidance_scale is not provided in guider_kwargs") - adaptive_projected_guidance_momentum = guider_kwargs.get("adaptive_projected_guidance_momentum", None) - adaptive_projected_guidance_rescale_factor = guider_kwargs.get( - "adaptive_projected_guidance_rescale_factor", 15.0 - ) - guidance_rescale = guider_kwargs.get("guidance_rescale", 0.0) - batch_size = guider_kwargs.get("batch_size", None) - if batch_size is None: - raise ValueError("batch_size is not provided in guider_kwargs") - self._adaptive_projected_guidance_momentum = adaptive_projected_guidance_momentum - self._adaptive_projected_guidance_rescale_factor = adaptive_projected_guidance_rescale_factor - self._guidance_scale = guidance_scale - self._guidance_rescale = guidance_rescale - self._batch_size = batch_size - self._disable_guidance = disable_guidance - if adaptive_projected_guidance_momentum is not None: - self.momentum_buffer = MomentumBuffer(adaptive_projected_guidance_momentum) - else: - self.momentum_buffer = None - self.scheduler = pipeline.scheduler - - def reset_guider(self, pipeline): - pass - - def maybe_update_guider(self, pipeline, timestep): - pass - - def maybe_update_input(self, pipeline, cond_input): - pass - - def _maybe_split_prepared_input(self, cond): - """ - Process and potentially split the conditional input for Classifier-Free Guidance (CFG). - - This method handles inputs that may already have CFG applied (i.e. when `cond` is output of `prepare_input`). - It determines whether to split the input based on its batch size relative to the expected batch size. - - Args: - cond (torch.Tensor): The conditional input tensor to process. - - Returns: - Tuple[torch.Tensor, torch.Tensor]: A tuple containing: - - The negative conditional input (uncond_input) - - The positive conditional input (cond_input) - """ - if cond.shape[0] == self.batch_size * 2: - neg_cond = cond[0 : self.batch_size] - cond = cond[self.batch_size :] - return neg_cond, cond - elif cond.shape[0] == self.batch_size: - return cond, cond - else: - raise ValueError(f"Unsupported input shape: {cond.shape}") - - def _is_prepared_input(self, cond): - """ - Check if the input is already prepared for Classifier-Free Guidance (CFG). - - Args: - cond (torch.Tensor): The conditional input tensor to check. - - Returns: - bool: True if the input is already prepared, False otherwise. - """ - cond_tensor = cond[0] if isinstance(cond, (list, tuple)) else cond - - return cond_tensor.shape[0] == self.batch_size * 2 - - def prepare_input( - self, - cond_input: Union[torch.Tensor, List[torch.Tensor]], - negative_cond_input: Optional[Union[torch.Tensor, List[torch.Tensor]]] = None, - ) -> Union[torch.Tensor, List[torch.Tensor]]: - """ - Prepare the input for CFG. - - Args: - cond_input (Union[torch.Tensor, List[torch.Tensor]]): - The conditional input. It can be a single tensor or a - list of tensors. It must have the same length as `negative_cond_input`. - negative_cond_input (Union[torch.Tensor, List[torch.Tensor]]): The negative conditional input. It can be a - single tensor or a list of tensors. It must have the same length as `cond_input`. - - Returns: - Union[torch.Tensor, List[torch.Tensor]]: The prepared input. - """ - - # we check if cond_input already has CFG applied, and split if it is the case. - if self._is_prepared_input(cond_input) and self.do_classifier_free_guidance: - return cond_input - - if self._is_prepared_input(cond_input) and not self.do_classifier_free_guidance: - if isinstance(cond_input, list): - negative_cond_input, cond_input = zip(*[self._maybe_split_prepared_input(cond) for cond in cond_input]) - else: - negative_cond_input, cond_input = self._maybe_split_prepared_input(cond_input) - - if not self._is_prepared_input(cond_input) and self.do_classifier_free_guidance and negative_cond_input is None: - raise ValueError( - "`negative_cond_input` is required when cond_input does not already contains negative conditional input" - ) - - if isinstance(cond_input, (list, tuple)): - if not self.do_classifier_free_guidance: - return cond_input - - if len(negative_cond_input) != len(cond_input): - raise ValueError("The length of negative_cond_input and cond_input must be the same.") - prepared_input = [] - for neg_cond, cond in zip(negative_cond_input, cond_input): - if neg_cond.shape[0] != cond.shape[0]: - raise ValueError("The batch size of negative_cond_input and cond_input must be the same.") - prepared_input.append(torch.cat([neg_cond, cond], dim=0)) - return prepared_input - - elif isinstance(cond_input, torch.Tensor): - if not self.do_classifier_free_guidance: - return cond_input - else: - return torch.cat([negative_cond_input, cond_input], dim=0) - - else: - raise ValueError(f"Unsupported input type: {type(cond_input)}") - - def apply_guidance( - self, - model_output: torch.Tensor, - timestep: int = None, - latents: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - if not self.do_classifier_free_guidance: - return model_output - - if latents is None: - raise ValueError("APG requires `latents` to convert model output to denoised prediction (x0).") - - sigma = self.scheduler.sigmas[self.scheduler.step_index] - noise_pred = latents - sigma * model_output - noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) - noise_pred = self.normalized_guidance( - noise_pred_text, - noise_pred_uncond, - self.guidance_scale, - self.momentum_buffer, - self.adaptive_projected_guidance_rescale_factor, - ) - noise_pred = (latents - noise_pred) / sigma - - if self.guidance_rescale > 0.0: - # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf - noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale) - return noise_pred - - -Guiders = Union[CFGGuider, PAGGuider, APGGuider] \ No newline at end of file diff --git a/src/diffusers/guiders/__init__.py b/src/diffusers/guiders/__init__.py index ac3837a6b4d7..52b9176f5fb7 100644 --- a/src/diffusers/guiders/__init__.py +++ b/src/diffusers/guiders/__init__.py @@ -23,4 +23,4 @@ from .classifier_free_zero_star_guidance import ClassifierFreeZeroStarGuidance from .skip_layer_guidance import SkipLayerGuidance - GuiderType = Union[ClassifierFreeGuidance, SkipLayerGuidance] + GuiderType = Union[AdaptiveProjectedGuidance, ClassifierFreeGuidance, ClassifierFreeZeroStarGuidance, SkipLayerGuidance] From 8d31c699a5bed011afc2ac6569eff3892fe885ae Mon Sep 17 00:00:00 2001 From: Aryan Date: Tue, 15 Apr 2025 08:26:03 +0200 Subject: [PATCH 11/39] fix slg bug --- .../guiders/adaptive_projected_guidance.py | 6 ++-- src/diffusers/guiders/guider_utils.py | 7 ++++ src/diffusers/guiders/skip_layer_guidance.py | 35 ++++++++++++++----- src/diffusers/hooks/layer_skip.py | 13 ++++--- .../pipeline_stable_diffusion_xl_modular.py | 3 ++ 5 files changed, 48 insertions(+), 16 deletions(-) diff --git a/src/diffusers/guiders/adaptive_projected_guidance.py b/src/diffusers/guiders/adaptive_projected_guidance.py index ab5175745b07..45bd196860f4 100644 --- a/src/diffusers/guiders/adaptive_projected_guidance.py +++ b/src/diffusers/guiders/adaptive_projected_guidance.py @@ -86,7 +86,7 @@ def prepare_outputs(self, denoiser: torch.nn.Module, pred: torch.Tensor) -> None def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> torch.Tensor: pred = None - if not self._is_cfg_enabled(): + if not self._is_apg_enabled(): pred = pred_cond else: pred = normalized_guidance( @@ -111,11 +111,11 @@ def is_conditional(self) -> bool: @property def num_conditions(self) -> int: num_conditions = 1 - if self._is_cfg_enabled(): + if self._is_apg_enabled(): num_conditions += 1 return num_conditions - def _is_cfg_enabled(self) -> bool: + def _is_apg_enabled(self) -> bool: if not self._enabled: return False diff --git a/src/diffusers/guiders/guider_utils.py b/src/diffusers/guiders/guider_utils.py index 690afae89178..60859bf390f6 100644 --- a/src/diffusers/guiders/guider_utils.py +++ b/src/diffusers/guiders/guider_utils.py @@ -68,6 +68,13 @@ def set_state(self, step: int, num_inference_steps: int, timestep: torch.LongTen self._preds = {} self._num_outputs_prepared = 0 + def prepare_models(self, denoiser: torch.nn.Module) -> None: + """ + Prepares the models for the guidance technique on a given batch of data. This method should be overridden in + subclasses to implement specific model preparation logic. + """ + pass + def prepare_inputs(self, denoiser: torch.nn.Module, *args: Union[Tuple[torch.Tensor], List[torch.Tensor]]) -> Tuple[List[torch.Tensor], ...]: raise NotImplementedError("BaseGuidance::prepare_inputs must be implemented in subclasses.") diff --git a/src/diffusers/guiders/skip_layer_guidance.py b/src/diffusers/guiders/skip_layer_guidance.py index bac851c0dc28..3fbfd771eff9 100644 --- a/src/diffusers/guiders/skip_layer_guidance.py +++ b/src/diffusers/guiders/skip_layer_guidance.py @@ -54,6 +54,10 @@ class SkipLayerGuidance(BaseGuidance): skip_layer_guidance_scale (`float`, defaults to `2.8`): The scale parameter for skip layer guidance. Anatomy and structure coherence may improve with higher values, but it may also lead to overexposure and saturation. + skip_layer_guidance_start (`float`, defaults to `0.01`): + The fraction of the total number of denoising steps after which skip layer guidance starts. + skip_layer_guidance_stop (`float`, defaults to `0.2`): + The fraction of the total number of denoising steps after which skip layer guidance stops. skip_layer_guidance_layers (`int` or `List[int]`, *optional*): The layer indices to apply skip layer guidance to. Can be a single integer or a list of integers. If not provided, `skip_layer_config` must be provided. The recommended values are `[7, 8, 9]` for Stable Diffusion @@ -81,20 +85,33 @@ def __init__( self, guidance_scale: float = 7.5, skip_layer_guidance_scale: float = 2.8, + skip_layer_guidance_start: float = 0.01, + skip_layer_guidance_stop: float = 0.2, skip_layer_guidance_layers: Optional[Union[int, List[int]]] = None, skip_layer_config: Union[LayerSkipConfig, List[LayerSkipConfig]] = None, guidance_rescale: float = 0.0, use_original_formulation: bool = False, - start: float = 0.01, - stop: float = 0.2, + start: float = 0.0, + stop: float = 1.0, ): super().__init__(start, stop) self.guidance_scale = guidance_scale self.skip_layer_guidance_scale = skip_layer_guidance_scale + self.skip_layer_guidance_start = skip_layer_guidance_start + self.skip_layer_guidance_stop = skip_layer_guidance_stop self.guidance_rescale = guidance_rescale self.use_original_formulation = use_original_formulation + if not (0.0 <= skip_layer_guidance_start < 1.0): + raise ValueError( + f"Expected `skip_layer_guidance_start` to be between 0.0 and 1.0, but got {skip_layer_guidance_start}." + ) + if not (skip_layer_guidance_start <= skip_layer_guidance_stop <= 1.0): + raise ValueError( + f"Expected `skip_layer_guidance_stop` to be between 0.0 and 1.0, but got {skip_layer_guidance_stop}." + ) + if skip_layer_guidance_layers is None and skip_layer_config is None: raise ValueError( "Either `skip_layer_guidance_layers` or `skip_layer_config` must be provided to enable Skip Layer Guidance." @@ -122,11 +139,12 @@ def __init__( self.skip_layer_config = skip_layer_config self._skip_layer_hook_names = [f"SkipLayerGuidance_{i}" for i in range(len(self.skip_layer_config))] - def prepare_inputs(self, denoiser: torch.nn.Module, *args: Union[Tuple[torch.Tensor], List[torch.Tensor]]) -> Tuple[List[torch.Tensor], ...]: - if self._num_outputs_prepared == 0 and self._is_slg_enabled(): + def prepare_models(self, denoiser: torch.nn.Module) -> None: + if self._is_slg_enabled() and self.is_conditional and self._num_outputs_prepared > 0: for name, config in zip(self._skip_layer_hook_names, self.skip_layer_config): _apply_layer_skip_hook(denoiser, config, name=name) - + + def prepare_inputs(self, denoiser: torch.nn.Module, *args: Union[Tuple[torch.Tensor], List[torch.Tensor]]) -> Tuple[List[torch.Tensor], ...]: num_conditions = self.num_conditions list_of_inputs = [] for arg in args: @@ -161,7 +179,8 @@ def prepare_outputs(self, denoiser: torch.nn.Module, pred: torch.Tensor) -> None key = "pred_cond_skip" self._preds[key] = pred - if self._num_outputs_prepared == self.num_conditions: + if key == "pred_cond_skip": + # If we are in SLG mode, we need to remove the hooks after inference registry = HookRegistry.check_if_exists_or_initialize(denoiser) # Remove the hooks after inference for hook_name in self._skip_layer_hook_names: @@ -233,8 +252,8 @@ def _is_slg_enabled(self) -> bool: is_within_range = True if self._num_inference_steps is not None: - skip_start_step = int(self._start * self._num_inference_steps) - skip_stop_step = int(self._stop * self._num_inference_steps) + skip_start_step = int(self.skip_layer_guidance_start * self._num_inference_steps) + skip_stop_step = int(self.skip_layer_guidance_stop * self._num_inference_steps) is_within_range = skip_start_step < self._step < skip_stop_step is_zero = math.isclose(self.skip_layer_guidance_scale, 0.0) diff --git a/src/diffusers/hooks/layer_skip.py b/src/diffusers/hooks/layer_skip.py index b8471523902a..e28322ac4865 100644 --- a/src/diffusers/hooks/layer_skip.py +++ b/src/diffusers/hooks/layer_skip.py @@ -80,14 +80,17 @@ def __init__(self, skip_processor_output_fn: Callable, skip_attention_scores: bo def new_forward(self, module: torch.nn.Module, *args, **kwargs): if self.skip_attention_scores: + print("Skipping attention scores") with AttentionScoreSkipFunctionMode(): return self.fn_ref.original_forward(*args, **kwargs) else: + print("Skipping attention processor output") return self.skip_processor_output_fn(module, *args, **kwargs) class FeedForwardSkipHook(ModelHook): def new_forward(self, module: torch.nn.Module, *args, **kwargs): + print("Skipping feed-forward block") output = kwargs.get("hidden_states", None) if output is None: output = kwargs.get("x", None) @@ -102,18 +105,22 @@ def initialize_hook(self, module): return module def new_forward(self, module: torch.nn.Module, *args, **kwargs): + print("Skipping transformer block") return self._metadata.skip_block_output_fn(module, *args, **kwargs) def apply_layer_skip(module: torch.nn.Module, config: LayerSkipConfig) -> None: r""" Apply layer skipping to internal layers of a transformer. + Args: module (`torch.nn.Module`): The transformer model to which the layer skip hook should be applied. config (`LayerSkipConfig`): The configuration for the layer skip hook. + Example: + ```python >>> from diffusers import apply_layer_skip_hook, CogVideoXTransformer3DModel, LayerSkipConfig >>> transformer = CogVideoXTransformer3DModel.from_pretrained("THUDM/CogVideoX-5b", torch_dtype=torch.bfloat16) @@ -168,17 +175,13 @@ def _apply_layer_skip_hook(module: torch.nn.Module, config: LayerSkipConfig, nam registry = HookRegistry.check_if_exists_or_initialize(submodule) hook = AttentionProcessorSkipHook(output_fn, config.skip_attention_scores) registry.register_hook(hook, name) - elif config.skip_ff: + if config.skip_ff: for submodule_name, submodule in block.named_modules(): if isinstance(submodule, _FEEDFORWARD_CLASSES): logger.debug(f"Applying FeedForwardSkipHook to '{config.fqn}.{i}.{submodule_name}'") registry = HookRegistry.check_if_exists_or_initialize(submodule) hook = FeedForwardSkipHook() registry.register_hook(hook, name) - else: - raise ValueError( - "At least one of `skip_attention`, `skip_attention_scores`, or `skip_ff` must be set to True." - ) if not blocks_found: raise ValueError( diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py index 68d9d913bd3a..aed212c3f880 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py @@ -2267,6 +2267,7 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: ) in enumerate(zip( latents, prompt_embeds, add_time_ids, pooled_prompt_embeds, mask, masked_image_latents, ip_adapter_embeds )): + pipeline.guider.prepare_models(pipeline.unet) latents_i = pipeline.scheduler.scale_model_input(latents_i, t) # Prepare for inpainting @@ -2670,6 +2671,7 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: ) in enumerate(zip( latents, prompt_embeds, add_time_ids, pooled_prompt_embeds, mask, masked_image_latents, ip_adapter_embeds )): + pipeline.guider.prepare_models(pipeline.unet) latents_i = pipeline.scheduler.scale_model_input(latents_i, t) # Prepare for inpainting @@ -3085,6 +3087,7 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: ) in enumerate(zip( latents, prompt_embeds, add_time_ids, pooled_prompt_embeds, mask, masked_image_latents, ip_adapter_embeds )): + pipeline.guider.prepare_models(pipeline.unet) latents_i = pipeline.scheduler.scale_model_input(latents_i, t) # Prepare for inpainting From 1c1d1d52e0712028253d5196545a79cc41ef005d Mon Sep 17 00:00:00 2001 From: Aryan Date: Tue, 15 Apr 2025 08:26:43 +0200 Subject: [PATCH 12/39] remove debug print --- src/diffusers/hooks/layer_skip.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/src/diffusers/hooks/layer_skip.py b/src/diffusers/hooks/layer_skip.py index e28322ac4865..2a906c315c16 100644 --- a/src/diffusers/hooks/layer_skip.py +++ b/src/diffusers/hooks/layer_skip.py @@ -80,17 +80,14 @@ def __init__(self, skip_processor_output_fn: Callable, skip_attention_scores: bo def new_forward(self, module: torch.nn.Module, *args, **kwargs): if self.skip_attention_scores: - print("Skipping attention scores") with AttentionScoreSkipFunctionMode(): return self.fn_ref.original_forward(*args, **kwargs) else: - print("Skipping attention processor output") return self.skip_processor_output_fn(module, *args, **kwargs) class FeedForwardSkipHook(ModelHook): def new_forward(self, module: torch.nn.Module, *args, **kwargs): - print("Skipping feed-forward block") output = kwargs.get("hidden_states", None) if output is None: output = kwargs.get("x", None) @@ -105,7 +102,6 @@ def initialize_hook(self, module): return module def new_forward(self, module: torch.nn.Module, *args, **kwargs): - print("Skipping transformer block") return self._metadata.skip_block_output_fn(module, *args, **kwargs) From ba579f4da90b5d6836d80f8b9e61a6df35cc3bad Mon Sep 17 00:00:00 2001 From: Aryan Date: Tue, 15 Apr 2025 09:21:25 +0200 Subject: [PATCH 13/39] autoguidance --- src/diffusers/__init__.py | 2 + src/diffusers/guiders/__init__.py | 1 + src/diffusers/guiders/auto_guidance.py | 172 +++++++++++++++++++++++++ src/diffusers/hooks/layer_skip.py | 72 ++++++++--- 4 files changed, 232 insertions(+), 15 deletions(-) create mode 100644 src/diffusers/guiders/auto_guidance.py diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index b2e24614b97a..0672e1035628 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -133,6 +133,7 @@ _import_structure["guiders"].extend( [ "AdaptiveProjectedGuidance", + "AutoGuidance", "ClassifierFreeGuidance", "ClassifierFreeZeroStarGuidance", "SkipLayerGuidance", @@ -724,6 +725,7 @@ else: from .guiders import ( AdaptiveProjectedGuidance, + AutoGuidance, ClassifierFreeGuidance, ClassifierFreeZeroStarGuidance, SkipLayerGuidance, diff --git a/src/diffusers/guiders/__init__.py b/src/diffusers/guiders/__init__.py index 52b9176f5fb7..c23e48578ebc 100644 --- a/src/diffusers/guiders/__init__.py +++ b/src/diffusers/guiders/__init__.py @@ -19,6 +19,7 @@ if is_torch_available(): from .adaptive_projected_guidance import AdaptiveProjectedGuidance + from .auto_guidance import AutoGuidance from .classifier_free_guidance import ClassifierFreeGuidance from .classifier_free_zero_star_guidance import ClassifierFreeZeroStarGuidance from .skip_layer_guidance import SkipLayerGuidance diff --git a/src/diffusers/guiders/auto_guidance.py b/src/diffusers/guiders/auto_guidance.py new file mode 100644 index 000000000000..8c759f497307 --- /dev/null +++ b/src/diffusers/guiders/auto_guidance.py @@ -0,0 +1,172 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from typing import List, Optional, Tuple, Union + +import torch + +from ..hooks import HookRegistry, LayerSkipConfig +from ..hooks.layer_skip import _apply_layer_skip_hook +from .guider_utils import BaseGuidance, rescale_noise_cfg, _default_prepare_inputs + + +class AutoGuidance(BaseGuidance): + """ + AutoGuidance: https://huggingface.co/papers/2406.02507 + + Args: + guidance_scale (`float`, defaults to `7.5`): + The scale parameter for classifier-free guidance. Higher values result in stronger conditioning on the text + prompt, while lower values allow for more freedom in generation. Higher values may lead to saturation and + deterioration of image quality. + auto_guidance_layers (`int` or `List[int]`, *optional*): + The layer indices to apply skip layer guidance to. Can be a single integer or a list of integers. If not + provided, `skip_layer_config` must be provided. + auto_guidance_config (`LayerSkipConfig` or `List[LayerSkipConfig]`, *optional*): + The configuration for the skip layer guidance. Can be a single `LayerSkipConfig` or a list of + `LayerSkipConfig`. If not provided, `skip_layer_guidance_layers` must be provided. + dropout (`float`, *optional*): + The dropout probability for autoguidance on the enabled skip layers (either with `auto_guidance_layers` or + `auto_guidance_config`). If not provided, the dropout probability will be set to 1.0. + guidance_rescale (`float`, defaults to `0.0`): + The rescale factor applied to the noise predictions. This is used to improve image quality and fix + overexposure. Based on Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are + Flawed](https://huggingface.co/papers/2305.08891). + use_original_formulation (`bool`, defaults to `False`): + Whether to use the original formulation of classifier-free guidance as proposed in the paper. By default, + we use the diffusers-native implementation that has been in the codebase for a long time. See + [~guiders.classifier_free_guidance.ClassifierFreeGuidance] for more details. + start (`float`, defaults to `0.0`): + The fraction of the total number of denoising steps after which guidance starts. + stop (`float`, defaults to `1.0`): + The fraction of the total number of denoising steps after which guidance stops. + """ + + _input_predictions = ["pred_cond", "pred_uncond"] + + def __init__( + self, + guidance_scale: float = 7.5, + auto_guidance_layers: Optional[Union[int, List[int]]] = None, + auto_guidance_config: Union[LayerSkipConfig, List[LayerSkipConfig]] = None, + dropout: Optional[float] = None, + guidance_rescale: float = 0.0, + use_original_formulation: bool = False, + start: float = 0.0, + stop: float = 1.0, + ): + super().__init__(start, stop) + + self.guidance_scale = guidance_scale + self.auto_guidance_layers = auto_guidance_layers + self.auto_guidance_config = auto_guidance_config + self.dropout = dropout + self.guidance_rescale = guidance_rescale + self.use_original_formulation = use_original_formulation + + if auto_guidance_layers is None and auto_guidance_config is None: + raise ValueError( + "Either `auto_guidance_layers` or `auto_guidance_config` must be provided to enable Skip Layer Guidance." + ) + if auto_guidance_layers is not None and auto_guidance_config is not None: + raise ValueError("Only one of `auto_guidance_layers` or `auto_guidance_config` can be provided.") + if (dropout is None and auto_guidance_layers is not None) or (dropout is not None and auto_guidance_layers is None): + raise ValueError("`dropout` must be provided if `auto_guidance_layers` is provided.") + + if auto_guidance_layers is not None: + if isinstance(auto_guidance_layers, int): + auto_guidance_layers = [auto_guidance_layers] + if not isinstance(auto_guidance_layers, list): + raise ValueError( + f"Expected `auto_guidance_layers` to be an int or a list of ints, but got {type(auto_guidance_layers)}." + ) + auto_guidance_config = [LayerSkipConfig(layer, fqn="auto", dropout=dropout) for layer in auto_guidance_layers] + + if isinstance(auto_guidance_config, LayerSkipConfig): + auto_guidance_config = [auto_guidance_config] + + if not isinstance(auto_guidance_config, list): + raise ValueError( + f"Expected `auto_guidance_config` to be a LayerSkipConfig or a list of LayerSkipConfig, but got {type(auto_guidance_config)}." + ) + + self.auto_guidance_config = auto_guidance_config + self._auto_guidance_hook_names = [f"AutoGuidance_{i}" for i in range(len(self.auto_guidance_config))] + + def prepare_models(self, denoiser: torch.nn.Module) -> None: + if self._is_ag_enabled() and self.is_unconditional: + for name, config in zip(self._auto_guidance_hook_names, self.auto_guidance_config): + _apply_layer_skip_hook(denoiser, config, name=name) + + def prepare_inputs(self, denoiser: torch.nn.Module, *args: Union[Tuple[torch.Tensor], List[torch.Tensor]]) -> Tuple[List[torch.Tensor], ...]: + return _default_prepare_inputs(denoiser, self.num_conditions, *args) + + def prepare_outputs(self, denoiser: torch.nn.Module, pred: torch.Tensor) -> None: + self._num_outputs_prepared += 1 + if self._num_outputs_prepared > self.num_conditions: + raise ValueError(f"Expected {self.num_conditions} outputs, but prepare_outputs called more times.") + key = self._input_predictions[self._num_outputs_prepared - 1] + self._preds[key] = pred + + if key == "pred_uncond": + # If we are in AutoGuidance unconditional inference mode, we need to remove the hooks after inference + registry = HookRegistry.check_if_exists_or_initialize(denoiser) + # Remove the hooks after inference + for hook_name in self._auto_guidance_hook_names: + registry.remove_hook(hook_name, recurse=True) + + def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> torch.Tensor: + pred = None + + if not self._is_ag_enabled(): + pred = pred_cond + else: + shift = pred_cond - pred_uncond + pred = pred_cond if self.use_original_formulation else pred_uncond + pred = pred + self.guidance_scale * shift + + if self.guidance_rescale > 0.0: + pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale) + + return pred + + @property + def is_conditional(self) -> bool: + return self._num_outputs_prepared == 0 + + @property + def num_conditions(self) -> int: + num_conditions = 1 + if self._is_ag_enabled(): + num_conditions += 1 + return num_conditions + + def _is_ag_enabled(self) -> bool: + if not self._enabled: + return False + + is_within_range = True + if self._num_inference_steps is not None: + skip_start_step = int(self._start * self._num_inference_steps) + skip_stop_step = int(self._stop * self._num_inference_steps) + is_within_range = skip_start_step <= self._step < skip_stop_step + + is_close = False + if self.use_original_formulation: + is_close = math.isclose(self.guidance_scale, 0.0) + else: + is_close = math.isclose(self.guidance_scale, 1.0) + + return is_within_range and not is_close diff --git a/src/diffusers/hooks/layer_skip.py b/src/diffusers/hooks/layer_skip.py index 2a906c315c16..14b1cf492d0e 100644 --- a/src/diffusers/hooks/layer_skip.py +++ b/src/diffusers/hooks/layer_skip.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import math from dataclasses import dataclass from typing import Callable, List, Optional @@ -47,8 +48,12 @@ class LayerSkipConfig: skip_ff (`bool`, defaults to `True`): Whether to skip feed-forward blocks. skip_attention_scores (`bool`, defaults to `False`): - Whether to skip attention score computation in the attention blocks. This is equivalent to using - `value` projections as the output of scaled dot product attention. + Whether to skip attention score computation in the attention blocks. This is equivalent to using `value` + projections as the output of scaled dot product attention. + dropout (`float`, defaults to `1.0`): + The dropout probability for dropping the outputs of the skipped layers. By default, this is set to `1.0`, + meaning that the outputs of the skipped layers are completely ignored. If set to `0.0`, the outputs of the + skipped layers are fully retained, which is equivalent to not skipping any layers. """ indices: List[int] @@ -56,6 +61,11 @@ class LayerSkipConfig: skip_attention: bool = True skip_attention_scores: bool = False skip_ff: bool = True + dropout: float = 1.0 + + def __post_init__(self): + if not (0 <= self.dropout <= 1): + raise ValueError(f"Expected `dropout` to be between 0.0 and 1.0, but got {self.dropout}.") class AttentionScoreSkipFunctionMode(torch.overrides.TorchFunctionMode): @@ -74,36 +84,62 @@ def __torch_function__(self, func, types, args=(), kwargs=None): class AttentionProcessorSkipHook(ModelHook): - def __init__(self, skip_processor_output_fn: Callable, skip_attention_scores: bool = False): + def __init__(self, skip_processor_output_fn: Callable, skip_attention_scores: bool = False, dropout: float = 1.0): self.skip_processor_output_fn = skip_processor_output_fn self.skip_attention_scores = skip_attention_scores + self.dropout = dropout def new_forward(self, module: torch.nn.Module, *args, **kwargs): if self.skip_attention_scores: + if math.isclose(self.dropout, 1.0): + raise ValueError( + "Cannot set `skip_attention_scores` to True when `dropout` is not 1.0. Please set `dropout` to 1.0." + ) with AttentionScoreSkipFunctionMode(): - return self.fn_ref.original_forward(*args, **kwargs) + output = self.fn_ref.original_forward(*args, **kwargs) else: - return self.skip_processor_output_fn(module, *args, **kwargs) + if math.isclose(self.dropout, 1.0): + output = self.skip_processor_output_fn(module, *args, **kwargs) + else: + output = self.fn_ref.original_forward(*args, **kwargs) + output = torch.nn.functional.dropout(output, p=self.dropout) + return output class FeedForwardSkipHook(ModelHook): + def __init__(self, dropout: float): + super().__init__() + self.dropout = dropout + def new_forward(self, module: torch.nn.Module, *args, **kwargs): - output = kwargs.get("hidden_states", None) - if output is None: - output = kwargs.get("x", None) - if output is None and len(args) > 0: - output = args[0] + if math.isclose(self.dropout, 1.0): + output = kwargs.get("hidden_states", None) + if output is None: + output = kwargs.get("x", None) + if output is None and len(args) > 0: + output = args[0] + else: + output = self.fn_ref.original_forward(*args, **kwargs) + output = torch.nn.functional.dropout(output, p=self.dropout) return output class TransformerBlockSkipHook(ModelHook): + def __init__(self, dropout: float): + super().__init__() + self.dropout = dropout + def initialize_hook(self, module): self._metadata = TransformerBlockRegistry.get(unwrap_module(module).__class__) return module def new_forward(self, module: torch.nn.Module, *args, **kwargs): - return self._metadata.skip_block_output_fn(module, *args, **kwargs) - + if math.isclose(self.dropout, 1.0): + output = self._metadata.skip_block_output_fn(module, *args, **kwargs) + else: + output = self.fn_ref.original_forward(*args, **kwargs) + output = torch.nn.functional.dropout(output, p=self.dropout) + return output def apply_layer_skip(module: torch.nn.Module, config: LayerSkipConfig) -> None: r""" @@ -132,6 +168,8 @@ def _apply_layer_skip_hook(module: torch.nn.Module, config: LayerSkipConfig, nam if config.skip_attention and config.skip_attention_scores: raise ValueError("Cannot set both `skip_attention` and `skip_attention_scores` to True. Please choose one.") + if not math.isclose(config.dropout, 1.0) and config.skip_attention_scores: + raise ValueError("Cannot set `skip_attention_scores` to True when `dropout` is not 1.0. Please set `dropout` to 1.0.") if config.fqn == "auto": for identifier in _ALL_TRANSFORMER_BLOCK_IDENTIFIERS: @@ -157,26 +195,30 @@ def _apply_layer_skip_hook(module: torch.nn.Module, config: LayerSkipConfig, nam for i, block in enumerate(transformer_blocks): if i not in config.indices: continue + blocks_found = True + if config.skip_attention and config.skip_ff: logger.debug(f"Applying TransformerBlockSkipHook to '{config.fqn}.{i}'") registry = HookRegistry.check_if_exists_or_initialize(block) - hook = TransformerBlockSkipHook() + hook = TransformerBlockSkipHook(config.dropout) registry.register_hook(hook, name) + elif config.skip_attention or config.skip_attention_scores: for submodule_name, submodule in block.named_modules(): if isinstance(submodule, _ATTENTION_CLASSES) and not submodule.is_cross_attention: logger.debug(f"Applying AttentionProcessorSkipHook to '{config.fqn}.{i}.{submodule_name}'") output_fn = AttentionProcessorRegistry.get(submodule.processor.__class__).skip_processor_output_fn registry = HookRegistry.check_if_exists_or_initialize(submodule) - hook = AttentionProcessorSkipHook(output_fn, config.skip_attention_scores) + hook = AttentionProcessorSkipHook(output_fn, config.skip_attention_scores, config.dropout) registry.register_hook(hook, name) + if config.skip_ff: for submodule_name, submodule in block.named_modules(): if isinstance(submodule, _FEEDFORWARD_CLASSES): logger.debug(f"Applying FeedForwardSkipHook to '{config.fqn}.{i}.{submodule_name}'") registry = HookRegistry.check_if_exists_or_initialize(submodule) - hook = FeedForwardSkipHook() + hook = FeedForwardSkipHook(config.dropout) registry.register_hook(hook, name) if not blocks_found: From 720783e508c9f76ed4ecd05763db96a989a0ec86 Mon Sep 17 00:00:00 2001 From: Aryan Date: Tue, 15 Apr 2025 21:13:13 +0200 Subject: [PATCH 14/39] smoothed energy guidance --- src/diffusers/__init__.py | 4 + src/diffusers/guiders/__init__.py | 1 + src/diffusers/guiders/guider_utils.py | 4 +- src/diffusers/guiders/skip_layer_guidance.py | 5 +- .../guiders/smoothed_energy_guidance.py | 250 ++++++++++++++++++ src/diffusers/hooks/__init__.py | 1 + src/diffusers/hooks/_common.py | 11 + src/diffusers/hooks/layer_skip.py | 16 +- .../hooks/smoothed_energy_guidance_utils.py | 148 +++++++++++ .../pipeline_stable_diffusion_xl_modular.py | 12 +- 10 files changed, 431 insertions(+), 21 deletions(-) create mode 100644 src/diffusers/guiders/smoothed_energy_guidance.py create mode 100644 src/diffusers/hooks/smoothed_energy_guidance_utils.py diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 0672e1035628..67d8ae9f79ac 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -137,6 +137,7 @@ "ClassifierFreeGuidance", "ClassifierFreeZeroStarGuidance", "SkipLayerGuidance", + "SmoothedEnergyGuidance", ] ) _import_structure["hooks"].extend( @@ -145,6 +146,7 @@ "HookRegistry", "PyramidAttentionBroadcastConfig", "LayerSkipConfig", + "SmoothedEnergyGuidanceConfig", "apply_faster_cache", "apply_layer_skip", "apply_pyramid_attention_broadcast", @@ -729,12 +731,14 @@ ClassifierFreeGuidance, ClassifierFreeZeroStarGuidance, SkipLayerGuidance, + SmoothedEnergyGuidance, ) from .hooks import ( FasterCacheConfig, HookRegistry, LayerSkipConfig, PyramidAttentionBroadcastConfig, + SmoothedEnergyGuidanceConfig, apply_layer_skip, apply_faster_cache, apply_pyramid_attention_broadcast, diff --git a/src/diffusers/guiders/__init__.py b/src/diffusers/guiders/__init__.py index c23e48578ebc..7b88d61c67b3 100644 --- a/src/diffusers/guiders/__init__.py +++ b/src/diffusers/guiders/__init__.py @@ -23,5 +23,6 @@ from .classifier_free_guidance import ClassifierFreeGuidance from .classifier_free_zero_star_guidance import ClassifierFreeZeroStarGuidance from .skip_layer_guidance import SkipLayerGuidance + from .smoothed_energy_guidance import SmoothedEnergyGuidance GuiderType = Union[AdaptiveProjectedGuidance, ClassifierFreeGuidance, ClassifierFreeZeroStarGuidance, SkipLayerGuidance] diff --git a/src/diffusers/guiders/guider_utils.py b/src/diffusers/guiders/guider_utils.py index 60859bf390f6..e7d22a50d0da 100644 --- a/src/diffusers/guiders/guider_utils.py +++ b/src/diffusers/guiders/guider_utils.py @@ -55,10 +55,10 @@ def __init__(self, start: float = 0.0, stop: float = 1.0): "`_input_predictions` must be a list of required prediction names for the guidance technique." ) - def force_disable(self): + def _force_disable(self): self._enabled = False - def force_enable(self): + def _force_enable(self): self._enabled = True def set_state(self, step: int, num_inference_steps: int, timestep: torch.LongTensor) -> None: diff --git a/src/diffusers/guiders/skip_layer_guidance.py b/src/diffusers/guiders/skip_layer_guidance.py index 3fbfd771eff9..64b2b8a73c1a 100644 --- a/src/diffusers/guiders/skip_layer_guidance.py +++ b/src/diffusers/guiders/skip_layer_guidance.py @@ -24,8 +24,9 @@ class SkipLayerGuidance(BaseGuidance): """ - Skip Layer Guidance (SLG): https://github.com/Stability-AI/sd3.5 Spatio-Temporal Guidance (STG): - https://huggingface.co/papers/2411.18664 + Skip Layer Guidance (SLG): https://github.com/Stability-AI/sd3.5 + + Spatio-Temporal Guidance (STG): https://huggingface.co/papers/2411.18664 SLG was introduced by StabilityAI for improving structure and anotomy coherence in generated images. It works by skipping the forward pass of specified transformer blocks during the denoising process on an additional conditional diff --git a/src/diffusers/guiders/smoothed_energy_guidance.py b/src/diffusers/guiders/smoothed_energy_guidance.py new file mode 100644 index 000000000000..bd2a61b894e7 --- /dev/null +++ b/src/diffusers/guiders/smoothed_energy_guidance.py @@ -0,0 +1,250 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from typing import List, Optional, Tuple, Union + +import torch + +from ..hooks import HookRegistry +from ..hooks.smoothed_energy_guidance_utils import SmoothedEnergyGuidanceConfig, _apply_smoothed_energy_guidance_hook +from .guider_utils import BaseGuidance, rescale_noise_cfg + + +class SmoothedEnergyGuidance(BaseGuidance): + """ + Smoothed Energy Guidance (SEG): https://huggingface.co/papers/2408.00760 + + Args: + guidance_scale (`float`, defaults to `7.5`): + The scale parameter for classifier-free guidance. Higher values result in stronger conditioning on the text + prompt, while lower values allow for more freedom in generation. Higher values may lead to saturation and + deterioration of image quality. + seg_guidance_scale (`float`, defaults to `3.0`): + The scale parameter for smoothed energy guidance. Anatomy and structure coherence may improve with higher + values, but it may also lead to overexposure and saturation. + seg_blur_sigma (`float`, defaults to `9999999.0`): + The amount by which we blur the attention weights. Setting this value greater than 9999.0 results in + infinite blur, which means uniform queries. Controlling it exponentially is empirically effective. + seg_blur_threshold_inf (`float`, defaults to `9999.0`): + The threshold above which the blur is considered infinite. + seg_guidance_start (`float`, defaults to `0.0`): + The fraction of the total number of denoising steps after which smoothed energy guidance starts. + seg_guidance_stop (`float`, defaults to `1.0`): + The fraction of the total number of denoising steps after which smoothed energy guidance stops. + seg_guidance_layers (`int` or `List[int]`, *optional*): + The layer indices to apply smoothed energy guidance to. Can be a single integer or a list of integers. If not + provided, `seg_guidance_config` must be provided. The recommended values are `[7, 8, 9]` for Stable Diffusion + 3.5 Medium. + seg_guidance_config (`SmoothedEnergyGuidanceConfig` or `List[SmoothedEnergyGuidanceConfig]`, *optional*): + The configuration for the smoothed energy layer guidance. Can be a single `SmoothedEnergyGuidanceConfig` or a list of + `SmoothedEnergyGuidanceConfig`. If not provided, `seg_guidance_layers` must be provided. + guidance_rescale (`float`, defaults to `0.0`): + The rescale factor applied to the noise predictions. This is used to improve image quality and fix + overexposure. Based on Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are + Flawed](https://huggingface.co/papers/2305.08891). + use_original_formulation (`bool`, defaults to `False`): + Whether to use the original formulation of classifier-free guidance as proposed in the paper. By default, + we use the diffusers-native implementation that has been in the codebase for a long time. See + [~guiders.classifier_free_guidance.ClassifierFreeGuidance] for more details. + start (`float`, defaults to `0.01`): + The fraction of the total number of denoising steps after which guidance starts. + stop (`float`, defaults to `0.2`): + The fraction of the total number of denoising steps after which guidance stops. + """ + + _input_predictions = ["pred_cond", "pred_uncond", "pred_cond_seg"] + + def __init__( + self, + guidance_scale: float = 7.5, + seg_guidance_scale: float = 2.8, + seg_blur_sigma: float = 9999999.0, + seg_blur_threshold_inf: float = 9999.0, + seg_guidance_start: float = 0.0, + seg_guidance_stop: float = 1.0, + seg_guidance_layers: Optional[Union[int, List[int]]] = None, + seg_guidance_config: Union[SmoothedEnergyGuidanceConfig, List[SmoothedEnergyGuidanceConfig]] = None, + guidance_rescale: float = 0.0, + use_original_formulation: bool = False, + start: float = 0.0, + stop: float = 1.0, + ): + super().__init__(start, stop) + + self.guidance_scale = guidance_scale + self.seg_guidance_scale = seg_guidance_scale + self.seg_blur_sigma = seg_blur_sigma + self.seg_blur_threshold_inf = seg_blur_threshold_inf + self.seg_guidance_start = seg_guidance_start + self.seg_guidance_stop = seg_guidance_stop + self.guidance_rescale = guidance_rescale + self.use_original_formulation = use_original_formulation + + if not (0.0 <= seg_guidance_start < 1.0): + raise ValueError( + f"Expected `seg_guidance_start` to be between 0.0 and 1.0, but got {seg_guidance_start}." + ) + if not (seg_guidance_start <= seg_guidance_stop <= 1.0): + raise ValueError( + f"Expected `seg_guidance_stop` to be between 0.0 and 1.0, but got {seg_guidance_stop}." + ) + + if seg_guidance_layers is None and seg_guidance_config is None: + raise ValueError( + "Either `seg_guidance_layers` or `seg_guidance_config` must be provided to enable Smoothed Energy Guidance." + ) + if seg_guidance_layers is not None and seg_guidance_config is not None: + raise ValueError("Only one of `seg_guidance_layers` or `seg_guidance_config` can be provided.") + + if seg_guidance_layers is not None: + if isinstance(seg_guidance_layers, int): + seg_guidance_layers = [seg_guidance_layers] + if not isinstance(seg_guidance_layers, list): + raise ValueError( + f"Expected `seg_guidance_layers` to be an int or a list of ints, but got {type(seg_guidance_layers)}." + ) + seg_guidance_config = [SmoothedEnergyGuidanceConfig(layer, fqn="auto") for layer in seg_guidance_layers] + + if isinstance(seg_guidance_config, SmoothedEnergyGuidanceConfig): + seg_guidance_config = [seg_guidance_config] + + if not isinstance(seg_guidance_config, list): + raise ValueError( + f"Expected `seg_guidance_config` to be a SmoothedEnergyGuidanceConfig or a list of SmoothedEnergyGuidanceConfig, but got {type(seg_guidance_config)}." + ) + + self.seg_guidance_config = seg_guidance_config + self._seg_layer_hook_names = [f"SmoothedEnergyGuidance_{i}" for i in range(len(self.seg_guidance_config))] + + def prepare_models(self, denoiser: torch.nn.Module) -> None: + if self._is_seg_enabled() and self.is_conditional and self._num_outputs_prepared > 0: + for name, config in zip(self._seg_layer_hook_names, self.seg_guidance_config): + _apply_smoothed_energy_guidance_hook(denoiser, config, self.seg_blur_sigma, name=name) + + def prepare_inputs(self, denoiser: torch.nn.Module, *args: Union[Tuple[torch.Tensor], List[torch.Tensor]]) -> Tuple[List[torch.Tensor], ...]: + num_conditions = self.num_conditions + list_of_inputs = [] + for arg in args: + if arg is None or isinstance(arg, torch.Tensor): + list_of_inputs.append([arg] * num_conditions) + elif isinstance(arg, (tuple, list)): + if len(arg) != 2: + raise ValueError( + f"Expected a tuple or list of length 2, but got {len(arg)} for argument {arg}. Please provide a tuple/list of length 2 " + f"with the first element being the conditional input and the second element being the unconditional input or None." + ) + if arg[1] is None: + # Only conditioning inputs for all batches + list_of_inputs.append([arg[0]] * num_conditions) + else: + list_of_inputs.append([arg[0], arg[1], arg[0]]) + else: + raise ValueError( + f"Expected a tensor, tuple, or list, but got {type(arg)} for argument {arg}. Please provide a tensor, tuple, or list." + ) + return tuple(list_of_inputs) + + def prepare_outputs(self, denoiser: torch.nn.Module, pred: torch.Tensor) -> None: + self._num_outputs_prepared += 1 + if self._num_outputs_prepared > self.num_conditions: + raise ValueError(f"Expected {self.num_conditions} outputs, but prepare_outputs called more times.") + key = self._input_predictions[self._num_outputs_prepared - 1] + if not self._is_cfg_enabled() and self._is_seg_enabled(): + # If we're predicting pred_cond and pred_cond_seg only, we need to set the key to pred_cond_seg + # to avoid writing into pred_uncond which is not used + if self._num_outputs_prepared == 2: + key = "pred_cond_seg" + self._preds[key] = pred + + if key == "pred_cond_seg": + # If we are in SLG mode, we need to remove the hooks after inference + registry = HookRegistry.check_if_exists_or_initialize(denoiser) + # Remove the hooks after inference + for hook_name in self._seg_layer_hook_names: + registry.remove_hook(hook_name, recurse=True) + + def forward( + self, + pred_cond: torch.Tensor, + pred_uncond: Optional[torch.Tensor] = None, + pred_cond_seg: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + pred = None + + if not self._is_cfg_enabled() and not self._is_seg_enabled(): + pred = pred_cond + elif not self._is_cfg_enabled(): + shift = pred_cond - pred_cond_seg + pred = pred_cond if self.use_original_formulation else pred_cond_seg + pred = pred + self.seg_guidance_scale * shift + elif not self._is_seg_enabled(): + shift = pred_cond - pred_uncond + pred = pred_cond if self.use_original_formulation else pred_uncond + pred = pred + self.guidance_scale * shift + else: + shift = pred_cond - pred_uncond + shift_seg = pred_cond - pred_cond_seg + pred = pred_cond if self.use_original_formulation else pred_uncond + pred = pred + self.guidance_scale * shift + self.seg_guidance_scale * shift_seg + + if self.guidance_rescale > 0.0: + pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale) + + return pred + + @property + def is_conditional(self) -> bool: + return self._num_outputs_prepared == 0 or self._num_outputs_prepared == 2 + + @property + def num_conditions(self) -> int: + num_conditions = 1 + if self._is_cfg_enabled(): + num_conditions += 1 + if self._is_seg_enabled(): + num_conditions += 1 + return num_conditions + + def _is_cfg_enabled(self) -> bool: + if not self._enabled: + return False + + is_within_range = True + if self._num_inference_steps is not None: + skip_start_step = int(self._start * self._num_inference_steps) + skip_stop_step = int(self._stop * self._num_inference_steps) + is_within_range = skip_start_step <= self._step < skip_stop_step + + is_close = False + if self.use_original_formulation: + is_close = math.isclose(self.guidance_scale, 0.0) + else: + is_close = math.isclose(self.guidance_scale, 1.0) + + return is_within_range and not is_close + + def _is_seg_enabled(self) -> bool: + if not self._enabled: + return False + + is_within_range = True + if self._num_inference_steps is not None: + skip_start_step = int(self.seg_guidance_start * self._num_inference_steps) + skip_stop_step = int(self.seg_guidance_stop * self._num_inference_steps) + is_within_range = skip_start_step < self._step < skip_stop_step + + is_zero = math.isclose(self.seg_guidance_scale, 0.0) + + return is_within_range and not is_zero diff --git a/src/diffusers/hooks/__init__.py b/src/diffusers/hooks/__init__.py index 142ff860371c..9d0e96e9e79e 100644 --- a/src/diffusers/hooks/__init__.py +++ b/src/diffusers/hooks/__init__.py @@ -8,3 +8,4 @@ from .layer_skip import LayerSkipConfig, apply_layer_skip from .layerwise_casting import apply_layerwise_casting, apply_layerwise_casting_hook from .pyramid_attention_broadcast import PyramidAttentionBroadcastConfig, apply_pyramid_attention_broadcast + from .smoothed_energy_guidance_utils import SmoothedEnergyGuidanceConfig diff --git a/src/diffusers/hooks/_common.py b/src/diffusers/hooks/_common.py index 6ea83dcbf6a7..3d9c99e8189f 100644 --- a/src/diffusers/hooks/_common.py +++ b/src/diffusers/hooks/_common.py @@ -12,6 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import Optional + +import torch + from ..models.attention import FeedForward, LuminaFeedForward from ..models.attention_processor import Attention, MochiAttention @@ -30,3 +34,10 @@ *_CROSS_TRANSFORMER_BLOCK_IDENTIFIERS, } ) + + +def _get_submodule_from_fqn(module: torch.nn.Module, fqn: str) -> Optional[torch.nn.Module]: + for submodule_name, submodule in module.named_modules(): + if submodule_name == fqn: + return submodule + return None diff --git a/src/diffusers/hooks/layer_skip.py b/src/diffusers/hooks/layer_skip.py index 14b1cf492d0e..45f42a1f0f86 100644 --- a/src/diffusers/hooks/layer_skip.py +++ b/src/diffusers/hooks/layer_skip.py @@ -20,7 +20,7 @@ from ..utils import get_logger from ..utils.torch_utils import unwrap_module -from ._common import _ALL_TRANSFORMER_BLOCK_IDENTIFIERS, _ATTENTION_CLASSES, _FEEDFORWARD_CLASSES +from ._common import _ALL_TRANSFORMER_BLOCK_IDENTIFIERS, _ATTENTION_CLASSES, _FEEDFORWARD_CLASSES, _get_submodule_from_fqn from ._helpers import AttentionProcessorRegistry, TransformerBlockRegistry from .hooks import HookRegistry, ModelHook @@ -66,12 +66,13 @@ class LayerSkipConfig: def __post_init__(self): if not (0 <= self.dropout <= 1): raise ValueError(f"Expected `dropout` to be between 0.0 and 1.0, but got {self.dropout}.") + if not math.isclose(self.dropout, 1.0) and self.skip_attention_scores: + raise ValueError( + "Cannot set `skip_attention_scores` to True when `dropout` is not 1.0. Please set `dropout` to 1.0." + ) class AttentionScoreSkipFunctionMode(torch.overrides.TorchFunctionMode): - def __init__(self) -> None: - super().__init__() - def __torch_function__(self, func, types, args=(), kwargs=None): if kwargs is None: kwargs = {} @@ -226,10 +227,3 @@ def _apply_layer_skip_hook(module: torch.nn.Module, config: LayerSkipConfig, nam f"Could not find any transformer blocks matching the provided indices {config.indices} and " f"fully qualified name '{config.fqn}'. Please check the indices and fqn for correctness." ) - - -def _get_submodule_from_fqn(module: torch.nn.Module, fqn: str) -> Optional[torch.nn.Module]: - for submodule_name, submodule in module.named_modules(): - if submodule_name == fqn: - return submodule - return None diff --git a/src/diffusers/hooks/smoothed_energy_guidance_utils.py b/src/diffusers/hooks/smoothed_energy_guidance_utils.py new file mode 100644 index 000000000000..20df0de048c7 --- /dev/null +++ b/src/diffusers/hooks/smoothed_energy_guidance_utils.py @@ -0,0 +1,148 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Tuple + +import torch +import torch.nn.functional as F + +from ..utils import get_logger +from ._common import _ATTENTION_CLASSES, _get_submodule_from_fqn +from .hooks import HookRegistry, ModelHook + + +logger = get_logger(__name__) # pylint: disable=invalid-name + +_SMOOTHED_ENERGY_GUIDANCE_HOOK = "smoothed_energy_guidance_hook" + + +@dataclass +class SmoothedEnergyGuidanceConfig: + r""" + Configuration for skipping internal transformer blocks when executing a transformer model. + + Args: + indices (`List[int]`): + The indices of the layer to skip. This is typically the first layer in the transformer block. + fqn (`str`, defaults to `"auto"`): + The fully qualified name identifying the stack of transformer blocks. Typically, this is + `transformer_blocks`, `single_transformer_blocks`, `blocks`, `layers`, or `temporal_transformer_blocks`. + For automatic detection, set this to `"auto"`. + "auto" only works on DiT models. For UNet models, you must provide the correct fqn. + _query_proj_identifiers (`List[str]`, defaults to `None`): + The identifiers for the query projection layers. Typically, these are `to_q`, `query`, or `q_proj`. + If `None`, `to_q` is used by default. + """ + + indices: List[int] + fqn: str = "auto" + _query_proj_identifiers: List[str] = None + + +class SmoothedEnergyGuidanceHook(ModelHook): + def __init__(self, blur_sigma: float = 1.0, blur_threshold_inf: float = 9999.9) -> None: + super().__init__() + self.blur_sigma = blur_sigma + self.blur_threshold_inf = blur_threshold_inf + + def post_forward(self, module: torch.nn.Module, output: torch.Tensor) -> torch.Tensor: + # Copied from https://github.com/SusungHong/SEG-SDXL/blob/cf8256d640d5373541cfea3b3b6caf93272cf986/pipeline_seg.py#L172C31-L172C102 + kernel_size = math.ceil(6 * self.blur_sigma) + 1 - math.ceil(6 * self.blur_sigma) % 2 + smoothed_output = _gaussian_blur_2d(output, kernel_size, self.blur_sigma, self.blur_threshold_inf) + return smoothed_output + + +def _apply_smoothed_energy_guidance_hook(module: torch.nn.Module, config: SmoothedEnergyGuidanceConfig, blur_sigma: float, name: Optional[str] = None) -> None: + name = name or _SMOOTHED_ENERGY_GUIDANCE_HOOK + + if config.fqn == "auto": + for identifier in _ALL_TRANSFORMER_BLOCK_IDENTIFIERS: + if hasattr(module, identifier): + config.fqn = identifier + break + else: + raise ValueError( + "Could not find a suitable identifier for the transformer blocks automatically. Please provide a valid " + "`fqn` (fully qualified name) that identifies a stack of transformer blocks." + ) + + if config._query_proj_identifiers is None: + config._query_proj_identifiers = ["to_q"] + + transformer_blocks = _get_submodule_from_fqn(module, config.fqn) + blocks_found = False + for i, block in enumerate(transformer_blocks): + if i not in config.indices: + continue + + blocks_found = True + + for submodule_name, submodule in block.named_modules(): + if not isinstance(submodule, _ATTENTION_CLASSES) or submodule.is_cross_attention: + continue + for identifier in config._query_proj_identifiers: + query_proj = getattr(submodule, identifier, None) + if query_proj is None or not isinstance(query_proj, torch.nn.Linear): + continue + logger.debug( + f"Registering smoothed energy guidance hook on {config.fqn}.{i}.{submodule_name}.{identifier}" + ) + registry = HookRegistry.check_if_exists_or_initialize(query_proj) + hook = SmoothedEnergyGuidanceHook(blur_sigma) + registry.register_hook(hook, name) + + if not blocks_found: + raise ValueError( + f"Could not find any transformer blocks matching the provided indices {config.indices} and " + f"fully qualified name '{config.fqn}'. Please check the indices and fqn for correctness." + ) + + +# Modified from https://github.com/SusungHong/SEG-SDXL/blob/cf8256d640d5373541cfea3b3b6caf93272cf986/pipeline_seg.py#L71 +def _gaussian_blur_2d(query: torch.Tensor, kernel_size: int, sigma: float, sigma_threshold_inf: float) -> torch.Tensor: + assert query.ndim == 3 + + is_inf = sigma > sigma_threshold_inf + batch_size, seq_len, embed_dim = query.shape + + seq_len_sqrt = int(math.sqrt(seq_len)) + num_square_tokens = seq_len_sqrt * seq_len_sqrt + query_slice = query[:, :num_square_tokens, :] + query_slice = query_slice.permute(0, 2, 1) + query_slice = query_slice.reshape(batch_size, embed_dim, seq_len_sqrt, seq_len_sqrt) + + if is_inf: + kernel_size = min(kernel_size, seq_len_sqrt - (seq_len_sqrt % 2 - 1)) + kernel_size_half = (kernel_size - 1) / 2 + + x = torch.linspace(-kernel_size_half, kernel_size_half, steps=kernel_size) + pdf = torch.exp(-0.5 * (x / sigma).pow(2)) + kernel1d = pdf / pdf.sum() + kernel1d = kernel1d.to(query) + kernel2d = torch.matmul(kernel1d[:, None], kernel1d[None, :]) + kernel2d = kernel2d.expand(embed_dim, 1, kernel2d.shape[0], kernel2d.shape[1]) + + padding = [kernel_size // 2, kernel_size // 2, kernel_size // 2, kernel_size // 2] + query_slice = F.pad(query_slice, padding, mode="reflect") + query_slice = F.conv2d(query_slice, kernel2d, groups=embed_dim) + else: + query[:] = query.mean(dim=(-2, -1), keepdim=True) + + query_slice = query_slice.reshape(batch_size, embed_dim, num_square_tokens) + query_slice = query_slice.permute(0, 2, 1) + query[:, :num_square_tokens, :] = query_slice + + return query diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py index aed212c3f880..76030a71535a 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py @@ -2231,9 +2231,9 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: data.num_channels_unet = pipeline.unet.config.in_channels data.disable_guidance = True if pipeline.unet.config.time_cond_proj_dim is not None else False if data.disable_guidance: - pipeline.guider.force_disable() + pipeline.guider._force_disable() else: - pipeline.guider.force_enable() + pipeline.guider._force_enable() # Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline data.extra_step_kwargs = self.prepare_extra_step_kwargs(pipeline, data.generator, data.eta) @@ -2626,9 +2626,9 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: # (2) Prepare conditional inputs for unet using the guider data.disable_guidance = True if pipeline.unet.config.time_cond_proj_dim is not None else False if data.disable_guidance: - pipeline.guider.force_disable() + pipeline.guider._force_disable() else: - pipeline.guider.force_enable() + pipeline.guider._force_enable() # (4) Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline data.extra_step_kwargs = self.prepare_extra_step_kwargs(pipeline, data.generator, data.eta) @@ -3039,9 +3039,9 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: # adding default guider arguments: disable_guidance, guidance_scale, guidance_rescale data.disable_guidance = True if pipeline.unet.config.time_cond_proj_dim is not None else False if data.disable_guidance: - pipeline.guider.force_disable() + pipeline.guider._force_disable() else: - pipeline.guider.force_enable() + pipeline.guider._force_enable() data.control_type = data.control_type.reshape(1, -1).to(data.device, dtype=data.prompt_embeds.dtype) repeat_by = data.batch_size * data.num_images_per_prompt // data.control_type.shape[0] From b9bcd469f13287afb9deaf1486714170d26b9999 Mon Sep 17 00:00:00 2001 From: Aryan Date: Tue, 15 Apr 2025 21:25:15 +0200 Subject: [PATCH 15/39] add note about seg --- src/diffusers/guiders/smoothed_energy_guidance.py | 3 +++ .../hooks/smoothed_energy_guidance_utils.py | 14 ++++++++++++-- 2 files changed, 15 insertions(+), 2 deletions(-) diff --git a/src/diffusers/guiders/smoothed_energy_guidance.py b/src/diffusers/guiders/smoothed_energy_guidance.py index bd2a61b894e7..2328aa82ec98 100644 --- a/src/diffusers/guiders/smoothed_energy_guidance.py +++ b/src/diffusers/guiders/smoothed_energy_guidance.py @@ -25,6 +25,9 @@ class SmoothedEnergyGuidance(BaseGuidance): """ Smoothed Energy Guidance (SEG): https://huggingface.co/papers/2408.00760 + + SEG is only supported as an experimental prototype feature for now, so the implementation may be modified + in the future without warning or guarantee of reproducibility. Args: guidance_scale (`float`, defaults to `7.5`): diff --git a/src/diffusers/hooks/smoothed_energy_guidance_utils.py b/src/diffusers/hooks/smoothed_energy_guidance_utils.py index 20df0de048c7..f0366e29887f 100644 --- a/src/diffusers/hooks/smoothed_energy_guidance_utils.py +++ b/src/diffusers/hooks/smoothed_energy_guidance_utils.py @@ -113,6 +113,16 @@ def _apply_smoothed_energy_guidance_hook(module: torch.nn.Module, config: Smooth # Modified from https://github.com/SusungHong/SEG-SDXL/blob/cf8256d640d5373541cfea3b3b6caf93272cf986/pipeline_seg.py#L71 def _gaussian_blur_2d(query: torch.Tensor, kernel_size: int, sigma: float, sigma_threshold_inf: float) -> torch.Tensor: + """ + This implementation assumes that the input query is for visual (image/videos) tokens to apply the 2D gaussian + blur. However, some models use joint text-visual token attention for which this may not be suitable. Additionally, + this implementation also assumes that the visual tokens come from a square image/video. In practice, despite + these assumptions, applying the 2D square gaussian blur on the query projections generates reasonable results + for Smoothed Energy Guidance. + + SEG is only supported as an experimental prototype feature for now, so the implementation may be modified + in the future without warning or guarantee of reproducibility. + """ assert query.ndim == 3 is_inf = sigma > sigma_threshold_inf @@ -139,10 +149,10 @@ def _gaussian_blur_2d(query: torch.Tensor, kernel_size: int, sigma: float, sigma query_slice = F.pad(query_slice, padding, mode="reflect") query_slice = F.conv2d(query_slice, kernel2d, groups=embed_dim) else: - query[:] = query.mean(dim=(-2, -1), keepdim=True) + query_slice[:] = query_slice.mean(dim=(-2, -1), keepdim=True) query_slice = query_slice.reshape(batch_size, embed_dim, num_square_tokens) query_slice = query_slice.permute(0, 2, 1) - query[:, :num_square_tokens, :] = query_slice + query[:, :num_square_tokens, :] = query_slice.clone() return query From 2dc673a213dba107aa7463a73295421ff1d30218 Mon Sep 17 00:00:00 2001 From: Aryan Date: Wed, 16 Apr 2025 06:38:03 +0200 Subject: [PATCH 16/39] tangential cfg --- src/diffusers/__init__.py | 2 + src/diffusers/guiders/__init__.py | 3 +- .../guiders/adaptive_projected_guidance.py | 7 +- .../guiders/smoothed_energy_guidance.py | 5 +- .../tangential_classifier_free_guidance.py | 133 ++++++++++++++++++ src/diffusers/hooks/layer_skip.py | 2 +- 6 files changed, 148 insertions(+), 4 deletions(-) create mode 100644 src/diffusers/guiders/tangential_classifier_free_guidance.py diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 67d8ae9f79ac..a4f55acf8b70 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -138,6 +138,7 @@ "ClassifierFreeZeroStarGuidance", "SkipLayerGuidance", "SmoothedEnergyGuidance", + "TangentialClassifierFreeGuidance", ] ) _import_structure["hooks"].extend( @@ -732,6 +733,7 @@ ClassifierFreeZeroStarGuidance, SkipLayerGuidance, SmoothedEnergyGuidance, + TangentialClassifierFreeGuidance, ) from .hooks import ( FasterCacheConfig, diff --git a/src/diffusers/guiders/__init__.py b/src/diffusers/guiders/__init__.py index 7b88d61c67b3..3c1ee293382d 100644 --- a/src/diffusers/guiders/__init__.py +++ b/src/diffusers/guiders/__init__.py @@ -24,5 +24,6 @@ from .classifier_free_zero_star_guidance import ClassifierFreeZeroStarGuidance from .skip_layer_guidance import SkipLayerGuidance from .smoothed_energy_guidance import SmoothedEnergyGuidance + from .tangential_classifier_free_guidance import TangentialClassifierFreeGuidance - GuiderType = Union[AdaptiveProjectedGuidance, ClassifierFreeGuidance, ClassifierFreeZeroStarGuidance, SkipLayerGuidance] + GuiderType = Union[AdaptiveProjectedGuidance, AutoGuidance, ClassifierFreeGuidance, ClassifierFreeZeroStarGuidance, SkipLayerGuidance, SmoothedEnergyGuidance, TangentialClassifierFreeGuidance] diff --git a/src/diffusers/guiders/adaptive_projected_guidance.py b/src/diffusers/guiders/adaptive_projected_guidance.py index 45bd196860f4..05c186e58d9f 100644 --- a/src/diffusers/guiders/adaptive_projected_guidance.py +++ b/src/diffusers/guiders/adaptive_projected_guidance.py @@ -155,20 +155,25 @@ def normalized_guidance( ): diff = pred_cond - pred_uncond dim = [-i for i in range(1, len(diff.shape))] + if momentum_buffer is not None: momentum_buffer.update(diff) diff = momentum_buffer.running_average + if norm_threshold > 0: ones = torch.ones_like(diff) diff_norm = diff.norm(p=2, dim=dim, keepdim=True) scale_factor = torch.minimum(ones, norm_threshold / diff_norm) diff = diff * scale_factor + v0, v1 = diff.double(), pred_cond.double() v1 = torch.nn.functional.normalize(v1, dim=dim) v0_parallel = (v0 * v1).sum(dim=dim, keepdim=True) * v1 v0_orthogonal = v0 - v0_parallel diff_parallel, diff_orthogonal = v0_parallel.type_as(diff), v0_orthogonal.type_as(diff) normalized_update = diff_orthogonal + eta * diff_parallel + pred = pred_cond if use_original_formulation else pred_uncond - pred = pred + (guidance_scale - 1) * normalized_update + pred = pred + guidance_scale * normalized_update + return pred diff --git a/src/diffusers/guiders/smoothed_energy_guidance.py b/src/diffusers/guiders/smoothed_energy_guidance.py index 2328aa82ec98..906900856f4c 100644 --- a/src/diffusers/guiders/smoothed_energy_guidance.py +++ b/src/diffusers/guiders/smoothed_energy_guidance.py @@ -27,7 +27,10 @@ class SmoothedEnergyGuidance(BaseGuidance): Smoothed Energy Guidance (SEG): https://huggingface.co/papers/2408.00760 SEG is only supported as an experimental prototype feature for now, so the implementation may be modified - in the future without warning or guarantee of reproducibility. + in the future without warning or guarantee of reproducibility. This implementation assumes: + - Generated images are square (height == width) + - The model does not combine different modalities together (e.g., text and image latent streams are + not combined together such as Flux) Args: guidance_scale (`float`, defaults to `7.5`): diff --git a/src/diffusers/guiders/tangential_classifier_free_guidance.py b/src/diffusers/guiders/tangential_classifier_free_guidance.py new file mode 100644 index 000000000000..078d795baa68 --- /dev/null +++ b/src/diffusers/guiders/tangential_classifier_free_guidance.py @@ -0,0 +1,133 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from typing import Optional, Union, Tuple, List + +import torch + +from .guider_utils import BaseGuidance, rescale_noise_cfg, _default_prepare_inputs + + +class TangentialClassifierFreeGuidance(BaseGuidance): + """ + Tangential Classifier Free Guidance (TCFG): https://huggingface.co/papers/2503.18137 + + Args: + guidance_scale (`float`, defaults to `7.5`): + The scale parameter for classifier-free guidance. Higher values result in stronger conditioning on the text + prompt, while lower values allow for more freedom in generation. Higher values may lead to saturation and + deterioration of image quality. + guidance_rescale (`float`, defaults to `0.0`): + The rescale factor applied to the noise predictions. This is used to improve image quality and fix + overexposure. Based on Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are + Flawed](https://huggingface.co/papers/2305.08891). + use_original_formulation (`bool`, defaults to `False`): + Whether to use the original formulation of classifier-free guidance as proposed in the paper. By default, + we use the diffusers-native implementation that has been in the codebase for a long time. See + [~guiders.classifier_free_guidance.ClassifierFreeGuidance] for more details. + start (`float`, defaults to `0.0`): + The fraction of the total number of denoising steps after which guidance starts. + stop (`float`, defaults to `1.0`): + The fraction of the total number of denoising steps after which guidance stops. + """ + + _input_predictions = ["pred_cond", "pred_uncond"] + + def __init__( + self, + guidance_scale: float = 7.5, + guidance_rescale: float = 0.0, + use_original_formulation: bool = False, + start: float = 0.0, + stop: float = 1.0, + ): + super().__init__(start, stop) + + self.guidance_scale = guidance_scale + self.guidance_rescale = guidance_rescale + self.use_original_formulation = use_original_formulation + self.momentum_buffer = None + + def prepare_inputs(self, denoiser: torch.nn.Module, *args: Union[Tuple[torch.Tensor], List[torch.Tensor]]) -> Tuple[List[torch.Tensor], ...]: + return _default_prepare_inputs(denoiser, self.num_conditions, *args) + + def prepare_outputs(self, denoiser: torch.nn.Module, pred: torch.Tensor) -> None: + self._num_outputs_prepared += 1 + if self._num_outputs_prepared > self.num_conditions: + raise ValueError(f"Expected {self.num_conditions} outputs, but prepare_outputs called more times.") + key = self._input_predictions[self._num_outputs_prepared - 1] + self._preds[key] = pred + + def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> torch.Tensor: + pred = None + + if not self._is_tcfg_enabled(): + pred = pred_cond + else: + pred = normalized_guidance(pred_cond, pred_uncond, self.guidance_scale, self.use_original_formulation) + + if self.guidance_rescale > 0.0: + pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale) + + return pred + + @property + def is_conditional(self) -> bool: + return self._num_outputs_prepared == 0 + + @property + def num_conditions(self) -> int: + num_conditions = 1 + if self._is_tcfg_enabled(): + num_conditions += 1 + return num_conditions + + def _is_tcfg_enabled(self) -> bool: + if not self._enabled: + return False + + is_within_range = True + if self._num_inference_steps is not None: + skip_start_step = int(self._start * self._num_inference_steps) + skip_stop_step = int(self._stop * self._num_inference_steps) + is_within_range = skip_start_step <= self._step < skip_stop_step + + is_close = False + if self.use_original_formulation: + is_close = math.isclose(self.guidance_scale, 0.0) + else: + is_close = math.isclose(self.guidance_scale, 1.0) + + return is_within_range and not is_close + + +def normalized_guidance(pred_cond: torch.Tensor, pred_uncond: torch.Tensor, guidance_scale: float, use_original_formulation: bool = False) -> torch.Tensor: + cond_dtype = pred_cond.dtype + preds = torch.stack([pred_cond, pred_uncond], dim=1).float() + preds = preds.flatten(2) + U, S, Vh = torch.linalg.svd(preds, full_matrices=False) + Vh_modified = Vh.clone() + Vh_modified[:, 1] = 0 + + uncond_flat = pred_uncond.reshape(pred_uncond.size(0), 1, -1).float() + x_Vh = torch.matmul(uncond_flat, Vh.transpose(-2, -1)) + x_Vh_V = torch.matmul(x_Vh, Vh_modified) + pred_uncond = x_Vh_V.reshape(pred_uncond.shape).to(cond_dtype) + + pred = pred_cond if use_original_formulation else pred_uncond + shift = pred_cond - pred_uncond + pred = pred + guidance_scale * shift + + return pred diff --git a/src/diffusers/hooks/layer_skip.py b/src/diffusers/hooks/layer_skip.py index 45f42a1f0f86..c50d2b7471e4 100644 --- a/src/diffusers/hooks/layer_skip.py +++ b/src/diffusers/hooks/layer_skip.py @@ -92,7 +92,7 @@ def __init__(self, skip_processor_output_fn: Callable, skip_attention_scores: bo def new_forward(self, module: torch.nn.Module, *args, **kwargs): if self.skip_attention_scores: - if math.isclose(self.dropout, 1.0): + if not math.isclose(self.dropout, 1.0): raise ValueError( "Cannot set `skip_attention_scores` to True when `dropout` is not 1.0. Please set `dropout` to 1.0." ) From 77d8a285bf36960a2e0315725e322e2f0f1f6197 Mon Sep 17 00:00:00 2001 From: Aryan Date: Wed, 16 Apr 2025 08:08:00 +0200 Subject: [PATCH 17/39] cfg plus plus --- src/diffusers/__init__.py | 2 + src/diffusers/guiders/__init__.py | 1 + .../classifier_free_guidance_plus_plus.py | 117 ++++++++++++++++++ src/diffusers/guiders/guider_utils.py | 9 +- .../tangential_classifier_free_guidance.py | 1 - .../pipeline_stable_diffusion_xl_modular.py | 9 +- .../schedulers/scheduling_euler_discrete.py | 29 +++++ 7 files changed, 163 insertions(+), 5 deletions(-) create mode 100644 src/diffusers/guiders/classifier_free_guidance_plus_plus.py diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index a4f55acf8b70..424011961ab0 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -134,6 +134,7 @@ [ "AdaptiveProjectedGuidance", "AutoGuidance", + "CFGPlusPlusGuidance", "ClassifierFreeGuidance", "ClassifierFreeZeroStarGuidance", "SkipLayerGuidance", @@ -729,6 +730,7 @@ from .guiders import ( AdaptiveProjectedGuidance, AutoGuidance, + CFGPlusPlusGuidance, ClassifierFreeGuidance, ClassifierFreeZeroStarGuidance, SkipLayerGuidance, diff --git a/src/diffusers/guiders/__init__.py b/src/diffusers/guiders/__init__.py index 3c1ee293382d..56e95c92b697 100644 --- a/src/diffusers/guiders/__init__.py +++ b/src/diffusers/guiders/__init__.py @@ -20,6 +20,7 @@ if is_torch_available(): from .adaptive_projected_guidance import AdaptiveProjectedGuidance from .auto_guidance import AutoGuidance + from .classifier_free_guidance_plus_plus import CFGPlusPlusGuidance from .classifier_free_guidance import ClassifierFreeGuidance from .classifier_free_zero_star_guidance import ClassifierFreeZeroStarGuidance from .skip_layer_guidance import SkipLayerGuidance diff --git a/src/diffusers/guiders/classifier_free_guidance_plus_plus.py b/src/diffusers/guiders/classifier_free_guidance_plus_plus.py new file mode 100644 index 000000000000..516dbfa0e05f --- /dev/null +++ b/src/diffusers/guiders/classifier_free_guidance_plus_plus.py @@ -0,0 +1,117 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from typing import Optional, Union, Tuple, List + +import torch + +from .guider_utils import BaseGuidance, rescale_noise_cfg, _default_prepare_inputs + + +class CFGPlusPlusGuidance(BaseGuidance): + """ + CFG++: https://huggingface.co/papers/2406.08070 + + Args: + guidance_scale (`float`, defaults to `0.7`): + The scale parameter for classifier-free guidance. Higher values result in stronger conditioning on the text + prompt, while lower values allow for more freedom in generation. Higher values may lead to saturation and + deterioration of image quality. + guidance_rescale (`float`, defaults to `0.0`): + The rescale factor applied to the noise predictions. This is used to improve image quality and fix + overexposure. Based on Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are + Flawed](https://huggingface.co/papers/2305.08891). + use_original_formulation (`bool`, defaults to `False`): + Whether to use the original formulation of classifier-free guidance as proposed in the paper. By default, + we use the diffusers-native implementation that has been in the codebase for a long time. See + [~guiders.classifier_free_guidance.ClassifierFreeGuidance] for more details. + start (`float`, defaults to `0.0`): + The fraction of the total number of denoising steps after which guidance starts. + stop (`float`, defaults to `1.0`): + The fraction of the total number of denoising steps after which guidance stops. + """ + + _input_predictions = ["pred_cond", "pred_uncond"] + + def __init__( + self, + guidance_scale: float = 0.7, + guidance_rescale: float = 0.0, + use_original_formulation: bool = False, + start: float = 0.0, + stop: float = 1.0, + ): + super().__init__(start, stop) + + self.guidance_scale = guidance_scale + self.guidance_rescale = guidance_rescale + self.use_original_formulation = use_original_formulation + + def prepare_inputs(self, denoiser: torch.nn.Module, *args: Union[Tuple[torch.Tensor], List[torch.Tensor]]) -> Tuple[List[torch.Tensor], ...]: + return _default_prepare_inputs(denoiser, self.num_conditions, *args) + + def prepare_outputs(self, denoiser: torch.nn.Module, pred: torch.Tensor) -> None: + self._num_outputs_prepared += 1 + if self._num_outputs_prepared > self.num_conditions: + raise ValueError(f"Expected {self.num_conditions} outputs, but prepare_outputs called more times.") + key = self._input_predictions[self._num_outputs_prepared - 1] + self._preds[key] = pred + + def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> torch.Tensor: + pred = None + + if not self._is_cfgpp_enabled(): + pred = pred_cond + else: + shift = pred_cond - pred_uncond + pred = pred_cond if self.use_original_formulation else pred_uncond + pred = pred + self.guidance_scale * shift + + if self.guidance_rescale > 0.0: + pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale) + + return pred + + def post_scheduler_step(self, pred: torch.Tensor) -> torch.Tensor: + if self._is_cfgpp_enabled(): + # TODO(aryan): this probably only makes sense for EulerDiscreteScheduler. Look into the others later! + pred_cond = self._preds["pred_cond"] + pred_uncond = self._preds["pred_uncond"] + diff = pred_uncond - pred_cond + pred = pred + diff * self.guidance_scale * self._sigma_next + return pred + + @property + def is_conditional(self) -> bool: + return self._num_outputs_prepared == 0 + + @property + def num_conditions(self) -> int: + num_conditions = 1 + if self._is_cfgpp_enabled(): + num_conditions += 1 + return num_conditions + + def _is_cfgpp_enabled(self) -> bool: + if not self._enabled: + return False + + is_within_range = True + if self._num_inference_steps is not None: + skip_start_step = int(self._start * self._num_inference_steps) + skip_stop_step = int(self._stop * self._num_inference_steps) + is_within_range = skip_start_step <= self._step < skip_stop_step + + return is_within_range diff --git a/src/diffusers/guiders/guider_utils.py b/src/diffusers/guiders/guider_utils.py index e7d22a50d0da..f51452ed0cee 100644 --- a/src/diffusers/guiders/guider_utils.py +++ b/src/diffusers/guiders/guider_utils.py @@ -37,6 +37,8 @@ def __init__(self, start: float = 0.0, stop: float = 1.0): self._step: int = None self._num_inference_steps: int = None self._timestep: torch.LongTensor = None + self._sigma: torch.Tensor = None + self._sigma_next: torch.Tensor = None self._preds: Dict[str, torch.Tensor] = {} self._num_outputs_prepared: int = 0 self._enabled = True @@ -61,10 +63,12 @@ def _force_disable(self): def _force_enable(self): self._enabled = True - def set_state(self, step: int, num_inference_steps: int, timestep: torch.LongTensor) -> None: + def set_state(self, step: int, num_inference_steps: int, timestep: torch.LongTensor, sigma: torch.Tensor, sigma_next: torch.Tensor) -> None: self._step = step self._num_inference_steps = num_inference_steps self._timestep = timestep + self._sigma = sigma + self._sigma_next = sigma_next self._preds = {} self._num_outputs_prepared = 0 @@ -91,6 +95,9 @@ def __call__(self, **kwargs) -> Any: def forward(self, *args, **kwargs) -> Any: raise NotImplementedError("BaseGuidance::forward must be implemented in subclasses.") + def post_scheduler_step(self, pred: torch.Tensor) -> torch.Tensor: + return pred + @property def is_conditional(self) -> bool: raise NotImplementedError("BaseGuidance::is_conditional must be implemented in subclasses.") diff --git a/src/diffusers/guiders/tangential_classifier_free_guidance.py b/src/diffusers/guiders/tangential_classifier_free_guidance.py index 078d795baa68..7529114bfd6f 100644 --- a/src/diffusers/guiders/tangential_classifier_free_guidance.py +++ b/src/diffusers/guiders/tangential_classifier_free_guidance.py @@ -58,7 +58,6 @@ def __init__( self.guidance_scale = guidance_scale self.guidance_rescale = guidance_rescale self.use_original_formulation = use_original_formulation - self.momentum_buffer = None def prepare_inputs(self, denoiser: torch.nn.Module, *args: Union[Tuple[torch.Tensor], List[torch.Tensor]]) -> Tuple[List[torch.Tensor], ...]: return _default_prepare_inputs(denoiser, self.num_conditions, *args) diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py index 76030a71535a..8e0ea4545f29 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py @@ -2241,7 +2241,7 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: with pipeline.progress_bar(total=data.num_inference_steps) as progress_bar: for i, t in enumerate(data.timesteps): - pipeline.guider.set_state(step=i, num_inference_steps=data.num_inference_steps, timestep=t) + pipeline.guider.set_state(step=i, num_inference_steps=data.num_inference_steps, timestep=t, sigma=pipeline.scheduler.sigmas[i], sigma_next=pipeline.scheduler.sigmas[i + 1]) ( latents, @@ -2301,6 +2301,7 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: # Perform scheduler step using the predicted output data.latents_dtype = data.latents.dtype data.latents = pipeline.scheduler.step(data.noise_pred, t, data.latents, **data.extra_step_kwargs, return_dict=False)[0] + data.latents = pipeline.guider.post_scheduler_step(data.latents) if data.latents.dtype != data.latents_dtype: if torch.backends.mps.is_available(): @@ -2637,7 +2638,7 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: # (5) Denoise loop with pipeline.progress_bar(total=data.num_inference_steps) as progress_bar: for i, t in enumerate(data.timesteps): - pipeline.guider.set_state(step=i, num_inference_steps=data.num_inference_steps, timestep=t) + pipeline.guider.set_state(step=i, num_inference_steps=data.num_inference_steps, timestep=t, sigma=pipeline.scheduler.sigmas[i], sigma_next=pipeline.scheduler.sigmas[i + 1]) ( latents, @@ -2730,6 +2731,7 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: # Perform scheduler step using the predicted output data.latents_dtype = data.latents.dtype data.latents = pipeline.scheduler.step(data.noise_pred, t, data.latents, **data.extra_step_kwargs, return_dict=False)[0] + data.latents = pipeline.guider.post_scheduler_step(data.latents) if data.latents.dtype != data.latents_dtype: if torch.backends.mps.is_available(): @@ -3053,7 +3055,7 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: with pipeline.progress_bar(total=data.num_inference_steps) as progress_bar: for i, t in enumerate(data.timesteps): - pipeline.guider.set_state(step=i, num_inference_steps=data.num_inference_steps, timestep=t) + pipeline.guider.set_state(step=i, num_inference_steps=data.num_inference_steps, timestep=t, sigma=pipeline.scheduler.sigmas[i], sigma_next=pipeline.scheduler.sigmas[i + 1]) ( latents, @@ -3148,6 +3150,7 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: # Perform scheduler step using the predicted output data.latents_dtype = data.latents.dtype data.latents = pipeline.scheduler.step(data.noise_pred, t, data.latents, **data.extra_step_kwargs, return_dict=False)[0] + data.latents = pipeline.guider.post_scheduler_step(data.latents) if data.latents.dtype != data.latents_dtype: if torch.backends.mps.is_available(): diff --git a/src/diffusers/schedulers/scheduling_euler_discrete.py b/src/diffusers/schedulers/scheduling_euler_discrete.py index 56757f3ca197..4adec768b776 100644 --- a/src/diffusers/schedulers/scheduling_euler_discrete.py +++ b/src/diffusers/schedulers/scheduling_euler_discrete.py @@ -669,6 +669,35 @@ def step( prev_sample = sample + derivative * dt + # denoised = sample - model_output * sigmas[i] + # d = (sample - denoised) / sigmas[i] + # new_sample = denoised + d * sigmas[i + 1] + + # new_sample = denoised + (sample - denoised) * sigmas[i + 1] / sigmas[i] + # new_sample = sample - model_output * sigmas[i] + model_output * sigmas[i + 1] + # new_sample = sample + model_output * (sigmas[i + 1] - sigmas[i]) + # new_sample = sample - model_output * sigmas[i] + model_output * sigmas[i + 1] --- (1) + + # CFG++ ===== + # denoised = sample - model_output * sigmas[i] + # uncond_denoised = sample - model_output_uncond * sigmas[i] + # d = (sample - uncond_denoised) / sigmas[i] + # new_sample = denoised + d * sigmas[i + 1] + + # new_sample = denoised + (sample - uncond_denoised) * sigmas[i + 1] / sigmas[i] + # new_sample = sample - model_output * sigmas[i] + model_output_uncond * sigmas[i + 1] --- (2) + + # To go from (1) to (2): + # new_sample_2 = new_sample_1 - model_output * sigmas[i + 1] + model_output_uncond * sigmas[i + 1] + # new_sample_2 = new_sample_1 + (model_output_uncond - model_output) * sigmas[i + 1] + # new_sample_2 = new_sample_1 + diff * sigmas[i + 1] + + # diff = model_output_uncond - model_output + # diff = model_output_uncond - (model_output_uncond + g * (model_output_cond - model_output_uncond)) + # diff = model_output_uncond - (g * model_output_cond + (1 - g) * model_output_uncond) + # diff = model_output_uncond - g * model_output_cond + (g - 1) * model_output_uncond + # diff = g * (model_output_uncond - model_output_cond) + # Cast sample back to model compatible dtype prev_sample = prev_sample.to(model_output.dtype) From 78fca12803d69541fc63161b036db96792520fd8 Mon Sep 17 00:00:00 2001 From: Aryan Date: Wed, 16 Apr 2025 13:23:24 +0200 Subject: [PATCH 18/39] support cfgpp in ddim --- .../classifier_free_guidance_plus_plus.py | 20 ++++++++---------- src/diffusers/guiders/guider_utils.py | 11 ++-------- .../pipeline_stable_diffusion_xl_modular.py | 21 ++++++++----------- src/diffusers/schedulers/scheduling_ddim.py | 9 +++++++- .../schedulers/scheduling_euler_discrete.py | 11 +++++++++- 5 files changed, 38 insertions(+), 34 deletions(-) diff --git a/src/diffusers/guiders/classifier_free_guidance_plus_plus.py b/src/diffusers/guiders/classifier_free_guidance_plus_plus.py index 516dbfa0e05f..d1c6f8744143 100644 --- a/src/diffusers/guiders/classifier_free_guidance_plus_plus.py +++ b/src/diffusers/guiders/classifier_free_guidance_plus_plus.py @@ -12,8 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -import math -from typing import Optional, Union, Tuple, List +from typing import Dict, Optional, Union, Tuple, List import torch @@ -84,15 +83,6 @@ def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = return pred - def post_scheduler_step(self, pred: torch.Tensor) -> torch.Tensor: - if self._is_cfgpp_enabled(): - # TODO(aryan): this probably only makes sense for EulerDiscreteScheduler. Look into the others later! - pred_cond = self._preds["pred_cond"] - pred_uncond = self._preds["pred_uncond"] - diff = pred_uncond - pred_cond - pred = pred + diff * self.guidance_scale * self._sigma_next - return pred - @property def is_conditional(self) -> bool: return self._num_outputs_prepared == 0 @@ -104,6 +94,14 @@ def num_conditions(self) -> int: num_conditions += 1 return num_conditions + @property + def outputs(self) -> Dict[str, torch.Tensor]: + scheduler_step_kwargs = {} + if self._is_cfgpp_enabled(): + scheduler_step_kwargs["_use_cfgpp"] = True + scheduler_step_kwargs["_model_output_uncond"] = self._preds.get("pred_uncond") + return self._preds, scheduler_step_kwargs + def _is_cfgpp_enabled(self) -> bool: if not self._enabled: return False diff --git a/src/diffusers/guiders/guider_utils.py b/src/diffusers/guiders/guider_utils.py index f51452ed0cee..420a56690678 100644 --- a/src/diffusers/guiders/guider_utils.py +++ b/src/diffusers/guiders/guider_utils.py @@ -37,8 +37,6 @@ def __init__(self, start: float = 0.0, stop: float = 1.0): self._step: int = None self._num_inference_steps: int = None self._timestep: torch.LongTensor = None - self._sigma: torch.Tensor = None - self._sigma_next: torch.Tensor = None self._preds: Dict[str, torch.Tensor] = {} self._num_outputs_prepared: int = 0 self._enabled = True @@ -63,12 +61,10 @@ def _force_disable(self): def _force_enable(self): self._enabled = True - def set_state(self, step: int, num_inference_steps: int, timestep: torch.LongTensor, sigma: torch.Tensor, sigma_next: torch.Tensor) -> None: + def set_state(self, step: int, num_inference_steps: int, timestep: torch.LongTensor) -> None: self._step = step self._num_inference_steps = num_inference_steps self._timestep = timestep - self._sigma = sigma - self._sigma_next = sigma_next self._preds = {} self._num_outputs_prepared = 0 @@ -95,9 +91,6 @@ def __call__(self, **kwargs) -> Any: def forward(self, *args, **kwargs) -> Any: raise NotImplementedError("BaseGuidance::forward must be implemented in subclasses.") - def post_scheduler_step(self, pred: torch.Tensor) -> torch.Tensor: - return pred - @property def is_conditional(self) -> bool: raise NotImplementedError("BaseGuidance::is_conditional must be implemented in subclasses.") @@ -112,7 +105,7 @@ def num_conditions(self) -> int: @property def outputs(self) -> Dict[str, torch.Tensor]: - return self._preds + return self._preds, {} def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py index 8e0ea4545f29..37d2bbbe6ca6 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py @@ -2241,7 +2241,7 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: with pipeline.progress_bar(total=data.num_inference_steps) as progress_bar: for i, t in enumerate(data.timesteps): - pipeline.guider.set_state(step=i, num_inference_steps=data.num_inference_steps, timestep=t, sigma=pipeline.scheduler.sigmas[i], sigma_next=pipeline.scheduler.sigmas[i + 1]) + pipeline.guider.set_state(step=i, num_inference_steps=data.num_inference_steps, timestep=t) ( latents, @@ -2295,13 +2295,12 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: data.noise_pred = pipeline.guider.prepare_outputs(pipeline.unet, data.noise_pred) # Perform guidance - outputs = pipeline.guider.outputs + outputs, scheduler_step_kwargs = pipeline.guider.outputs data.noise_pred = pipeline.guider(**outputs) # Perform scheduler step using the predicted output data.latents_dtype = data.latents.dtype - data.latents = pipeline.scheduler.step(data.noise_pred, t, data.latents, **data.extra_step_kwargs, return_dict=False)[0] - data.latents = pipeline.guider.post_scheduler_step(data.latents) + data.latents = pipeline.scheduler.step(data.noise_pred, t, data.latents, **data.extra_step_kwargs, **scheduler_step_kwargs, return_dict=False)[0] if data.latents.dtype != data.latents_dtype: if torch.backends.mps.is_available(): @@ -2638,7 +2637,7 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: # (5) Denoise loop with pipeline.progress_bar(total=data.num_inference_steps) as progress_bar: for i, t in enumerate(data.timesteps): - pipeline.guider.set_state(step=i, num_inference_steps=data.num_inference_steps, timestep=t, sigma=pipeline.scheduler.sigmas[i], sigma_next=pipeline.scheduler.sigmas[i + 1]) + pipeline.guider.set_state(step=i, num_inference_steps=data.num_inference_steps, timestep=t) ( latents, @@ -2725,13 +2724,12 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: data.noise_pred = pipeline.guider.prepare_outputs(pipeline.unet, data.noise_pred) # Perform guidance - outputs = pipeline.guider.outputs + outputs, scheduler_step_kwargs = pipeline.guider.outputs data.noise_pred = pipeline.guider(**outputs) # Perform scheduler step using the predicted output data.latents_dtype = data.latents.dtype - data.latents = pipeline.scheduler.step(data.noise_pred, t, data.latents, **data.extra_step_kwargs, return_dict=False)[0] - data.latents = pipeline.guider.post_scheduler_step(data.latents) + data.latents = pipeline.scheduler.step(data.noise_pred, t, data.latents, **data.extra_step_kwargs, **scheduler_step_kwargs, return_dict=False)[0] if data.latents.dtype != data.latents_dtype: if torch.backends.mps.is_available(): @@ -3055,7 +3053,7 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: with pipeline.progress_bar(total=data.num_inference_steps) as progress_bar: for i, t in enumerate(data.timesteps): - pipeline.guider.set_state(step=i, num_inference_steps=data.num_inference_steps, timestep=t, sigma=pipeline.scheduler.sigmas[i], sigma_next=pipeline.scheduler.sigmas[i + 1]) + pipeline.guider.set_state(step=i, num_inference_steps=data.num_inference_steps, timestep=t) ( latents, @@ -3144,13 +3142,12 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: data.noise_pred = pipeline.guider.prepare_outputs(pipeline.unet, data.noise_pred) # Perform guidance - outputs = pipeline.guider.outputs + outputs, scheduler_step_kwargs = pipeline.guider.outputs data.noise_pred = pipeline.guider(**outputs) # Perform scheduler step using the predicted output data.latents_dtype = data.latents.dtype - data.latents = pipeline.scheduler.step(data.noise_pred, t, data.latents, **data.extra_step_kwargs, return_dict=False)[0] - data.latents = pipeline.guider.post_scheduler_step(data.latents) + data.latents = pipeline.scheduler.step(data.noise_pred, t, data.latents, **data.extra_step_kwargs, **scheduler_step_kwargs, return_dict=False)[0] if data.latents.dtype != data.latents_dtype: if torch.backends.mps.is_available(): diff --git a/src/diffusers/schedulers/scheduling_ddim.py b/src/diffusers/schedulers/scheduling_ddim.py index 13c9b3b4a5e9..2e74c9bbfccd 100644 --- a/src/diffusers/schedulers/scheduling_ddim.py +++ b/src/diffusers/schedulers/scheduling_ddim.py @@ -349,6 +349,8 @@ def step( generator=None, variance_noise: Optional[torch.Tensor] = None, return_dict: bool = True, + _model_output_uncond: Optional[torch.Tensor] = None, + _use_cfgpp: bool = True, ) -> Union[DDIMSchedulerOutput, Tuple]: """ Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion @@ -386,6 +388,11 @@ def step( raise ValueError( "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" ) + + if _use_cfgpp and self.config.prediction_type != "epsilon": + raise ValueError( + f"CFG++ is only supported for prediction type `epsilon`, but got {self.config.prediction_type}." + ) # See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf # Ideally, read DDIM paper in-detail understanding @@ -411,7 +418,7 @@ def step( # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf if self.config.prediction_type == "epsilon": pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5) - pred_epsilon = model_output + pred_epsilon = model_output if not _use_cfgpp else _model_output_uncond elif self.config.prediction_type == "sample": pred_original_sample = model_output pred_epsilon = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5) diff --git a/src/diffusers/schedulers/scheduling_euler_discrete.py b/src/diffusers/schedulers/scheduling_euler_discrete.py index 4adec768b776..4c82ca7e389b 100644 --- a/src/diffusers/schedulers/scheduling_euler_discrete.py +++ b/src/diffusers/schedulers/scheduling_euler_discrete.py @@ -584,6 +584,8 @@ def step( s_noise: float = 1.0, generator: Optional[torch.Generator] = None, return_dict: bool = True, + _model_output_uncond: Optional[torch.Tensor] = None, + _use_cfgpp: bool = False, ) -> Union[EulerDiscreteSchedulerOutput, Tuple]: """ Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion @@ -627,6 +629,11 @@ def step( "The `scale_model_input` function should be called before `step` to ensure correct denoising. " "See `StableDiffusionPipeline` for a usage example." ) + + if _use_cfgpp and self.config.prediction_type != "epsilon": + raise ValueError( + f"CFG++ is only supported for prediction type `epsilon`, but got {self.config.prediction_type}." + ) if self.step_index is None: self._init_step_index(timestep) @@ -668,7 +675,9 @@ def step( dt = self.sigmas[self.step_index + 1] - sigma_hat prev_sample = sample + derivative * dt - + if _use_cfgpp: + prev_sample = prev_sample + (_model_output_uncond - model_output) * self.sigmas[self.step_index + 1] + # denoised = sample - model_output * sigmas[i] # d = (sample - denoised) / sigmas[i] # new_sample = denoised + d * sigmas[i + 1] From 19555e95cce7c142453ebbbaa23646d66044c7cb Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Tue, 22 Apr 2025 10:33:03 +0200 Subject: [PATCH 19/39] update doc & repr --- src/diffusers/guiders/__init__.py | 11 +- src/diffusers/pipelines/modular_pipeline.py | 392 ++++++++++---------- 2 files changed, 200 insertions(+), 203 deletions(-) diff --git a/src/diffusers/guiders/__init__.py b/src/diffusers/guiders/__init__.py index 56e95c92b697..a0a905f2b522 100644 --- a/src/diffusers/guiders/__init__.py +++ b/src/diffusers/guiders/__init__.py @@ -13,7 +13,7 @@ # limitations under the License. from typing import Union - +from enum import Enum from ..utils import is_torch_available @@ -27,4 +27,11 @@ from .smoothed_energy_guidance import SmoothedEnergyGuidance from .tangential_classifier_free_guidance import TangentialClassifierFreeGuidance - GuiderType = Union[AdaptiveProjectedGuidance, AutoGuidance, ClassifierFreeGuidance, ClassifierFreeZeroStarGuidance, SkipLayerGuidance, SmoothedEnergyGuidance, TangentialClassifierFreeGuidance] + class GuiderType(Enum): + AdaptiveProjectedGuidance=1, + AutoGuidance=2, + ClassifierFreeGuidance=3, + ClassifierFreeZeroStarGuidance=4, + SkipLayerGuidance=5, + SmoothedEnergyGuidance=6, + TangentialClassifierFreeGuidance=7 diff --git a/src/diffusers/pipelines/modular_pipeline.py b/src/diffusers/pipelines/modular_pipeline.py index 785f38cdbf8c..b896066edf08 100644 --- a/src/diffusers/pipelines/modular_pipeline.py +++ b/src/diffusers/pipelines/modular_pipeline.py @@ -336,30 +336,143 @@ def wrap_text(text: str, indent: str, max_length: int) -> str: # Then update the original functions to use this combined version: def format_input_params(input_params: List[InputParam], indent_level: int = 4, max_line_length: int = 115) -> str: - return format_params(input_params, "Args", indent_level, max_line_length) + return format_params(input_params, "Inputs", indent_level, max_line_length) def format_output_params(output_params: List[OutputParam], indent_level: int = 4, max_line_length: int = 115) -> str: - return format_params(output_params, "Returns", indent_level, max_line_length) + return format_params(output_params, "Outputs", indent_level, max_line_length) +def format_components(components: List[ComponentSpec], indent_level: int = 4, max_line_length: int = 115, add_empty_lines: bool = True) -> str: + """Format a list of ComponentSpec objects into a readable string representation. -def make_doc_string(inputs, intermediates_inputs, outputs, description=""): + Args: + components: List of ComponentSpec objects to format + indent_level: Number of spaces to indent each component line (default: 4) + max_line_length: Maximum length for each line before wrapping (default: 115) + add_empty_lines: Whether to add empty lines between components (default: True) + + Returns: + A formatted string representing all components + """ + if not components: + return "" + + base_indent = " " * indent_level + component_indent = " " * (indent_level + 4) + formatted_components = [] + + # Add the header + formatted_components.append(f"{base_indent}Components:") + if add_empty_lines: + formatted_components.append("") + + # Add each component with optional empty lines between them + for i, component in enumerate(components): + # Get type name, handling special cases + type_name = component.type_hint.__name__ if hasattr(component.type_hint, "__name__") else str(component.type_hint) + + component_desc = f"{component_indent}{component.name} (`{type_name}`)" + if component.description: + component_desc += f": {component.description}" + if component.default_repo: + if isinstance(component.default_repo, list) and len(component.default_repo) == 2: + repo_info = component.default_repo[0] + subfolder = component.default_repo[1] + if subfolder: + repo_info += f", subfolder={subfolder}" + else: + repo_info = component.default_repo + component_desc += f" [{repo_info}]" + formatted_components.append(component_desc) + + # Add an empty line after each component except the last one + if add_empty_lines and i < len(components) - 1: + formatted_components.append("") + + return "\n".join(formatted_components) + + +def format_configs(configs: List[ConfigSpec], indent_level: int = 4, max_line_length: int = 115, add_empty_lines: bool = True) -> str: + """Format a list of ConfigSpec objects into a readable string representation. + + Args: + configs: List of ConfigSpec objects to format + indent_level: Number of spaces to indent each config line (default: 4) + max_line_length: Maximum length for each line before wrapping (default: 115) + add_empty_lines: Whether to add empty lines between configs (default: True) + + Returns: + A formatted string representing all configs + """ + if not configs: + return "" + + base_indent = " " * indent_level + config_indent = " " * (indent_level + 4) + formatted_configs = [] + + # Add the header + formatted_configs.append(f"{base_indent}Configs:") + if add_empty_lines: + formatted_configs.append("") + + # Add each config with optional empty lines between them + for i, config in enumerate(configs): + config_desc = f"{config_indent}{config.name} (default: {config.default})" + if config.description: + config_desc += f": {config.description}" + formatted_configs.append(config_desc) + + # Add an empty line after each config except the last one + if add_empty_lines and i < len(configs) - 1: + formatted_configs.append("") + + return "\n".join(formatted_configs) + + +def make_doc_string(inputs, intermediates_inputs, outputs, description="", class_name=None, expected_components=None, expected_configs=None): """ Generates a formatted documentation string describing the pipeline block's parameters and structure. + Args: + inputs (List[InputParam]): List of input parameters + intermediates_inputs (List[InputParam]): List of intermediate input parameters + outputs (List[OutputParam]): List of output parameters + description (str, *optional*): Description of the block + class_name (str, *optional*): Name of the class to include in the documentation + expected_components (List[ComponentSpec], *optional*): List of expected components + expected_configs (List[ConfigSpec], *optional*): List of expected configurations + Returns: - str: A formatted string containing information about call parameters, intermediate inputs/outputs, - and final intermediate outputs. + str: A formatted string containing information about components, configs, call parameters, + intermediate inputs/outputs, and final outputs. """ output = "" + # Add class name if provided + if class_name: + output += f"class {class_name}\n\n" + + # Add description if description: desc_lines = description.strip().split('\n') aligned_desc = '\n'.join(' ' + line for line in desc_lines) output += aligned_desc + "\n\n" + # Add components section if provided + if expected_components and len(expected_components) > 0: + components_str = format_components(expected_components, indent_level=2) + output += components_str + "\n\n" + + # Add configs section if provided + if expected_configs and len(expected_configs) > 0: + configs_str = format_configs(expected_configs, indent_level=2) + output += configs_str + "\n\n" + + # Add inputs section output += format_input_params(inputs + intermediates_inputs, indent_level=2) + # Add outputs section output += "\n\n" output += format_output_params(outputs, indent_level=2) @@ -440,31 +553,15 @@ def __repr__(self): desc.extend(f" {line}" for line in desc_lines[1:]) desc = '\n'.join(desc) + '\n' - # Components section - focus only on expected components + # Components section - use format_components with add_empty_lines=False expected_components = getattr(self, "expected_components", []) - expected_components_str_list = [] - - for component_spec in expected_components: - component_str = f" - {component_spec.name} ({component_spec.type_hint})" - - # Add repo info if available - if component_spec.default_repo: - if isinstance(component_spec.default_repo, list) and len(component_spec.default_repo) == 2: - repo_info = component_spec.default_repo[0] - subfolder = component_spec.default_repo[1] - if subfolder: - repo_info += f", subfolder={subfolder}" - else: - repo_info = component_spec.default_repo - component_str += f" [{repo_info}]" - - expected_components_str_list.append(component_str) - - components = "Components:\n" + "\n".join(expected_components_str_list) + components_str = format_components(expected_components, indent_level=2, add_empty_lines=False) + components = " " + components_str.replace("\n", "\n ") - # Configs section - focus only on expected configs + # Configs section - use format_configs with add_empty_lines=False expected_configs = getattr(self, "expected_configs", []) - configs = "Configs:\n" + "\n".join(f" - {k}" for k in sorted(expected_configs)) + configs_str = format_configs(expected_configs, indent_level=2, add_empty_lines=False) + configs = " " + configs_str.replace("\n", "\n ") # Inputs section inputs_str = format_inputs_short(self.inputs) @@ -478,8 +575,8 @@ def __repr__(self): f"{class_name}(\n" f" Class: {base_class}\n" f"{desc}" - f" {components}\n" - f" {configs}\n" + f"{components}\n" + f"{configs}\n" f" {inputs}\n" f" {intermediates}\n" f")" @@ -488,7 +585,15 @@ def __repr__(self): @property def doc(self): - return make_doc_string(self.inputs, self.intermediates_inputs, self.outputs, self.description) + return make_doc_string( + self.inputs, + self.intermediates_inputs, + self.outputs, + self.description, + class_name=self.__class__.__name__, + expected_components=self.expected_components, + expected_configs=self.expected_configs + ) def get_block_state(self, state: PipelineState) -> dict: @@ -796,32 +901,25 @@ def __repr__(self): # Components section - focus only on expected components expected_components = getattr(self, "expected_components", []) - expected_components_str_list = [] + components_str = format_components(expected_components, indent_level=2, add_empty_lines=False) - for component_spec in expected_components: - - component_str = f" - {component_spec.name} ({component_spec.type_hint.__name__})" - - # Add repo info if available - if component_spec.default_repo: - if isinstance(component_spec.default_repo, list) and len(component_spec.default_repo) == 2: - repo_info = component_spec.default_repo[0] - subfolder = component_spec.default_repo[1] - if subfolder: - repo_info += f", subfolder={subfolder}" - else: - repo_info = component_spec.default_repo - component_str += f" [{repo_info}]" - - expected_components_str_list.append(component_str) - - components_str = " Components:\n" + "\n".join(expected_components_str_list) - - # Configs section - focus only on expected configs + # Configs section - use format_configs with add_empty_lines=False expected_configs = getattr(self, "expected_configs", []) - configs_str = " Configs:\n" + "\n".join(f" - {config.name}" for config in sorted(expected_configs, key=lambda x: x.name)) + configs_str = format_configs(expected_configs, indent_level=2, add_empty_lines=False) - # Blocks section + # Inputs and outputs section - moved up + inputs_str = format_inputs_short(self.inputs) + inputs_str = " Inputs:\n " + inputs_str + + outputs = [out.name for out in self.outputs] + intermediates_str = format_intermediates_short(self.intermediates_inputs, self.required_intermediates_inputs, self.intermediates_outputs) + intermediates_str = ( + " Intermediates:\n" + f"{intermediates_str}\n" + f" - final outputs: {', '.join(outputs)}" + ) + + # Blocks section - moved to the end with simplified format blocks_str = " Blocks:\n" for i, (name, block) in enumerate(self.blocks.items()): # Get trigger input for this block @@ -846,52 +944,31 @@ def __repr__(self): indented_desc = desc_lines[0] if len(desc_lines) > 1: indented_desc += '\n' + '\n'.join(' ' + line for line in desc_lines[1:]) - blocks_str += f" Description: {indented_desc}\n" - - # Format inputs - inputs_str = format_inputs_short(block.inputs) - blocks_str += f" inputs: {inputs_str}\n" - - # Format intermediates - intermediates_str = format_intermediates_short( - block.intermediates_inputs, - block.required_intermediates_inputs, - block.intermediates_outputs - ) - if intermediates_str != " (none)": - blocks_str += " intermediates:\n" - indented_intermediates = "\n".join( - " " + line for line in intermediates_str.split("\n") - ) - blocks_str += f"{indented_intermediates}\n" - blocks_str += "\n" - - # Inputs and outputs section - inputs_str = format_inputs_short(self.inputs) - inputs_str = " Inputs:\n " + inputs_str - outputs = [out.name for out in self.outputs] - - intermediates_str = format_intermediates_short(self.intermediates_inputs, self.required_intermediates_inputs, self.intermediates_outputs) - intermediates_str = ( - "\n Intermediates:\n" - f"{intermediates_str}\n" - f" - final outputs: {', '.join(outputs)}" - ) + blocks_str += f" Description: {indented_desc}\n\n" return ( f"{header}\n" f"{desc}" f"{components_str}\n" f"{configs_str}\n" - f"{blocks_str}\n" f"{inputs_str}\n" f"{intermediates_str}\n" + f"{blocks_str}" f")" ) + @property def doc(self): - return make_doc_string(self.inputs, self.intermediates_inputs, self.outputs, self.description) + return make_doc_string( + self.inputs, + self.intermediates_inputs, + self.outputs, + self.description, + class_name=self.__class__.__name__, + expected_components=self.expected_components, + expected_configs=self.expected_configs + ) class SequentialPipelineBlocks: """ @@ -1166,34 +1243,27 @@ def __repr__(self): desc.extend(f" {line}" for line in desc_lines[1:]) desc = '\n'.join(desc) + '\n' - # Components section - focus only on expected components + # Components section - use format_components with add_empty_lines=False expected_components = getattr(self, "expected_components", []) - expected_components_str_list = [] + components_str = format_components(expected_components, indent_level=2, add_empty_lines=False) - for component_spec in expected_components: - - component_str = f" - {component_spec.name} ({component_spec.type_hint.__name__})" - - # Add repo info if available - if component_spec.default_repo: - if isinstance(component_spec.default_repo, list) and len(component_spec.default_repo) == 2: - repo_info = component_spec.default_repo[0] - subfolder = component_spec.default_repo[1] - if subfolder: - repo_info += f", subfolder={subfolder}" - else: - repo_info = component_spec.default_repo - component_str += f" [{repo_info}]" - - expected_components_str_list.append(component_str) - - components_str = " Components:\n" + "\n".join(expected_components_str_list) - - # Configs section - focus only on expected configs + # Configs section - use format_configs with add_empty_lines=False expected_configs = getattr(self, "expected_configs", []) - configs_str = " Configs:\n" + "\n".join(f" - {config.name}" for config in sorted(expected_configs, key=lambda x: x.name)) + configs_str = format_configs(expected_configs, indent_level=2, add_empty_lines=False) - # Blocks section + # Inputs and outputs section - moved up + inputs_str = format_inputs_short(self.inputs) + inputs_str = " Inputs:\n " + inputs_str + + outputs = [out.name for out in self.outputs] + intermediates_str = format_intermediates_short(self.intermediates_inputs, self.required_intermediates_inputs, self.intermediates_outputs) + intermediates_str = ( + " Intermediates:\n" + f"{intermediates_str}\n" + f" - final outputs: {', '.join(outputs)}" + ) + + # Blocks section - moved to the end with simplified format blocks_str = " Blocks:\n" for i, (name, block) in enumerate(self.blocks.items()): # Get trigger input for this block @@ -1218,53 +1288,31 @@ def __repr__(self): indented_desc = desc_lines[0] if len(desc_lines) > 1: indented_desc += '\n' + '\n'.join(' ' + line for line in desc_lines[1:]) - blocks_str += f" Description: {indented_desc}\n" - - # Format inputs - inputs_str = format_inputs_short(block.inputs) - blocks_str += f" inputs: {inputs_str}\n" - - # Format intermediates - intermediates_str = format_intermediates_short( - block.intermediates_inputs, - block.required_intermediates_inputs, - block.intermediates_outputs - ) - if intermediates_str != " (none)": - blocks_str += " intermediates:\n" - indented_intermediates = "\n".join( - " " + line for line in intermediates_str.split("\n") - ) - blocks_str += f"{indented_intermediates}\n" - blocks_str += "\n" - - # Inputs and outputs section - inputs_str = format_inputs_short(self.inputs) - inputs_str = " Inputs:\n " + inputs_str - outputs = [out.name for out in self.outputs] - - intermediates_str = format_intermediates_short(self.intermediates_inputs, self.required_intermediates_inputs, self.intermediates_outputs) - intermediates_str = ( - "\n Intermediates:\n" - f"{intermediates_str}\n" - f" - final outputs: {', '.join(outputs)}" - ) + blocks_str += f" Description: {indented_desc}\n\n" return ( f"{header}\n" f"{desc}" f"{components_str}\n" f"{configs_str}\n" - f"{blocks_str}\n" f"{inputs_str}\n" f"{intermediates_str}\n" + f"{blocks_str}" f")" ) @property def doc(self): - return make_doc_string(self.inputs, self.intermediates_inputs, self.outputs, self.description) + return make_doc_string( + self.inputs, + self.intermediates_inputs, + self.outputs, + self.description, + class_name=self.__class__.__name__, + expected_components=self.expected_components, + expected_configs=self.expected_configs + ) class ModularPipeline(ConfigMixin): """ @@ -1486,64 +1534,6 @@ def default_call_parameters(self) -> Dict[str, Any]: params[input_param.name] = input_param.default return params - # def __repr__(self): - # output = "ModularPipeline:\n" - # output += "==============================\n\n" - - # block = self.pipeline_block - - # # List the pipeline block structure first - # output += "Pipeline Block:\n" - # output += "--------------\n" - # if hasattr(block, "blocks"): - # output += f"{block.__class__.__name__}\n" - # base_class = block.__class__.__bases__[0].__name__ - # output += f" (Class: {base_class})\n" if base_class != "object" else "\n" - # for sub_block_name, sub_block in block.blocks.items(): - # if hasattr(block, "block_trigger_inputs"): - # trigger_input = block.block_to_trigger_map[sub_block_name] - # trigger_info = f" [trigger: {trigger_input}]" if trigger_input is not None else " [default]" - # output += f" • {sub_block_name} ({sub_block.__class__.__name__}){trigger_info}\n" - # else: - # output += f" • {sub_block_name} ({sub_block.__class__.__name__})\n" - # else: - # output += f"{block.__class__.__name__}\n" - # output += "\n" - - # # List the components registered in the pipeline - # output += "Registered Components:\n" - # output += "----------------------\n" - # for name, component in self.components.items(): - # output += f"{name}: {type(component).__name__}" - # if hasattr(component, "dtype") and hasattr(component, "device"): - # output += f" (dtype={component.dtype}, device={component.device})" - # output += "\n" - # output += "\n" - - # # List the configs registered in the pipeline - # output += "Registered Configs:\n" - # output += "------------------\n" - # for name, config in self.config.items(): - # output += f"{name}: {config!r}\n" - # output += "\n" - - # # Add auto blocks section - # if hasattr(block, "trigger_inputs") and block.trigger_inputs: - # output += "------------------\n" - # output += "This pipeline contains blocks that are selected at runtime based on inputs.\n\n" - # output += f"Trigger Inputs: {block.trigger_inputs}\n" - # # Get first trigger input as example - # example_input = next(t for t in block.trigger_inputs if t is not None) - # output += f" Use `get_execution_blocks()` with input names to see selected blocks (e.g. `get_execution_blocks('{example_input}')`).\n" - # output += "Check `.doc` of returned object for more information.\n\n" - - # # List the call parameters - # full_doc = self.pipeline_block.doc - # if "------------------------" in full_doc: - # full_doc = full_doc.split("------------------------")[0].rstrip() - # output += full_doc - - # return output # YiYi TODO: try to unify the to method with the one in DiffusionPipeline # Modified from diffusers.pipelines.pipeline_utils.DiffusionPipeline.to From 01e9a5a941ee4d3d23934e33afe4654be7229470 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Tue, 22 Apr 2025 10:37:45 +0200 Subject: [PATCH 20/39] move formating utilitys to modular_pipeline_util.py --- src/diffusers/pipelines/modular_pipeline.py | 313 +--------------- .../pipelines/modular_pipeline_util.py | 345 ++++++++++++++++++ 2 files changed, 355 insertions(+), 303 deletions(-) create mode 100644 src/diffusers/pipelines/modular_pipeline_util.py diff --git a/src/diffusers/pipelines/modular_pipeline.py b/src/diffusers/pipelines/modular_pipeline.py index b896066edf08..724c31f374ef 100644 --- a/src/diffusers/pipelines/modular_pipeline.py +++ b/src/diffusers/pipelines/modular_pipeline.py @@ -30,6 +30,16 @@ logging, ) from .pipeline_loading_utils import _get_pipeline_class +from .modular_pipeline_util import ( + format_components, + format_configs, + format_input_params, + format_inputs_short, + format_intermediates_short, + format_output_params, + format_params, + make_doc_string, +) if is_accelerate_available(): @@ -176,309 +186,6 @@ class OutputParam: def __repr__(self): return f"<{self.name}: {self.type_hint.__name__ if hasattr(self.type_hint, '__name__') else str(self.type_hint)}>" -def format_inputs_short(inputs): - """ - Format input parameters into a string representation, with required params first followed by optional ones. - - Args: - inputs: List of input parameters with 'required' and 'name' attributes, and 'default' for optional params - - Returns: - str: Formatted string of input parameters - - Example: - >>> inputs = [ - ... InputParam(name="prompt", required=True), - ... InputParam(name="image", required=True), - ... InputParam(name="guidance_scale", required=False, default=7.5), - ... InputParam(name="num_inference_steps", required=False, default=50) - ... ] - >>> format_inputs_short(inputs) - 'prompt, image, guidance_scale=7.5, num_inference_steps=50' - """ - required_inputs = [param for param in inputs if param.required] - optional_inputs = [param for param in inputs if not param.required] - - required_str = ", ".join(param.name for param in required_inputs) - optional_str = ", ".join(f"{param.name}={param.default}" for param in optional_inputs) - - inputs_str = required_str - if optional_str: - inputs_str = f"{inputs_str}, {optional_str}" if required_str else optional_str - - return inputs_str - - -def format_intermediates_short(intermediates_inputs: List[InputParam], required_intermediates_inputs: List[str], intermediates_outputs: List[OutputParam]) -> str: - """ - Formats intermediate inputs and outputs of a block into a string representation. - - Args: - intermediates_inputs: List of intermediate input parameters - required_intermediates_inputs: List of required intermediate input names - intermediates_outputs: List of intermediate output parameters - - Returns: - str: Formatted string like: - Intermediates: - - inputs: Required(latents), dtype - - modified: latents # variables that appear in both inputs and outputs - - outputs: images # new outputs only - """ - # Handle inputs - input_parts = [] - for inp in intermediates_inputs: - if inp.name in required_intermediates_inputs: - input_parts.append(f"Required({inp.name})") - else: - input_parts.append(inp.name) - - # Handle modified variables (appear in both inputs and outputs) - inputs_set = {inp.name for inp in intermediates_inputs} - modified_parts = [] - new_output_parts = [] - - for out in intermediates_outputs: - if out.name in inputs_set: - modified_parts.append(out.name) - else: - new_output_parts.append(out.name) - - result = [] - if input_parts: - result.append(f" - inputs: {', '.join(input_parts)}") - if modified_parts: - result.append(f" - modified: {', '.join(modified_parts)}") - if new_output_parts: - result.append(f" - outputs: {', '.join(new_output_parts)}") - - return "\n".join(result) if result else " (none)" - - -def format_params(params: List[Union[InputParam, OutputParam]], header: str = "Args", indent_level: int = 4, max_line_length: int = 115) -> str: - """Format a list of InputParam or OutputParam objects into a readable string representation. - - Args: - params: List of InputParam or OutputParam objects to format - header: Header text to use (e.g. "Args" or "Returns") - indent_level: Number of spaces to indent each parameter line (default: 4) - max_line_length: Maximum length for each line before wrapping (default: 115) - - Returns: - A formatted string representing all parameters - """ - if not params: - return "" - - base_indent = " " * indent_level - param_indent = " " * (indent_level + 4) - desc_indent = " " * (indent_level + 8) - formatted_params = [] - - def get_type_str(type_hint): - if hasattr(type_hint, "__origin__") and type_hint.__origin__ is Union: - types = [t.__name__ if hasattr(t, "__name__") else str(t) for t in type_hint.__args__] - return f"Union[{', '.join(types)}]" - return type_hint.__name__ if hasattr(type_hint, "__name__") else str(type_hint) - - def wrap_text(text: str, indent: str, max_length: int) -> str: - """Wrap text while preserving markdown links and maintaining indentation.""" - words = text.split() - lines = [] - current_line = [] - current_length = 0 - - for word in words: - word_length = len(word) + (1 if current_line else 0) - - if current_line and current_length + word_length > max_length: - lines.append(" ".join(current_line)) - current_line = [word] - current_length = len(word) - else: - current_line.append(word) - current_length += word_length - - if current_line: - lines.append(" ".join(current_line)) - - return f"\n{indent}".join(lines) - - # Add the header - formatted_params.append(f"{base_indent}{header}:") - - for param in params: - # Format parameter name and type - type_str = get_type_str(param.type_hint) if param.type_hint != Any else "" - param_str = f"{param_indent}{param.name} (`{type_str}`" - - # Add optional tag and default value if parameter is an InputParam and optional - if isinstance(param, InputParam): - if not param.required: - param_str += ", *optional*" - if param.default is not None: - param_str += f", defaults to {param.default}" - param_str += "):" - - # Add description on a new line with additional indentation and wrapping - if param.description: - desc = re.sub( - r'\[(.*?)\]\((https?://[^\s\)]+)\)', - r'[\1](\2)', - param.description - ) - wrapped_desc = wrap_text(desc, desc_indent, max_line_length) - param_str += f"\n{desc_indent}{wrapped_desc}" - - formatted_params.append(param_str) - - return "\n\n".join(formatted_params) - -# Then update the original functions to use this combined version: -def format_input_params(input_params: List[InputParam], indent_level: int = 4, max_line_length: int = 115) -> str: - return format_params(input_params, "Inputs", indent_level, max_line_length) - -def format_output_params(output_params: List[OutputParam], indent_level: int = 4, max_line_length: int = 115) -> str: - return format_params(output_params, "Outputs", indent_level, max_line_length) - - -def format_components(components: List[ComponentSpec], indent_level: int = 4, max_line_length: int = 115, add_empty_lines: bool = True) -> str: - """Format a list of ComponentSpec objects into a readable string representation. - - Args: - components: List of ComponentSpec objects to format - indent_level: Number of spaces to indent each component line (default: 4) - max_line_length: Maximum length for each line before wrapping (default: 115) - add_empty_lines: Whether to add empty lines between components (default: True) - - Returns: - A formatted string representing all components - """ - if not components: - return "" - - base_indent = " " * indent_level - component_indent = " " * (indent_level + 4) - formatted_components = [] - - # Add the header - formatted_components.append(f"{base_indent}Components:") - if add_empty_lines: - formatted_components.append("") - - # Add each component with optional empty lines between them - for i, component in enumerate(components): - # Get type name, handling special cases - type_name = component.type_hint.__name__ if hasattr(component.type_hint, "__name__") else str(component.type_hint) - - component_desc = f"{component_indent}{component.name} (`{type_name}`)" - if component.description: - component_desc += f": {component.description}" - if component.default_repo: - if isinstance(component.default_repo, list) and len(component.default_repo) == 2: - repo_info = component.default_repo[0] - subfolder = component.default_repo[1] - if subfolder: - repo_info += f", subfolder={subfolder}" - else: - repo_info = component.default_repo - component_desc += f" [{repo_info}]" - formatted_components.append(component_desc) - - # Add an empty line after each component except the last one - if add_empty_lines and i < len(components) - 1: - formatted_components.append("") - - return "\n".join(formatted_components) - - -def format_configs(configs: List[ConfigSpec], indent_level: int = 4, max_line_length: int = 115, add_empty_lines: bool = True) -> str: - """Format a list of ConfigSpec objects into a readable string representation. - - Args: - configs: List of ConfigSpec objects to format - indent_level: Number of spaces to indent each config line (default: 4) - max_line_length: Maximum length for each line before wrapping (default: 115) - add_empty_lines: Whether to add empty lines between configs (default: True) - - Returns: - A formatted string representing all configs - """ - if not configs: - return "" - - base_indent = " " * indent_level - config_indent = " " * (indent_level + 4) - formatted_configs = [] - - # Add the header - formatted_configs.append(f"{base_indent}Configs:") - if add_empty_lines: - formatted_configs.append("") - - # Add each config with optional empty lines between them - for i, config in enumerate(configs): - config_desc = f"{config_indent}{config.name} (default: {config.default})" - if config.description: - config_desc += f": {config.description}" - formatted_configs.append(config_desc) - - # Add an empty line after each config except the last one - if add_empty_lines and i < len(configs) - 1: - formatted_configs.append("") - - return "\n".join(formatted_configs) - - -def make_doc_string(inputs, intermediates_inputs, outputs, description="", class_name=None, expected_components=None, expected_configs=None): - """ - Generates a formatted documentation string describing the pipeline block's parameters and structure. - - Args: - inputs (List[InputParam]): List of input parameters - intermediates_inputs (List[InputParam]): List of intermediate input parameters - outputs (List[OutputParam]): List of output parameters - description (str, *optional*): Description of the block - class_name (str, *optional*): Name of the class to include in the documentation - expected_components (List[ComponentSpec], *optional*): List of expected components - expected_configs (List[ConfigSpec], *optional*): List of expected configurations - - Returns: - str: A formatted string containing information about components, configs, call parameters, - intermediate inputs/outputs, and final outputs. - """ - output = "" - - # Add class name if provided - if class_name: - output += f"class {class_name}\n\n" - - # Add description - if description: - desc_lines = description.strip().split('\n') - aligned_desc = '\n'.join(' ' + line for line in desc_lines) - output += aligned_desc + "\n\n" - - # Add components section if provided - if expected_components and len(expected_components) > 0: - components_str = format_components(expected_components, indent_level=2) - output += components_str + "\n\n" - - # Add configs section if provided - if expected_configs and len(expected_configs) > 0: - configs_str = format_configs(expected_configs, indent_level=2) - output += configs_str + "\n\n" - - # Add inputs section - output += format_input_params(inputs + intermediates_inputs, indent_level=2) - - # Add outputs section - output += "\n\n" - output += format_output_params(outputs, indent_level=2) - - return output - - class PipelineBlock: diff --git a/src/diffusers/pipelines/modular_pipeline_util.py b/src/diffusers/pipelines/modular_pipeline_util.py new file mode 100644 index 000000000000..fb6b83c7eee0 --- /dev/null +++ b/src/diffusers/pipelines/modular_pipeline_util.py @@ -0,0 +1,345 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import re +from typing import Any, Dict, List, Union + +from ..utils.import_utils import is_torch_available + +if is_torch_available(): + import torch + + +def format_inputs_short(inputs): + """ + Format input parameters into a string representation, with required params first followed by optional ones. + + Args: + inputs: List of input parameters with 'required' and 'name' attributes, and 'default' for optional params + + Returns: + str: Formatted string of input parameters + + Example: + >>> inputs = [ + ... InputParam(name="prompt", required=True), + ... InputParam(name="image", required=True), + ... InputParam(name="guidance_scale", required=False, default=7.5), + ... InputParam(name="num_inference_steps", required=False, default=50) + ... ] + >>> format_inputs_short(inputs) + 'prompt, image, guidance_scale=7.5, num_inference_steps=50' + """ + required_inputs = [param for param in inputs if param.required] + optional_inputs = [param for param in inputs if not param.required] + + required_str = ", ".join(param.name for param in required_inputs) + optional_str = ", ".join(f"{param.name}={param.default}" for param in optional_inputs) + + inputs_str = required_str + if optional_str: + inputs_str = f"{inputs_str}, {optional_str}" if required_str else optional_str + + return inputs_str + + +def format_intermediates_short(intermediates_inputs, required_intermediates_inputs, intermediates_outputs): + """ + Formats intermediate inputs and outputs of a block into a string representation. + + Args: + intermediates_inputs: List of intermediate input parameters + required_intermediates_inputs: List of required intermediate input names + intermediates_outputs: List of intermediate output parameters + + Returns: + str: Formatted string like: + Intermediates: + - inputs: Required(latents), dtype + - modified: latents # variables that appear in both inputs and outputs + - outputs: images # new outputs only + """ + # Handle inputs + input_parts = [] + for inp in intermediates_inputs: + if inp.name in required_intermediates_inputs: + input_parts.append(f"Required({inp.name})") + else: + input_parts.append(inp.name) + + # Handle modified variables (appear in both inputs and outputs) + inputs_set = {inp.name for inp in intermediates_inputs} + modified_parts = [] + new_output_parts = [] + + for out in intermediates_outputs: + if out.name in inputs_set: + modified_parts.append(out.name) + else: + new_output_parts.append(out.name) + + result = [] + if input_parts: + result.append(f" - inputs: {', '.join(input_parts)}") + if modified_parts: + result.append(f" - modified: {', '.join(modified_parts)}") + if new_output_parts: + result.append(f" - outputs: {', '.join(new_output_parts)}") + + return "\n".join(result) if result else " (none)" + + +def format_params(params, header="Args", indent_level=4, max_line_length=115): + """Format a list of InputParam or OutputParam objects into a readable string representation. + + Args: + params: List of InputParam or OutputParam objects to format + header: Header text to use (e.g. "Args" or "Returns") + indent_level: Number of spaces to indent each parameter line (default: 4) + max_line_length: Maximum length for each line before wrapping (default: 115) + + Returns: + A formatted string representing all parameters + """ + if not params: + return "" + + base_indent = " " * indent_level + param_indent = " " * (indent_level + 4) + desc_indent = " " * (indent_level + 8) + formatted_params = [] + + def get_type_str(type_hint): + if hasattr(type_hint, "__origin__") and type_hint.__origin__ is Union: + types = [t.__name__ if hasattr(t, "__name__") else str(t) for t in type_hint.__args__] + return f"Union[{', '.join(types)}]" + return type_hint.__name__ if hasattr(type_hint, "__name__") else str(type_hint) + + def wrap_text(text, indent, max_length): + """Wrap text while preserving markdown links and maintaining indentation.""" + words = text.split() + lines = [] + current_line = [] + current_length = 0 + + for word in words: + word_length = len(word) + (1 if current_line else 0) + + if current_line and current_length + word_length > max_length: + lines.append(" ".join(current_line)) + current_line = [word] + current_length = len(word) + else: + current_line.append(word) + current_length += word_length + + if current_line: + lines.append(" ".join(current_line)) + + return f"\n{indent}".join(lines) + + # Add the header + formatted_params.append(f"{base_indent}{header}:") + + for param in params: + # Format parameter name and type + type_str = get_type_str(param.type_hint) if param.type_hint != Any else "" + param_str = f"{param_indent}{param.name} (`{type_str}`" + + # Add optional tag and default value if parameter is an InputParam and optional + if hasattr(param, "required"): + if not param.required: + param_str += ", *optional*" + if param.default is not None: + param_str += f", defaults to {param.default}" + param_str += "):" + + # Add description on a new line with additional indentation and wrapping + if param.description: + desc = re.sub( + r'\[(.*?)\]\((https?://[^\s\)]+)\)', + r'[\1](\2)', + param.description + ) + wrapped_desc = wrap_text(desc, desc_indent, max_line_length) + param_str += f"\n{desc_indent}{wrapped_desc}" + + formatted_params.append(param_str) + + return "\n\n".join(formatted_params) + + +def format_input_params(input_params, indent_level=4, max_line_length=115): + """Format a list of InputParam objects into a readable string representation. + + Args: + input_params: List of InputParam objects to format + indent_level: Number of spaces to indent each parameter line (default: 4) + max_line_length: Maximum length for each line before wrapping (default: 115) + + Returns: + A formatted string representing all input parameters + """ + return format_params(input_params, "Inputs", indent_level, max_line_length) + + +def format_output_params(output_params, indent_level=4, max_line_length=115): + """Format a list of OutputParam objects into a readable string representation. + + Args: + output_params: List of OutputParam objects to format + indent_level: Number of spaces to indent each parameter line (default: 4) + max_line_length: Maximum length for each line before wrapping (default: 115) + + Returns: + A formatted string representing all output parameters + """ + return format_params(output_params, "Outputs", indent_level, max_line_length) + + +def format_components(components, indent_level=4, max_line_length=115, add_empty_lines=True): + """Format a list of ComponentSpec objects into a readable string representation. + + Args: + components: List of ComponentSpec objects to format + indent_level: Number of spaces to indent each component line (default: 4) + max_line_length: Maximum length for each line before wrapping (default: 115) + add_empty_lines: Whether to add empty lines between components (default: True) + + Returns: + A formatted string representing all components + """ + if not components: + return "" + + base_indent = " " * indent_level + component_indent = " " * (indent_level + 4) + formatted_components = [] + + # Add the header + formatted_components.append(f"{base_indent}Components:") + if add_empty_lines: + formatted_components.append("") + + # Add each component with optional empty lines between them + for i, component in enumerate(components): + # Get type name, handling special cases + type_name = component.type_hint.__name__ if hasattr(component.type_hint, "__name__") else str(component.type_hint) + + component_desc = f"{component_indent}{component.name} (`{type_name}`)" + if component.description: + component_desc += f": {component.description}" + if component.default_repo: + if isinstance(component.default_repo, list) and len(component.default_repo) == 2: + repo_info = component.default_repo[0] + subfolder = component.default_repo[1] + if subfolder: + repo_info += f", subfolder={subfolder}" + else: + repo_info = component.default_repo + component_desc += f" [{repo_info}]" + formatted_components.append(component_desc) + + # Add an empty line after each component except the last one + if add_empty_lines and i < len(components) - 1: + formatted_components.append("") + + return "\n".join(formatted_components) + + +def format_configs(configs, indent_level=4, max_line_length=115, add_empty_lines=True): + """Format a list of ConfigSpec objects into a readable string representation. + + Args: + configs: List of ConfigSpec objects to format + indent_level: Number of spaces to indent each config line (default: 4) + max_line_length: Maximum length for each line before wrapping (default: 115) + add_empty_lines: Whether to add empty lines between configs (default: True) + + Returns: + A formatted string representing all configs + """ + if not configs: + return "" + + base_indent = " " * indent_level + config_indent = " " * (indent_level + 4) + formatted_configs = [] + + # Add the header + formatted_configs.append(f"{base_indent}Configs:") + if add_empty_lines: + formatted_configs.append("") + + # Add each config with optional empty lines between them + for i, config in enumerate(configs): + config_desc = f"{config_indent}{config.name} (default: {config.default})" + if config.description: + config_desc += f": {config.description}" + formatted_configs.append(config_desc) + + # Add an empty line after each config except the last one + if add_empty_lines and i < len(configs) - 1: + formatted_configs.append("") + + return "\n".join(formatted_configs) + + +def make_doc_string(inputs, intermediates_inputs, outputs, description="", class_name=None, expected_components=None, expected_configs=None): + """ + Generates a formatted documentation string describing the pipeline block's parameters and structure. + + Args: + inputs: List of input parameters + intermediates_inputs: List of intermediate input parameters + outputs: List of output parameters + description (str, *optional*): Description of the block + class_name (str, *optional*): Name of the class to include in the documentation + expected_components (List[ComponentSpec], *optional*): List of expected components + expected_configs (List[ConfigSpec], *optional*): List of expected configurations + + Returns: + str: A formatted string containing information about components, configs, call parameters, + intermediate inputs/outputs, and final outputs. + """ + output = "" + + # Add class name if provided + if class_name: + output += f"class {class_name}\n\n" + + # Add description + if description: + desc_lines = description.strip().split('\n') + aligned_desc = '\n'.join(' ' + line for line in desc_lines) + output += aligned_desc + "\n\n" + + # Add components section if provided + if expected_components and len(expected_components) > 0: + components_str = format_components(expected_components, indent_level=2) + output += components_str + "\n\n" + + # Add configs section if provided + if expected_configs and len(expected_configs) > 0: + configs_str = format_configs(expected_configs, indent_level=2) + output += configs_str + "\n\n" + + # Add inputs section + output += format_input_params(inputs + intermediates_inputs, indent_level=2) + + # Add outputs section + output += "\n\n" + output += format_output_params(outputs, indent_level=2) + + return output \ No newline at end of file From 170a3c57367a68fb8f4cbc2ad382457fcc6e6ed6 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Tue, 22 Apr 2025 18:48:20 +0200 Subject: [PATCH 21/39] attemp to break ModularPipeline base into componentstate and a pipelineblock mixin --- src/diffusers/pipelines/modular_pipeline.py | 348 +++++------------- .../pipeline_stable_diffusion_xl_modular.py | 2 +- 2 files changed, 96 insertions(+), 254 deletions(-) diff --git a/src/diffusers/pipelines/modular_pipeline.py b/src/diffusers/pipelines/modular_pipeline.py index 724c31f374ef..09736712a28b 100644 --- a/src/diffusers/pipelines/modular_pipeline.py +++ b/src/diffusers/pipelines/modular_pipeline.py @@ -1021,37 +1021,111 @@ def doc(self): expected_configs=self.expected_configs ) -class ModularPipeline(ConfigMixin): + + +class ModularPipelineMixin: """ - Base class for all Modular pipelines. + Mixin for all PipelineBlocks: PipelineBlock, AutoPipelineBlocks, SequentialPipelineBlocks + """ + + + def __init__(self): + self.components_manager = None + self.components_manager_prefix = "" + self.components_state = None + + # YiYi TODO: not sure this is the best method name + def compile(self, components_manager: ComponentsManager, label: Optional[str] = None): + self.components_manager = components_manager + self.components_manager_prefix = "" if label is None else f"{label}_" + self.components_state = ComponentsState(self.expected_components, self.expected_configs) + + components_to_add = self.components_manager.get(f"{self.components_manager_prefix}*") + self.components_state.update_states(self.expected_components, self.expected_configs, **components_to_add) + + + @property + def default_call_parameters(self) -> Dict[str, Any]: + params = {} + for input_param in self.inputs: + params[input_param.name] = input_param.default + return params + + def run(self, state: PipelineState = None, output: Union[str, List[str]] = None, **kwargs): + """ + Run one or more blocks in sequence, optionally you can pass a previous pipeline state. + """ + if state is None: + state = PipelineState() + + # Make a copy of the input kwargs + input_params = kwargs.copy() + + default_params = self.default_call_parameters + + # Add inputs to state, using defaults if not provided in the kwargs or the state + # if same input already in the state, will override it if provided in the kwargs + + intermediates_inputs = [inp.name for inp in self.intermediates_inputs] + for name, default in default_params.items(): + if name in input_params: + if name not in intermediates_inputs: + state.add_input(name, input_params.pop(name)) + else: + state.add_input(name, input_params[name]) + elif name not in state.inputs: + state.add_input(name, default) + + for name in intermediates_inputs: + if name in input_params: + state.add_intermediate(name, input_params.pop(name)) + + # Warn about unexpected inputs + if len(input_params) > 0: + logger.warning(f"Unexpected input '{input_params.keys()}' provided. This input will be ignored.") + # Run the pipeline + with torch.no_grad(): + try: + pipeline, state = self(self, state) + except Exception: + error_msg = f"Error in block: ({self.__class__.__name__}):\n" + logger.error(error_msg) + raise + + if output is None: + return state + + + elif isinstance(output, str): + return state.get_intermediate(output) + + elif isinstance(output, (list, tuple)): + return state.get_intermediates(output) + else: + raise ValueError(f"Output '{output}' is not a valid output type") + +class ComponentsState(ConfigMixin): """ + Base class for all Modular pipelines. + """ config_name = "model_index.json" - _exclude_from_cpu_offload = [] - def __init__(self, block): - self.pipeline_block = block + def __init__(self, component_specs, config_specs): - for component_spec in self.expected_components: + for component_spec in component_specs: if component_spec.obj is not None: setattr(self, component_spec.name, component_spec.obj) else: setattr(self, component_spec.name, None) default_configs = {} - for config_spec in self.expected_configs: + for config_spec in config_specs: default_configs[config_spec.name] = config_spec.default self.register_to_config(**default_configs) - @classmethod - def from_block(cls, block): - modular_pipeline_class_name = MODULAR_PIPELINE_MAPPING[block.model_name] - modular_pipeline_class = _get_pipeline_class(cls, class_name=modular_pipeline_class_name) - - return modular_pipeline_class(block) - @property def device(self) -> torch.device: r""" @@ -1089,10 +1163,7 @@ def _execution_device(self): return torch.device(module._hf_hook.execution_device) return self.device - - def get_execution_blocks(self, *trigger_inputs): - return self.pipeline_block.get_execution_blocks(*trigger_inputs) - + @property def dtype(self) -> torch.dtype: r""" @@ -1107,13 +1178,6 @@ def dtype(self) -> torch.dtype: return torch.float32 - @property - def expected_components(self): - return self.pipeline_block.expected_components - - @property - def expected_configs(self): - return self.pipeline_block.expected_configs @property def components(self): @@ -1123,80 +1187,7 @@ def components(self): components[component_spec.name] = getattr(self, component_spec.name) return components - # Copied from diffusers.pipelines.pipeline_utils.DiffusionPipeline.progress_bar - def progress_bar(self, iterable=None, total=None): - if not hasattr(self, "_progress_bar_config"): - self._progress_bar_config = {} - elif not isinstance(self._progress_bar_config, dict): - raise ValueError( - f"`self._progress_bar_config` should be of type `dict`, but is {type(self._progress_bar_config)}." - ) - - if iterable is not None: - return tqdm(iterable, **self._progress_bar_config) - elif total is not None: - return tqdm(total=total, **self._progress_bar_config) - else: - raise ValueError("Either `total` or `iterable` has to be defined.") - - # Copied from diffusers.pipelines.pipeline_utils.DiffusionPipeline.set_progress_bar_config - def set_progress_bar_config(self, **kwargs): - self._progress_bar_config = kwargs - - def __call__(self, state: PipelineState = None, output: Union[str, List[str]] = None, **kwargs): - """ - Run one or more blocks in sequence, optionally you can pass a previous pipeline state. - """ - if state is None: - state = PipelineState() - - # Make a copy of the input kwargs - input_params = kwargs.copy() - - default_params = self.default_call_parameters - - # Add inputs to state, using defaults if not provided in the kwargs or the state - # if same input already in the state, will override it if provided in the kwargs - - intermediates_inputs = [inp.name for inp in self.pipeline_block.intermediates_inputs] - for name, default in default_params.items(): - if name in input_params: - if name not in intermediates_inputs: - state.add_input(name, input_params.pop(name)) - else: - state.add_input(name, input_params[name]) - elif name not in state.inputs: - state.add_input(name, default) - - for name in intermediates_inputs: - if name in input_params: - state.add_intermediate(name, input_params.pop(name)) - - # Warn about unexpected inputs - if len(input_params) > 0: - logger.warning(f"Unexpected input '{input_params.keys()}' provided. This input will be ignored.") - # Run the pipeline - with torch.no_grad(): - try: - pipeline, state = self.pipeline_block(self, state) - except Exception: - error_msg = f"Error in block: ({self.pipeline_block.__class__.__name__}):\n" - logger.error(error_msg) - raise - - if output is None: - return state - - - elif isinstance(output, str): - return state.get_intermediate(output) - - elif isinstance(output, (list, tuple)): - return state.get_intermediates(output) - else: - raise ValueError(f"Output '{output}' is not a valid output type") - - def update_states(self, **kwargs): + def update_states(self, expected_components, expected_configs, **kwargs): """ Update components and configs after instance creation. Auxiliaries (e.g. image_processor) should be defined for each pipeline block, does not need to be updated by users. Logs if existing non-None components are being @@ -1206,7 +1197,7 @@ def update_states(self, **kwargs): kwargs (dict): Keyword arguments to update the states. """ - for component in self.expected_components: + for component in expected_components: if component.name in kwargs: if hasattr(self, component.name) and getattr(self, component.name) is not None: current_component = getattr(self, component.name) @@ -1226,163 +1217,14 @@ def update_states(self, **kwargs): f"with new value (type: {type(new_component).__name__})" ) - setattr(self, component.name, kwargs.pop(component.name)) + setattr(self.components_state, component.name, kwargs.pop(component.name)) configs_to_add = {} - for config in self.expected_configs: + for config in expected_configs: if config.name in kwargs: configs_to_add[config.name] = kwargs.pop(config.name) self.register_to_config(**configs_to_add) - @property - def default_call_parameters(self) -> Dict[str, Any]: - params = {} - for input_param in self.pipeline_block.inputs: - params[input_param.name] = input_param.default - return params - - - # YiYi TODO: try to unify the to method with the one in DiffusionPipeline - # Modified from diffusers.pipelines.pipeline_utils.DiffusionPipeline.to + # YiYi TODO: should support to method def to(self, *args, **kwargs): - r""" - Performs Pipeline dtype and/or device conversion. A torch.dtype and torch.device are inferred from the - arguments of `self.to(*args, **kwargs).` - - - - If the pipeline already has the correct torch.dtype and torch.device, then it is returned as is. Otherwise, - the returned pipeline is a copy of self with the desired torch.dtype and torch.device. - - - - - Here are the ways to call `to`: - - - `to(dtype, silence_dtype_warnings=False) → DiffusionPipeline` to return a pipeline with the specified - [`dtype`](https://pytorch.org/docs/stable/tensor_attributes.html#torch.dtype) - - `to(device, silence_dtype_warnings=False) → DiffusionPipeline` to return a pipeline with the specified - [`device`](https://pytorch.org/docs/stable/tensor_attributes.html#torch.device) - - `to(device=None, dtype=None, silence_dtype_warnings=False) → DiffusionPipeline` to return a pipeline with the - specified [`device`](https://pytorch.org/docs/stable/tensor_attributes.html#torch.device) and - [`dtype`](https://pytorch.org/docs/stable/tensor_attributes.html#torch.dtype) - - Arguments: - dtype (`torch.dtype`, *optional*): - Returns a pipeline with the specified - [`dtype`](https://pytorch.org/docs/stable/tensor_attributes.html#torch.dtype) - device (`torch.Device`, *optional*): - Returns a pipeline with the specified - [`device`](https://pytorch.org/docs/stable/tensor_attributes.html#torch.device) - silence_dtype_warnings (`str`, *optional*, defaults to `False`): - Whether to omit warnings if the target `dtype` is not compatible with the target `device`. - - Returns: - [`DiffusionPipeline`]: The pipeline converted to specified `dtype` and/or `dtype`. - """ - dtype = kwargs.pop("dtype", None) - device = kwargs.pop("device", None) - silence_dtype_warnings = kwargs.pop("silence_dtype_warnings", False) - - dtype_arg = None - device_arg = None - if len(args) == 1: - if isinstance(args[0], torch.dtype): - dtype_arg = args[0] - else: - device_arg = torch.device(args[0]) if args[0] is not None else None - elif len(args) == 2: - if isinstance(args[0], torch.dtype): - raise ValueError( - "When passing two arguments, make sure the first corresponds to `device` and the second to `dtype`." - ) - device_arg = torch.device(args[0]) if args[0] is not None else None - dtype_arg = args[1] - elif len(args) > 2: - raise ValueError("Please make sure to pass at most two arguments (`device` and `dtype`) `.to(...)`") - - if dtype is not None and dtype_arg is not None: - raise ValueError( - "You have passed `dtype` both as an argument and as a keyword argument. Please only pass one of the two." - ) - - dtype = dtype or dtype_arg - - if device is not None and device_arg is not None: - raise ValueError( - "You have passed `device` both as an argument and as a keyword argument. Please only pass one of the two." - ) - - device = device or device_arg - - # throw warning if pipeline is in "offloaded"-mode but user tries to manually set to GPU. - def module_is_sequentially_offloaded(module): - if not is_accelerate_available() or is_accelerate_version("<", "0.14.0"): - return False - - return hasattr(module, "_hf_hook") and ( - isinstance(module._hf_hook, accelerate.hooks.AlignDevicesHook) - or hasattr(module._hf_hook, "hooks") - and isinstance(module._hf_hook.hooks[0], accelerate.hooks.AlignDevicesHook) - ) - - def module_is_offloaded(module): - if not is_accelerate_available() or is_accelerate_version("<", "0.17.0.dev0"): - return False - - return hasattr(module, "_hf_hook") and isinstance(module._hf_hook, accelerate.hooks.CpuOffload) - - # .to("cuda") would raise an error if the pipeline is sequentially offloaded, so we raise our own to make it clearer - pipeline_is_sequentially_offloaded = any( - module_is_sequentially_offloaded(module) for _, module in self.components.items() - ) - if pipeline_is_sequentially_offloaded and device and torch.device(device).type == "cuda": - raise ValueError( - "It seems like you have activated sequential model offloading by calling `enable_sequential_cpu_offload`, but are now attempting to move the pipeline to GPU. This is not compatible with offloading. Please, move your pipeline `.to('cpu')` or consider removing the move altogether if you use sequential offloading." - ) - - is_pipeline_device_mapped = hasattr(self, "hf_device_map") and self.hf_device_map is not None and len(self.hf_device_map) > 1 - if is_pipeline_device_mapped: - raise ValueError( - "It seems like you have activated a device mapping strategy on the pipeline which doesn't allow explicit device placement using `to()`. You can call `reset_device_map()` first and then call `to()`." - ) - - # Display a warning in this case (the operation succeeds but the benefits are lost) - pipeline_is_offloaded = any(module_is_offloaded(module) for _, module in self.components.items()) - if pipeline_is_offloaded and device and torch.device(device).type == "cuda": - logger.warning( - f"It seems like you have activated model offloading by calling `enable_model_cpu_offload`, but are now manually moving the pipeline to GPU. It is strongly recommended against doing so as memory gains from offloading are likely to be lost. Offloading automatically takes care of moving the individual components {', '.join(self.components.keys())} to GPU when needed. To make sure offloading works as expected, you should consider moving the pipeline back to CPU: `pipeline.to('cpu')` or removing the move altogether if you use offloading." - ) - - modules = [m for m in self.components.values() if isinstance(m, torch.nn.Module)] - - is_offloaded = pipeline_is_offloaded or pipeline_is_sequentially_offloaded - for module in modules: - is_loaded_in_8bit = hasattr(module, "is_loaded_in_8bit") and module.is_loaded_in_8bit - - if is_loaded_in_8bit and dtype is not None: - logger.warning( - f"The module '{module.__class__.__name__}' has been loaded in 8bit and conversion to {dtype} is not yet supported. Module is still in 8bit precision." - ) - - if is_loaded_in_8bit and device is not None: - logger.warning( - f"The module '{module.__class__.__name__}' has been loaded in 8bit and moving it to {dtype} via `.to()` is not yet supported. Module is still on {module.device}." - ) - else: - module.to(device, dtype) - - if ( - module.dtype == torch.float16 - and str(device) in ["cpu"] - and not silence_dtype_warnings - and not is_offloaded - ): - logger.warning( - "Pipelines loaded with `dtype=torch.float16` cannot run with `cpu` device. It" - " is not recommended to move them to `cpu` as running them will fail. Please make" - " sure to use an accelerator to run the pipeline in inference, due to the lack of" - " support for`float16` operations on this device in PyTorch. Please, remove the" - " `torch_dtype=torch.float16` argument, or use another device for inference." - ) - return self + pass diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py index 37d2bbbe6ca6..882ec2e18552 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py @@ -3571,7 +3571,7 @@ def description(self): # YiYi TODO: rename to components etc. and not inherit from ModularPipeline -class StableDiffusionXLModularPipeline( +class StableDiffusionXLComponentStates( ModularPipeline, StableDiffusionMixin, TextualInversionLoaderMixin, From e38f09ba97a87596ef2ded5485a824628d8d82b5 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Wed, 23 Apr 2025 19:42:21 +0200 Subject: [PATCH 22/39] update components manageer: add -collection arg, allow subfoldeer arg in from_pretrained --- src/diffusers/pipelines/components_manager.py | 55 ++++++++++++++----- ...line_util.py => modular_pipeline_utils.py} | 0 2 files changed, 40 insertions(+), 15 deletions(-) rename src/diffusers/pipelines/{modular_pipeline_util.py => modular_pipeline_utils.py} (100%) diff --git a/src/diffusers/pipelines/components_manager.py b/src/diffusers/pipelines/components_manager.py index 8c14321ccfac..4bd4c23c281a 100644 --- a/src/diffusers/pipelines/components_manager.py +++ b/src/diffusers/pipelines/components_manager.py @@ -26,6 +26,7 @@ logging, ) from ..models.modeling_utils import ModelMixin +from .modular_pipeline_utils import ComponentSpec if is_accelerate_available(): @@ -232,26 +233,36 @@ def search_best_candidate(module_sizes, min_memory_offload): class ComponentsManager: def __init__(self): self.components = OrderedDict() - self.added_time = OrderedDict() # Store when components were added + self.added_time = OrderedDict() # Store when components were added + self.collections = OrderedDict() # collection_name -> set of component_names self.model_hooks = None self._auto_offload_enabled = False - def add(self, name, component): + def add(self, name, component, collection: Optional[str] = None): if name in self.components: logger.warning(f"Overriding existing component '{name}' in ComponentsManager") self.components[name] = component self.added_time[name] = time.time() - + if collection: + if collection not in self.collections: + self.collections[collection] = set() + self.collections[collection].add(name) + if self._auto_offload_enabled: self.enable_auto_cpu_offload(self._auto_offload_device) def remove(self, name): + if name not in self.components: logger.warning(f"Component '{name}' not found in ComponentsManager") return self.components.pop(name) self.added_time.pop(name) + + for collection in self.collections: + if name in self.collections[collection]: + self.collections[collection].remove(name) if self._auto_offload_enabled: self.enable_auto_cpu_offload(self._auto_offload_device) @@ -516,7 +527,7 @@ def __repr__(self): return output - def add_from_pretrained(self, pretrained_model_name_or_path, prefix: Optional[str] = None, **kwargs): + def from_pretrained(self, pretrained_model_name_or_path, prefix: Optional[str] = None, **kwargs): """ Load components from a pretrained model and add them to the manager. @@ -526,17 +537,12 @@ def add_from_pretrained(self, pretrained_model_name_or_path, prefix: Optional[st If provided, components will be named as "{prefix}_{component_name}" **kwargs: Additional arguments to pass to DiffusionPipeline.from_pretrained() """ - from ..pipelines.pipeline_utils import DiffusionPipeline - - pipe = DiffusionPipeline.from_pretrained(pretrained_model_name_or_path, **kwargs) - for name, component in pipe.components.items(): - - if component is None: - continue - - # Add prefix if specified - component_name = f"{prefix}_{name}" if prefix else name - + subfolder = kwargs.pop("subfolder", None) + # YiYi TODO: extend auto model to support non-diffusers models + if subfolder: + from ..models import AutoModel + component = AutoModel.from_pretrained(pretrained_model_name_or_path, subfolder=subfolder, **kwargs) + component_name = f"{prefix}_{subfolder}" if prefix else subfolder if component_name not in self.components: self.add(component_name, component) else: @@ -545,6 +551,25 @@ def add_from_pretrained(self, pretrained_model_name_or_path, prefix: Optional[st f"1. remove the existing component with remove('{component_name}')\n" f"2. Use a different prefix: add_from_pretrained(..., prefix='{prefix}_2')" ) + else: + from ..pipelines.pipeline_utils import DiffusionPipeline + pipe = DiffusionPipeline.from_pretrained(pretrained_model_name_or_path, **kwargs) + for name, component in pipe.components.items(): + + if component is None: + continue + + # Add prefix if specified + component_name = f"{prefix}_{name}" if prefix else name + + if component_name not in self.components: + self.add(component_name, component) + else: + logger.warning( + f"Component '{component_name}' already exists in ComponentsManager and will not be added. To add it, either:\n" + f"1. remove the existing component with remove('{component_name}')\n" + f"2. Use a different prefix: add_from_pretrained(..., prefix='{prefix}_2')" + ) def summarize_dict_by_value_and_parts(d: Dict[str, Any]) -> Dict[str, Any]: """Summarizes a dictionary by finding common prefixes that share the same value. diff --git a/src/diffusers/pipelines/modular_pipeline_util.py b/src/diffusers/pipelines/modular_pipeline_utils.py similarity index 100% rename from src/diffusers/pipelines/modular_pipeline_util.py rename to src/diffusers/pipelines/modular_pipeline_utils.py From 2571c000547da5782e5ccd0ce448adee2a221026 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Wed, 23 Apr 2025 19:43:34 +0200 Subject: [PATCH 23/39] move componentspec, configspec, input output param to utils --- .../pipelines/modular_pipeline_utils.py | 49 ++++++++++++++++++- 1 file changed, 48 insertions(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/modular_pipeline_utils.py b/src/diffusers/pipelines/modular_pipeline_utils.py index fb6b83c7eee0..0fec1db91e90 100644 --- a/src/diffusers/pipelines/modular_pipeline_utils.py +++ b/src/diffusers/pipelines/modular_pipeline_utils.py @@ -13,14 +13,61 @@ # limitations under the License. import re -from typing import Any, Dict, List, Union +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Tuple, Type, Union from ..utils.import_utils import is_torch_available +from ..configuration_utils import FrozenDict if is_torch_available(): import torch +@dataclass +class ComponentSpec: + """Specification for a pipeline component.""" + name: str + # YiYi NOTE: is type_hint a good fild name? it is the actual class, will be used to create the default instance + type_hint: Type + description: Optional[str] = None + config: Optional[FrozenDict[str, Any]] = None # you can specific default config to create a default component if it is a stateless class like scheduler, guider or image processor + repo: Optional[Union[str, List[str]]] = None + subfolder: Optional[str] = None + revision: Optional[str] = None + variant: Optional[str] = None + +@dataclass +class ConfigSpec: + """Specification for a pipeline configuration parameter.""" + name: str + value: Any + description: Optional[str] = None + repo: Optional[Union[str, List[str]]] = None + +@dataclass +class InputParam: + """Specification for an input parameter.""" + name: str + type_hint: Any = None + default: Any = None + required: bool = False + description: str = "" + + def __repr__(self): + return f"<{self.name}: {'required' if self.required else 'optional'}, default={self.default}>" + + +@dataclass +class OutputParam: + """Specification for an output parameter.""" + name: str + type_hint: Any = None + description: str = "" + + def __repr__(self): + return f"<{self.name}: {self.type_hint.__name__ if hasattr(self.type_hint, '__name__') else str(self.type_hint)}>" + + def format_inputs_short(inputs): """ Format input parameters into a string representation, with required params first followed by optional ones. From 3b30e794d01ad5d73897362baee6d2eee482dd53 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Wed, 23 Apr 2025 19:45:13 +0200 Subject: [PATCH 24/39] modularpipeloine -> modularpipelineloader, setup_loader, make loader configmixin etc --- src/diffusers/pipelines/modular_pipeline.py | 231 +++++++++++++----- .../pipeline_stable_diffusion_xl_modular.py | 34 +-- 2 files changed, 185 insertions(+), 80 deletions(-) diff --git a/src/diffusers/pipelines/modular_pipeline.py b/src/diffusers/pipelines/modular_pipeline.py index 09736712a28b..a27883047ea2 100644 --- a/src/diffusers/pipelines/modular_pipeline.py +++ b/src/diffusers/pipelines/modular_pipeline.py @@ -30,7 +30,11 @@ logging, ) from .pipeline_loading_utils import _get_pipeline_class -from .modular_pipeline_util import ( +from .modular_pipeline_utils import ( + ComponentSpec, + ConfigSpec, + InputParam, + OutputParam, format_components, format_configs, format_input_params, @@ -41,16 +45,16 @@ make_doc_string, ) - +from copy import deepcopy if is_accelerate_available(): import accelerate logger = logging.get_logger(__name__) # pylint: disable=invalid-name -MODULAR_PIPELINE_MAPPING = OrderedDict( +MODULAR_LOADER_MAPPING = OrderedDict( [ - ("stable-diffusion-xl", "StableDiffusionXLModularPipeline"), + ("stable-diffusion-xl", "StableDiffusionXLModularLoader"), ] ) @@ -148,45 +152,6 @@ def format_value(v): return f"BlockState(\n{attributes}\n)" -@dataclass -class ComponentSpec: - """Specification for a pipeline component.""" - name: str - type_hint: Type - description: Optional[str] = None - obj: Any = None # you can create a default component if it is a stateless class like scheduler, guider or image processor - default_class_name: Union[str, List[str], Tuple[str, str]] = None # Either "class_name" or ["module", "class_name"] - default_repo: Optional[Union[str, List[str]]] = None # either "repo" or ["repo", "subfolder"] - -@dataclass -class ConfigSpec: - """Specification for a pipeline configuration parameter.""" - name: str - default: Any - description: Optional[str] = None - - -@dataclass -class InputParam: - name: str - type_hint: Any = None - default: Any = None - required: bool = False - description: str = "" - - def __repr__(self): - return f"<{self.name}: {'required' if self.required else 'optional'}, default={self.default}>" - -@dataclass -class OutputParam: - name: str - type_hint: Any = None - description: str = "" - - def __repr__(self): - return f"<{self.name}: {self.type_hint.__name__ if hasattr(self.type_hint, '__name__') else str(self.type_hint)}>" - - class PipelineBlock: model_name = None @@ -1027,21 +992,109 @@ class ModularPipelineMixin: """ Mixin for all PipelineBlocks: PipelineBlock, AutoPipelineBlocks, SequentialPipelineBlocks """ + + def register_loader(self, global_components_manager: ComponentsManager, label: Optional[str] = None): + self._global_components_manager = global_components_manager + self._label = label + + #YiYi TODO: add validation for kwargs? + def setup_loader(self, **kwargs): + """ + Set up the components loader with repository information. + + Args: + **kwargs: Configuration for component loading. + - repo: Default repository to use for all components + - For individual components, pass a tuple of (repo, subfolder) + e.g., text_encoder=("repo_name", "text_encoder") + + Examples: + # Set repo for all components (subfolder will be component name) + setup_loader(repo="stabilityai/stable-diffusion-xl-base-1.0") + + # Set specific repo/subfolder for individual components + setup_loader( + unet=("stabilityai/stable-diffusion-xl-base-1.0", "unet"), + text_encoder=("stabilityai/stable-diffusion-xl-base-1.0", "text_encoder") + ) + + # Set default repo and override for specific components + setup_loader( + repo="stabilityai/stable-diffusion-xl-base-1.0", + unet=(""stabilityai/stable-diffusion-xl-refiner-1.0", "unet") + ) + """ + # Create deep copies to avoid modifying the original specs + components_specs = deepcopy(self.expected_components) + config_specs = deepcopy(self.expected_configs) - def __init__(self): - self.components_manager = None - self.components_manager_prefix = "" - self.components_state = None - - # YiYi TODO: not sure this is the best method name - def compile(self, components_manager: ComponentsManager, label: Optional[str] = None): - self.components_manager = components_manager - self.components_manager_prefix = "" if label is None else f"{label}_" - self.components_state = ComponentsState(self.expected_components, self.expected_configs) + expected_component_names = set([c.name for c in components_specs]) + expected_config_names = set([c.name for c in config_specs]) + + # Check if a default repo is provided + repo = kwargs.pop("repo", None) + revision = kwargs.pop("revision", None) + variant = kwargs.pop("variant", None) + + passed_component_kwargs = {k: kwargs.pop(k) for k in expected_component_names if k in kwargs} + passed_config_kwargs = {k: kwargs.pop(k) for k in expected_config_names if k in kwargs} + if len(kwargs) > 0: + logger.warning(f"Unused keyword arguments: {kwargs.keys()}. This input will be ignored.") + + for name, value in passed_component_kwargs.items(): + if not isinstance(value, (tuple, list, str)): + raise ValueError(f"Invalid value for component '{name}': {value}. Expected a string, tuple or list") + elif isinstance(value, (tuple, list)) and len(value) > 2: + raise ValueError(f"Invalid value for component '{name}': {value}. Expected a tuple or list of length 1 or 2.") + + for name, value in passed_config_kwargs.items(): + if not isinstance(value, str): + raise ValueError(f"Invalid value for config '{name}': {value}. Expected a string") + + # First apply default repo to all components if provided + if repo is not None: + for component_spec in components_specs: + # components defined with a config are classes like image_processor or guider, + # skip setting loading related attributes for them, they should be initialized with the default config + if component_spec.config is None: + component_spec.repo = repo + + # YiYi TODO: should also accept `revision` and `variant` as a dict here so user can set different values for different components + if revision is not None: + component_spec.revision = revision + if variant is not None: + component_spec.variant = variant + for config_spec in config_specs: + config_spec.repo = repo - components_to_add = self.components_manager.get(f"{self.components_manager_prefix}*") - self.components_state.update_states(self.expected_components, self.expected_configs, **components_to_add) + # apply component-specific overrides + for name, value in passed_component_kwargs.items(): + if not isinstance(value, (tuple, list)): + value = (value,) + # Find the matching component spec + for component_spec in components_specs: + if component_spec.name == name: + # Handle tuple of (repo, subfolder) + component_spec.repo = value[0] + if len(value) > 1: + component_spec.subfolder = value[1] + break + + # apply config overrides + for name, value in passed_config_kwargs.items(): + for config_spec in config_specs: + if config_spec.name == name: + config_spec.repo = value + break + + # Import components loader (it is model-specific class) + loader_class_name = MODULAR_LOADER_MAPPING[self.model_name] + diffusers_module = importlib.import_module(self.__module__.split(".")[0]) + loader_class = getattr(diffusers_module, loader_class_name) + + # Create the loader with the updated specs + self.loader = loader_class(components_specs, config_specs) @property @@ -1105,24 +1158,69 @@ def run(self, state: PipelineState = None, output: Union[str, List[str]] = None, raise ValueError(f"Output '{output}' is not a valid output type") -class ComponentsState(ConfigMixin): +# YiYi NOTE: not sure if this needs to be a ConfigMixin +class ModularPipelineLoader(ConfigMixin): """ - Base class for all Modular pipelines. + Base class for all Modular pipelines loaders. """ - config_name = "model_index.json" + config_name = "modular_model_index.json" + + + def register_components(self, **kwargs): + for name, module in kwargs.items(): + + repo = self.components_specs[name].repo + subfolder = self.components_specs[name].subfolder + # retrieve library + if module is None or isinstance(module, (tuple, list)) and module[0] is None: + register_dict = {name: (None, None, (None, None))} + else: + library, class_name = _fetch_class_library_tuple(module) + register_dict = {name: (library, class_name, (repo, subfolder))} + + # save model index config + self.register_to_config(**register_dict) + + # set models + setattr(self, name, module) + + def __setattr__(self, name: str, value: Any): + if name in self.__dict__ and hasattr(self.config, name): + + repo = self.components_specs[name].repo + subfolder = self.components_specs[name].subfolder + + # We need to overwrite the config if name exists in config + if isinstance(getattr(self.config, name), (tuple, list)): + if value is not None and self.config[name][0] is not None: + library, class_name = _fetch_class_library_tuple(value) + register_dict = {name: (library, class_name, (repo, subfolder))} + else: + register_dict = {name: (None, None, (None, None))} + + self.register_to_config(**register_dict) + else: + self.register_to_config(**{name: value}) + + super().__setattr__(name, value) + def __init__(self, component_specs, config_specs): + self.components_specs = deepcopy(component_specs) + self.configs_specs = deepcopy(config_specs) + for component_spec in component_specs: - if component_spec.obj is not None: - setattr(self, component_spec.name, component_spec.obj) + if component_spec.config is not None: + component_obj = component_spec.type_hint(**component_spec.config) + self.register_components(component_spec.name, component_obj) else: - setattr(self, component_spec.name, None) + self.register_components(component_spec.name, None) default_configs = {} for config_spec in config_specs: - default_configs[config_spec.name] = config_spec.default + default_configs[config_spec.name] = config_spec.value self.register_to_config(**default_configs) @@ -1187,7 +1285,7 @@ def components(self): components[component_spec.name] = getattr(self, component_spec.name) return components - def update_states(self, expected_components, expected_configs, **kwargs): + def update(self, **kwargs): """ Update components and configs after instance creation. Auxiliaries (e.g. image_processor) should be defined for each pipeline block, does not need to be updated by users. Logs if existing non-None components are being @@ -1197,7 +1295,7 @@ def update_states(self, expected_components, expected_configs, **kwargs): kwargs (dict): Keyword arguments to update the states. """ - for component in expected_components: + for component in self.components_specs: if component.name in kwargs: if hasattr(self, component.name) and getattr(self, component.name) is not None: current_component = getattr(self, component.name) @@ -1217,10 +1315,10 @@ def update_states(self, expected_components, expected_configs, **kwargs): f"with new value (type: {type(new_component).__name__})" ) - setattr(self.components_state, component.name, kwargs.pop(component.name)) + setattr(self, component.name, kwargs.pop(component.name)) configs_to_add = {} - for config in expected_configs: + for config in self.configs_specs: if config.name in kwargs: configs_to_add[config.name] = kwargs.pop(config.name) self.register_to_config(**configs_to_add) @@ -1228,3 +1326,4 @@ def update_states(self, expected_components, expected_configs, **kwargs): # YiYi TODO: should support to method def to(self, *args, **kwargs): pass + diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py index 882ec2e18552..d4563ed2ea53 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py @@ -34,7 +34,7 @@ from ..controlnet.multicontrolnet import MultiControlNetModel from ..modular_pipeline import ( AutoPipelineBlocks, - ModularPipeline, + ModularPipelineLoader, PipelineBlock, PipelineState, InputParam, @@ -58,6 +58,7 @@ from ...schedulers import KarrasDiffusionSchedulers from ...guiders import GuiderType, ClassifierFreeGuidance +from ...configuration_utils import FrozenDict import numpy as np @@ -646,7 +647,7 @@ def description(self) -> str: def expected_components(self) -> List[ComponentSpec]: return [ ComponentSpec("vae", AutoencoderKL), - ComponentSpec("image_processor", VaeImageProcessor, obj=VaeImageProcessor()), + ComponentSpec("image_processor", VaeImageProcessor, config=FrozenDict({"vae_scale_factor": 8})), ] @property @@ -741,8 +742,8 @@ class StableDiffusionXLInpaintVaeEncoderStep(PipelineBlock): def expected_components(self) -> List[ComponentSpec]: return [ ComponentSpec("vae", AutoencoderKL), - ComponentSpec("image_processor", VaeImageProcessor, obj=VaeImageProcessor()), - ComponentSpec("mask_processor", VaeImageProcessor, obj=VaeImageProcessor(do_normalize=False, do_binarize=True, do_convert_grayscale=True)), + ComponentSpec("image_processor", VaeImageProcessor, config=FrozenDict({"vae_scale_factor": 8})), + ComponentSpec("mask_processor", VaeImageProcessor, config=FrozenDict({"do_normalize": False, "vae_scale_factor": 8, "do_binarize": True, "do_convert_grayscale": True})), ] @@ -1728,7 +1729,7 @@ class StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep(PipelineBlock): @property def expected_configs(self) -> List[ConfigSpec]: - return [ConfigSpec("requires_aesthetics_score", default=False),] + return [ConfigSpec("requires_aesthetics_score", False),] @property def description(self) -> str: @@ -2063,7 +2064,7 @@ class StableDiffusionXLDenoiseStep(PipelineBlock): @property def expected_components(self) -> List[ComponentSpec]: return [ - ComponentSpec("guider", GuiderType, obj=ClassifierFreeGuidance()), + ComponentSpec("guider", GuiderType, config=FrozenDict({"guidance_scale": 7.5})), ComponentSpec("scheduler", KarrasDiffusionSchedulers), ComponentSpec("unet", UNet2DConditionModel), ] @@ -2332,11 +2333,11 @@ class StableDiffusionXLControlNetDenoiseStep(PipelineBlock): @property def expected_components(self) -> List[ComponentSpec]: return [ - ComponentSpec("guider", GuiderType, obj=ClassifierFreeGuidance()), + ComponentSpec("guider", GuiderType, config=FrozenDict({"guidance_scale": 7.5})), ComponentSpec("scheduler", KarrasDiffusionSchedulers), ComponentSpec("unet", UNet2DConditionModel), ComponentSpec("controlnet", ControlNetModel), - ComponentSpec("control_image_processor", VaeImageProcessor, obj=VaeImageProcessor(do_convert_rgb=True, do_normalize=False)), + ComponentSpec("control_image_processor", VaeImageProcessor, config=FrozenDict({"do_convert_rgb": True, "do_normalize": False})), ] @property @@ -2763,8 +2764,8 @@ def expected_components(self) -> List[ComponentSpec]: ComponentSpec("unet", UNet2DConditionModel), ComponentSpec("controlnet", ControlNetUnionModel), ComponentSpec("scheduler", KarrasDiffusionSchedulers), - ComponentSpec("guider", GuiderType, obj=ClassifierFreeGuidance()), - ComponentSpec("control_image_processor", VaeImageProcessor, obj=VaeImageProcessor(do_convert_rgb=True, do_normalize=False)), + ComponentSpec("guider", GuiderType, config=FrozenDict({"guidance_scale": 7.5})), + ComponentSpec("control_image_processor", VaeImageProcessor, config=FrozenDict({"do_convert_rgb": True, "do_normalize": False})), ] @property @@ -3179,7 +3180,7 @@ class StableDiffusionXLDecodeLatentsStep(PipelineBlock): def expected_components(self) -> List[ComponentSpec]: return [ ComponentSpec("vae", AutoencoderKL), - ComponentSpec("image_processor", VaeImageProcessor, obj=VaeImageProcessor()) + ComponentSpec("image_processor", VaeImageProcessor, config=FrozenDict({"vae_scale_factor": 8})) ] @property @@ -3570,9 +3571,14 @@ def description(self): } -# YiYi TODO: rename to components etc. and not inherit from ModularPipeline -class StableDiffusionXLComponentStates( - ModularPipeline, +# YiYi Notes: model specific components: +## (1) it should inherit from ModularPipelineComponents +## (2) acts like a container that holds components and configs +## (3) define default config (related to components), e.g. default_sample_size, vae_scale_factor, num_channels_unet, num_channels_latents +## (4) inherit from model-specic loader class (e.g. StableDiffusionXLLoraLoaderMixin) +## (5) how to use together with Components_manager? +class StableDiffusionXLModularLoader( + ModularPipelineLoader, StableDiffusionMixin, TextualInversionLoaderMixin, StableDiffusionXLLoraLoaderMixin, From d456a97420696e45fbfde19af04c4e68fb25d4b4 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Thu, 24 Apr 2025 06:44:26 +0200 Subject: [PATCH 25/39] update components manager, allow loading with spec --- src/diffusers/pipelines/components_manager.py | 55 ++++++++++++++++++- src/diffusers/pipelines/modular_pipeline.py | 3 +- .../pipelines/modular_pipeline_utils.py | 12 ++++ .../pipeline_stable_diffusion_xl_modular.py | 4 +- 4 files changed, 67 insertions(+), 7 deletions(-) diff --git a/src/diffusers/pipelines/components_manager.py b/src/diffusers/pipelines/components_manager.py index 4bd4c23c281a..5cf471314d37 100644 --- a/src/diffusers/pipelines/components_manager.py +++ b/src/diffusers/pipelines/components_manager.py @@ -230,26 +230,75 @@ def search_best_candidate(module_sizes, min_memory_offload): return hooks_to_offload + +from .modular_pipeline_utils import ComponentSpec, ComponentLoadSpec class ComponentsManager: def __init__(self): self.components = OrderedDict() self.added_time = OrderedDict() # Store when components were added + self.components_specs = OrderedDict() self.collections = OrderedDict() # collection_name -> set of component_names self.model_hooks = None self._auto_offload_enabled = False - def add(self, name, component, collection: Optional[str] = None): + + def load_component(self, spec: Union[ComponentSpec, ComponentLoadSpec], **kwargs): + module_class = spec.type_hint + + + if spec.revision is not None: + kwargs["revision"] = spec.revision + if spec.variant is not None: + kwargs["variant"] = spec.variant + + component = module_class.from_pretrained(spec.repo, subfolder=spec.subfolder, **kwargs) + return component + + def add(self, name, component, collection: Optional[str] = None, load_spec: Optional[ComponentLoadSpec] = None): if name in self.components: logger.warning(f"Overriding existing component '{name}' in ComponentsManager") + self.components[name] = component self.added_time[name] = time.time() if collection: if collection not in self.collections: self.collections[collection] = set() self.collections[collection].add(name) + + if load_spec is not None: + self.components_specs[name] = load_spec if self._auto_offload_enabled: - self.enable_auto_cpu_offload(self._auto_offload_device) + self.enable_auto_cpu_offload(self._auto_offload_device) + + # YiYi TODO: combine this with add method? + def add_with_spec(self, name, spec:Union[ComponentSpec, ComponentLoadSpec], collection: Optional[str] = None, force_add: bool = False, **kwargs): + """ + Add a component to the manager. + + Args: + name: Name of the component in the ComponentsManager + component: The ComponentSpec to load + collection: Optional collection to add the component to + force_add: If True, always add the component even if the ComponentSpec already exists + **kwargs: Additional arguments to pass to the component loader + """ + + if isinstance(spec, ComponentSpec): + if spec.config is not None: + component = spec.type_hint(**spec.config) + self.add(name, component, collection=collection, load_spec=ComponentLoadSpec.from_component_spec(spec)) + return + + spec = ComponentLoadSpec.from_component_spec(spec) + + for k, v in self.components_specs.items(): + if v == spec and not force_add: + logger.warning(f"will not add {name} to ComponentsManager, as {k} already exists with same spec.Please use force_add=True to add it.") + return + + component = self.load_component(spec, **kwargs) + self.add(name, component, collection=collection, load_spec=spec) def remove(self, name): @@ -538,7 +587,7 @@ def from_pretrained(self, pretrained_model_name_or_path, prefix: Optional[str] = **kwargs: Additional arguments to pass to DiffusionPipeline.from_pretrained() """ subfolder = kwargs.pop("subfolder", None) - # YiYi TODO: extend auto model to support non-diffusers models + # YiYi TODO: extend AutoModel to support non-diffusers models if subfolder: from ..models import AutoModel component = AutoModel.from_pretrained(pretrained_model_name_or_path, subfolder=subfolder, **kwargs) diff --git a/src/diffusers/pipelines/modular_pipeline.py b/src/diffusers/pipelines/modular_pipeline.py index a27883047ea2..3059caf212c8 100644 --- a/src/diffusers/pipelines/modular_pipeline.py +++ b/src/diffusers/pipelines/modular_pipeline.py @@ -1158,8 +1158,7 @@ def run(self, state: PipelineState = None, output: Union[str, List[str]] = None, raise ValueError(f"Output '{output}' is not a valid output type") -# YiYi NOTE: not sure if this needs to be a ConfigMixin -class ModularPipelineLoader(ConfigMixin): +class ModularLoader(ConfigMixin): """ Base class for all Modular pipelines loaders. diff --git a/src/diffusers/pipelines/modular_pipeline_utils.py b/src/diffusers/pipelines/modular_pipeline_utils.py index 0fec1db91e90..282b94bb083c 100644 --- a/src/diffusers/pipelines/modular_pipeline_utils.py +++ b/src/diffusers/pipelines/modular_pipeline_utils.py @@ -36,6 +36,18 @@ class ComponentSpec: revision: Optional[str] = None variant: Optional[str] = None +@dataclass +class ComponentLoadSpec: + type_hint: type + repo: Optional[str] = None + subfolder: Optional[str] = None + revision: Optional[str] = None + variant: Optional[str] = None + + @classmethod + def from_component_spec(cls, component_spec: ComponentSpec): + return cls(type_hint=component_spec.type_hint, repo=component_spec.repo, subfolder=component_spec.subfolder, revision=component_spec.revision, variant=component_spec.variant) + @dataclass class ConfigSpec: """Specification for a pipeline configuration parameter.""" diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py index d4563ed2ea53..1ff0befb1597 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py @@ -34,7 +34,7 @@ from ..controlnet.multicontrolnet import MultiControlNetModel from ..modular_pipeline import ( AutoPipelineBlocks, - ModularPipelineLoader, + ModularLoader, PipelineBlock, PipelineState, InputParam, @@ -3578,7 +3578,7 @@ def description(self): ## (4) inherit from model-specic loader class (e.g. StableDiffusionXLLoraLoaderMixin) ## (5) how to use together with Components_manager? class StableDiffusionXLModularLoader( - ModularPipelineLoader, + ModularLoader, StableDiffusionMixin, TextualInversionLoaderMixin, StableDiffusionXLLoraLoaderMixin, From a1eb9ee951d3973abc96dd1a3f40c683f1866353 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Thu, 24 Apr 2025 12:31:29 +0200 Subject: [PATCH 26/39] make component spec loadable: add load/create method --- src/diffusers/pipelines/components_manager.py | 41 +++++++---------- .../pipelines/modular_pipeline_utils.py | 45 +++++++++++++++---- 2 files changed, 51 insertions(+), 35 deletions(-) diff --git a/src/diffusers/pipelines/components_manager.py b/src/diffusers/pipelines/components_manager.py index 5cf471314d37..f9a039ddaa12 100644 --- a/src/diffusers/pipelines/components_manager.py +++ b/src/diffusers/pipelines/components_manager.py @@ -241,19 +241,6 @@ def __init__(self): self.model_hooks = None self._auto_offload_enabled = False - - def load_component(self, spec: Union[ComponentSpec, ComponentLoadSpec], **kwargs): - module_class = spec.type_hint - - - if spec.revision is not None: - kwargs["revision"] = spec.revision - if spec.variant is not None: - kwargs["variant"] = spec.variant - - component = module_class.from_pretrained(spec.repo, subfolder=spec.subfolder, **kwargs) - return component - def add(self, name, component, collection: Optional[str] = None, load_spec: Optional[ComponentLoadSpec] = None): if name in self.components: logger.warning(f"Overriding existing component '{name}' in ComponentsManager") @@ -284,21 +271,23 @@ def add_with_spec(self, name, spec:Union[ComponentSpec, ComponentLoadSpec], coll **kwargs: Additional arguments to pass to the component loader """ - if isinstance(spec, ComponentSpec): - if spec.config is not None: - component = spec.type_hint(**spec.config) - self.add(name, component, collection=collection, load_spec=ComponentLoadSpec.from_component_spec(spec)) - return - - spec = ComponentLoadSpec.from_component_spec(spec) - + if isinstance(spec, ComponentSpec) and spec.repo is None: + component = spec.create(**kwargs) + self.add(name, component, collection=collection) + elif isinstance(spec, ComponentSpec): + load_spec = spec.to_load_spec() + elif isinstance(spec, ComponentLoadSpec): + load_spec = spec + else: + raise ValueError(f"Invalid spec type: {type(spec)}") + for k, v in self.components_specs.items(): - if v == spec and not force_add: - logger.warning(f"will not add {name} to ComponentsManager, as {k} already exists with same spec.Please use force_add=True to add it.") + if v == load_spec and not force_add: + logger.warning(f"{name} is not added to ComponentsManager, because `{k}` already exists with same spec. Please use `force_add=True` to add it.") return - - component = self.load_component(spec, **kwargs) - self.add(name, component, collection=collection, load_spec=spec) + + component = load_spec.load(**kwargs) + self.add(name, component, collection=collection, load_spec=load_spec) def remove(self, name): diff --git a/src/diffusers/pipelines/modular_pipeline_utils.py b/src/diffusers/pipelines/modular_pipeline_utils.py index 282b94bb083c..bb8cc1283e8b 100644 --- a/src/diffusers/pipelines/modular_pipeline_utils.py +++ b/src/diffusers/pipelines/modular_pipeline_utils.py @@ -13,7 +13,7 @@ # limitations under the License. import re -from dataclasses import dataclass +from dataclasses import dataclass, asdict from typing import Any, Dict, List, Optional, Tuple, Type, Union from ..utils.import_utils import is_torch_available @@ -27,26 +27,53 @@ class ComponentSpec: """Specification for a pipeline component.""" name: str - # YiYi NOTE: is type_hint a good fild name? it is the actual class, will be used to create the default instance - type_hint: Type + type_hint: Type # YiYi Notes: change to component_type? description: Optional[str] = None config: Optional[FrozenDict[str, Any]] = None # you can specific default config to create a default component if it is a stateless class like scheduler, guider or image processor repo: Optional[Union[str, List[str]]] = None subfolder: Optional[str] = None - revision: Optional[str] = None - variant: Optional[str] = None + + def create(self, **kwargs) -> Any: + """ + Create the component based on the config and additional kwargs. + + Args: + **kwargs: Additional arguments to pass to the component's __init__ method + + Returns: + The created component + """ + if self.config is not None: + init_kwargs = self.config + else: + init_kwargs = {} + return self.type_hint(**init_kwargs, **kwargs) + + def load(self, **kwargs) -> Any: + return self.to_load_spec().load(**kwargs) + + def to_load_spec(self) -> "ComponentLoadSpec": + """Convert to a ComponentLoadSpec for storage in ComponentsManager.""" + return ComponentLoadSpec.from_component_spec(self) @dataclass class ComponentLoadSpec: type_hint: type repo: Optional[str] = None subfolder: Optional[str] = None - revision: Optional[str] = None - variant: Optional[str] = None + def load(self, **kwargs) -> Any: + """Load the component from the repository.""" + repo = kwargs.pop("repo", self.repo) + subfolder = kwargs.pop("subfolder", self.subfolder) + + return self.type_hint.from_pretrained(repo, subfolder=subfolder, **kwargs) + + @classmethod def from_component_spec(cls, component_spec: ComponentSpec): - return cls(type_hint=component_spec.type_hint, repo=component_spec.repo, subfolder=component_spec.subfolder, revision=component_spec.revision, variant=component_spec.variant) + return cls(type_hint=component_spec.type_hint, repo=component_spec.repo, subfolder=component_spec.subfolder) + @dataclass class ConfigSpec: @@ -54,7 +81,7 @@ class ConfigSpec: name: str value: Any description: Optional[str] = None - repo: Optional[Union[str, List[str]]] = None + repo: Optional[Union[str, List[str]]] = None #YiYi Notes: not sure if this field is needed @dataclass class InputParam: From e2dcf9a5e4ffd475a7bb9d69de68feb110dd8ea5 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Thu, 24 Apr 2025 12:32:40 +0200 Subject: [PATCH 27/39] update ModularLarder, add save/from_pretrained, proper register_components, update --- src/diffusers/pipelines/modular_pipeline.py | 230 +++++++++++++------- 1 file changed, 149 insertions(+), 81 deletions(-) diff --git a/src/diffusers/pipelines/modular_pipeline.py b/src/diffusers/pipelines/modular_pipeline.py index 3059caf212c8..e07ee36103e8 100644 --- a/src/diffusers/pipelines/modular_pipeline.py +++ b/src/diffusers/pipelines/modular_pipeline.py @@ -22,12 +22,16 @@ import torch from tqdm.auto import tqdm import re +import os + +from huggingface_hub.utils import validate_hf_hub_args from ..configuration_utils import ConfigMixin from ..utils import ( is_accelerate_available, is_accelerate_version, logging, + PushToHubMixin, ) from .pipeline_loading_utils import _get_pipeline_class from .modular_pipeline_utils import ( @@ -993,9 +997,9 @@ class ModularPipelineMixin: Mixin for all PipelineBlocks: PipelineBlock, AutoPipelineBlocks, SequentialPipelineBlocks """ - def register_loader(self, global_components_manager: ComponentsManager, label: Optional[str] = None): - self._global_components_manager = global_components_manager - self._label = label + # def register_loader(self, global_components_manager: ComponentsManager, label: Optional[str] = None): + # self._global_components_manager = global_components_manager + # self._label = label #YiYi TODO: add validation for kwargs? def setup_loader(self, **kwargs): @@ -1026,10 +1030,10 @@ def setup_loader(self, **kwargs): """ # Create deep copies to avoid modifying the original specs - components_specs = deepcopy(self.expected_components) + component_specs = deepcopy(self.expected_components) config_specs = deepcopy(self.expected_configs) - expected_component_names = set([c.name for c in components_specs]) + expected_component_names = set([c.name for c in component_specs]) expected_config_names = set([c.name for c in config_specs]) # Check if a default repo is provided @@ -1054,7 +1058,7 @@ def setup_loader(self, **kwargs): # First apply default repo to all components if provided if repo is not None: - for component_spec in components_specs: + for component_spec in component_specs: # components defined with a config are classes like image_processor or guider, # skip setting loading related attributes for them, they should be initialized with the default config if component_spec.config is None: @@ -1073,7 +1077,7 @@ def setup_loader(self, **kwargs): if not isinstance(value, (tuple, list)): value = (value,) # Find the matching component spec - for component_spec in components_specs: + for component_spec in component_specs: if component_spec.name == name: # Handle tuple of (repo, subfolder) component_spec.repo = value[0] @@ -1094,7 +1098,7 @@ def setup_loader(self, **kwargs): loader_class = getattr(diffusers_module, loader_class_name) # Create the loader with the updated specs - self.loader = loader_class(components_specs, config_specs) + self.loader = loader_class(component_specs, config_specs) @property @@ -1158,7 +1162,45 @@ def run(self, state: PipelineState = None, output: Union[str, List[str]] = None, raise ValueError(f"Output '{output}' is not a valid output type") -class ModularLoader(ConfigMixin): +def _find_spec_by_name(specs: List[Union[ComponentSpec, ConfigSpec]], name: str) -> Union[ComponentSpec, ConfigSpec]: + for spec in specs: + if hasattr(spec, "name") and spec.name == name: + return spec + logger.warning(f"'{name}' not found in the specs") + return None + +# YiYi TODO: refactor the _fetch_class_library_tuple in pipeline_loading_utils.py to acceept class (current object) +from .pipeline_loading_utils import LOADABLE_CLASSES +import importlib +def _fetch_class_library_tuple(module_class): + # import it here to avoid circular import + diffusers_module = importlib.import_module(__name__.split(".")[0]) + pipelines = getattr(diffusers_module, "pipelines") + + library = module_class.__module__.split(".")[0] + + # check if the module is a pipeline module + module_path_items = module_class.__module__.split(".") + pipeline_dir = module_path_items[-2] if len(module_path_items) > 2 else None + + path = module_class.__module__.split(".") + is_pipeline_module = pipeline_dir in path and hasattr(pipelines, pipeline_dir) + + # if library is not in LOADABLE_CLASSES, then it is a custom module. + # Or if it's a pipeline module, then the module is inside the pipeline + # folder so we set the library to module name. + if is_pipeline_module: + library = pipeline_dir + elif library not in LOADABLE_CLASSES: + library = module_class.__module__ + + # retrieve class_name + class_name = module_class.__name__ + + return (library, class_name) + + +class ModularLoader(ConfigMixin, PushToHubMixin): """ Base class for all Modular pipelines loaders. @@ -1167,16 +1209,24 @@ class ModularLoader(ConfigMixin): def register_components(self, **kwargs): + """ + Register components with their corresponding specs. + + Args: + **kwargs: Keyword arguments where keys are component names and values are component objects. + + """ for name, module in kwargs.items(): + component_spec = _find_spec_by_name(self.component_specs, name) + library, class_name = _fetch_class_library_tuple(component_spec.type_hint) + load_spec_dict = OrderedDict({ + "repo": component_spec.repo, + "subfolder": component_spec.subfolder, + }) - repo = self.components_specs[name].repo - subfolder = self.components_specs[name].subfolder # retrieve library - if module is None or isinstance(module, (tuple, list)) and module[0] is None: - register_dict = {name: (None, None, (None, None))} - else: - library, class_name = _fetch_class_library_tuple(module) - register_dict = {name: (library, class_name, (repo, subfolder))} + + register_dict = {name: (library, class_name, load_spec_dict)} # save model index config self.register_to_config(**register_dict) @@ -1184,41 +1234,25 @@ def register_components(self, **kwargs): # set models setattr(self, name, module) - def __setattr__(self, name: str, value: Any): - if name in self.__dict__ and hasattr(self.config, name): - - repo = self.components_specs[name].repo - subfolder = self.components_specs[name].subfolder - - # We need to overwrite the config if name exists in config - if isinstance(getattr(self.config, name), (tuple, list)): - if value is not None and self.config[name][0] is not None: - library, class_name = _fetch_class_library_tuple(value) - register_dict = {name: (library, class_name, (repo, subfolder))} - else: - register_dict = {name: (None, None, (None, None))} - - self.register_to_config(**register_dict) - else: - self.register_to_config(**{name: value}) - - super().__setattr__(name, value) - - - def __init__(self, component_specs, config_specs): - self.components_specs = deepcopy(component_specs) - self.configs_specs = deepcopy(config_specs) + def __init__(self, component_specs: List[ComponentSpec], config_specs: Optional[List[ConfigSpec]]=None): + + if config_specs is not None: + self.config_specs = deepcopy(config_specs) + else: + self.config_specs = [] + + if component_specs is None: + self.component_specs = [] + else: + self.component_specs = deepcopy(component_specs) - for component_spec in component_specs: - if component_spec.config is not None: - component_obj = component_spec.type_hint(**component_spec.config) - self.register_components(component_spec.name, component_obj) - else: - self.register_components(component_spec.name, None) + for component_spec in self.component_specs: + register_dict = {component_spec.name: None} + self.register_components(**register_dict) default_configs = {} - for config_spec in config_specs: + for config_spec in self.config_specs: default_configs[config_spec.name] = config_spec.value self.register_to_config(**default_configs) @@ -1279,50 +1313,84 @@ def dtype(self) -> torch.dtype: @property def components(self): components = {} - for component_spec in self.expected_components: + for component_spec in self.component_specs: if hasattr(self, component_spec.name): components[component_spec.name] = getattr(self, component_spec.name) return components - def update(self, **kwargs): + def update(self, repo=None, **kwargs): """ - Update components and configs after instance creation. Auxiliaries (e.g. image_processor) should be defined for - each pipeline block, does not need to be updated by users. Logs if existing non-None components are being - overwritten. - + Update components and configs specs after instance creation. + Args: - kwargs (dict): Keyword arguments to update the states. + repo (str, optional): Default repository to use for all components + **kwargs: Updates, which can be: + - For components: + - A string: Used as the repository name + - A tuple: (repo, subfolder) or (repo,) + - A ComponentSpec: Used to replace the existing spec + + If the components already exist in the loader, it will load the component + with updated info and replace the existing one; otherwise, it will only + update the spec. + + - For configs: + - Any value: Used to update the config value as well as the value field in the config spec """ + # Update global defaults if provided + if repo is not None: + for component_spec in self.component_specs: + component_spec.repo = repo + + # Process all updates + register_components_dict = {} + for component_spec in self.component_specs: + if component_spec.name in kwargs: + # update the component spec + component_kwargs = kwargs.pop(component_spec.name) + if isinstance(component_kwargs, ComponentSpec): + component_spec = component_kwargs + elif isinstance(component_kwargs, str): + component_spec.repo = component_kwargs + elif isinstance(component_kwargs, tuple): + component_spec.repo = component_kwargs[0] + if len(component_kwargs) > 1: + component_spec.subfolder = component_kwargs[1] + + if self.components[component_spec.name]: + new_component = component_spec.load() + else: + new_component = None + register_components_dict[component_spec.name] = new_component + + self.register_components(**register_components_dict) + + register_configs_dict = {} + for config_spec in self.config_specs: + if config_spec.name in kwargs: + config_value = kwargs.pop(config_spec.name) + if isinstance(config_value, ConfigSpec): + config_spec = config_value + else: + config_spec.value = config_value + register_configs_dict[config_spec.name] = config_spec.value + self.register_to_config(**register_configs_dict) - for component in self.components_specs: - if component.name in kwargs: - if hasattr(self, component.name) and getattr(self, component.name) is not None: - current_component = getattr(self, component.name) - new_component = kwargs[component.name] - - if not isinstance(new_component, current_component.__class__): - logger.info( - f"Overwriting existing component '{component.name}' " - f"(type: {current_component.__class__.__name__}) " - f"with type: {new_component.__class__.__name__})" - ) - elif isinstance(current_component, torch.nn.Module): - if id(current_component) != id(new_component): - logger.info( - f"Overwriting existing component '{component.name}' " - f"(type: {type(current_component).__name__}) " - f"with new value (type: {type(new_component).__name__})" - ) - - setattr(self, component.name, kwargs.pop(component.name)) - - configs_to_add = {} - for config in self.configs_specs: - if config.name in kwargs: - configs_to_add[config.name] = kwargs.pop(config.name) - self.register_to_config(**configs_to_add) + + if len(kwargs) > 0: + logger.warning(f"Unexpected input '{kwargs.keys()}' provided. This input will be ignored.") # YiYi TODO: should support to method def to(self, *args, **kwargs): pass + # YiYi TODO: should support save some components too! currently only modular_model_index.json is saved + def save_pretrained(self, save_directory: Union[str, os.PathLike], push_to_hub: bool = False, **kwargs): + self.save_config(save_directory=save_directory, push_to_hub=push_to_hub, **kwargs) + + @classmethod + @validate_hf_hub_args + def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]] = None, **kwargs): + config_dict = cls.load_config(pretrained_model_name_or_path, **kwargs) + return config_dict + From 267a1af6ab84f0c79d5b777b557f5fc3977d5b3c Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Fri, 25 Apr 2025 16:08:43 +0200 Subject: [PATCH 28/39] up --- src/diffusers/pipelines/modular_pipeline.py | 312 +++++++++++++----- .../pipelines/modular_pipeline_utils.py | 4 +- 2 files changed, 230 insertions(+), 86 deletions(-) diff --git a/src/diffusers/pipelines/modular_pipeline.py b/src/diffusers/pipelines/modular_pipeline.py index e07ee36103e8..5cd5221a7602 100644 --- a/src/diffusers/pipelines/modular_pipeline.py +++ b/src/diffusers/pipelines/modular_pipeline.py @@ -1162,13 +1162,6 @@ def run(self, state: PipelineState = None, output: Union[str, List[str]] = None, raise ValueError(f"Output '{output}' is not a valid output type") -def _find_spec_by_name(specs: List[Union[ComponentSpec, ConfigSpec]], name: str) -> Union[ComponentSpec, ConfigSpec]: - for spec in specs: - if hasattr(spec, "name") and spec.name == name: - return spec - logger.warning(f"'{name}' not found in the specs") - return None - # YiYi TODO: refactor the _fetch_class_library_tuple in pipeline_loading_utils.py to acceept class (current object) from .pipeline_loading_utils import LOADABLE_CLASSES import importlib @@ -1200,6 +1193,19 @@ def _fetch_class_library_tuple(module_class): return (library, class_name) +def simple_import_class_obj(library_name, class_name): + from diffusers import pipelines + is_pipeline_module = hasattr(pipelines, library_name) + + if is_pipeline_module: + pipeline_module = getattr(pipelines, library_name) + class_obj = getattr(pipeline_module, class_name) + else: + library = importlib.import_module(library_name) + class_obj = getattr(library, class_name) + + return class_obj + class ModularLoader(ConfigMixin, PushToHubMixin): """ Base class for all Modular pipelines loaders. @@ -1210,50 +1216,89 @@ class ModularLoader(ConfigMixin, PushToHubMixin): def register_components(self, **kwargs): """ - Register components with their corresponding specs. + Register components with their corresponding specs. + This method is called when component changed or its spec changed (in self.component_specs). Args: **kwargs: Keyword arguments where keys are component names and values are component objects. """ for name, module in kwargs.items(): - component_spec = _find_spec_by_name(self.component_specs, name) + + current_module = getattr(self, name, None) + + # update config based on the updated component spec + component_spec = self.component_specs.get(name) + if component_spec is None: + logger.warning(f"register_components: skipping unknown component '{name}'") + continue + library, class_name = _fetch_class_library_tuple(component_spec.type_hint) load_spec_dict = OrderedDict({ "repo": component_spec.repo, "subfolder": component_spec.subfolder, }) - # retrieve library - register_dict = {name: (library, class_name, load_spec_dict)} # save model index config self.register_to_config(**register_dict) - # set models + # set the component as attribute + # skip if the component is already registered with the same object + if current_module is module: + logger.info(f"register_components: {name} is already registered with same object, skipping") + continue + + # it module is not an instance of the expected type, still register it but with a warning + if module is not None and not isinstance(module, component_spec.type_hint): + logger.warning(f"register_components: adding {name} with type: {module.__class__.__name__}, expected: {component_spec.type_hint.__name__}") + + # warn if unregister + if current_module is not None and module is None: + logger.info( + f"register_components: setting '{name}' to None " + f"(was {current_module.__class__.__name__})" + ) + # warn if class mismatch + elif current_module is not None \ + and module is not None \ + and not isinstance(module, current_module.__class__): + logger.warning( + f"register_components: overwriting component '{name}' " + f"(type {current_module.__class__.__name__}) " + f"with DIFFERENT type {module.__class__.__name__}" + ) + # same type, new instance → debug + elif current_module is not None \ + and module is not None \ + and isinstance(module, current_module.__class__) \ + and current_module != module: + logger.debug( + f"register_components: replacing existing '{name}' " + f"(same type {type(current_module).__name__}, new instance)" + ) + + # finally set models setattr(self, name, module) def __init__(self, component_specs: List[ComponentSpec], config_specs: Optional[List[ConfigSpec]]=None): - - if config_specs is not None: - self.config_specs = deepcopy(config_specs) - else: - self.config_specs = [] - - if component_specs is None: - self.component_specs = [] - else: - self.component_specs = deepcopy(component_specs) + self.component_specs = { + spec.name: deepcopy(spec) for spec in (component_specs or []) + } + self.config_specs = { + spec.name: deepcopy(spec) for spec in (config_specs or []) + } - for component_spec in self.component_specs: - register_dict = {component_spec.name: None} - self.register_components(**register_dict) + register_components_dict = {} + for component_spec in self.component_specs.values(): + register_components_dict[component_spec.name] = None + self.register_components(**register_components_dict) default_configs = {} - for config_spec in self.config_specs: - default_configs[config_spec.name] = config_spec.value + for config_spec in self.config_specs.values(): + default_configs[config_spec.name] = config_spec.default self.register_to_config(**default_configs) @@ -1311,86 +1356,187 @@ def dtype(self) -> torch.dtype: @property - def components(self): - components = {} - for component_spec in self.component_specs: - if hasattr(self, component_spec.name): - components[component_spec.name] = getattr(self, component_spec.name) - return components + def components(self) -> Dict[str, Any]: + # return only components we've actually set as attributes on self + return { + name: getattr(self, name) + for name in self.component_specs.keys() + if hasattr(self, name) + } def update(self, repo=None, **kwargs): """ - Update components and configs specs after instance creation. + Update components and configs after instance creation. Args: repo (str, optional): Default repository to use for all components - **kwargs: Updates, which can be: + **kwargs: + Updates, which can be: - For components: - A string: Used as the repository name - A tuple: (repo, subfolder) or (repo,) - - A ComponentSpec: Used to replace the existing spec - - If the components already exist in the loader, it will load the component - with updated info and replace the existing one; otherwise, it will only - update the spec. - + - A ComponentSpec: Replace the existing spec + + If the component is already loaded, it will be reloaded with updated info; + otherwise only the spec is updated. + - For configs: - - Any value: Used to update the config value as well as the value field in the config spec - """ - # Update global defaults if provided - if repo is not None: - for component_spec in self.component_specs: - component_spec.repo = repo + - Any value: Update the config value + + - Additional loader options: + Passed through to the underlying component loading methods + (e.g., from_pretrained), such as torch_dtype, revision, variant, etc. + """ + + # extract component_updates from `kwargs``: + # e.g. loader.update(unet=..., vae=...)` -> {"unet": ..., "vae": ...} + component_updates = {k: kwargs.pop(k) for k in self.component_specs.keys() if k in kwargs} + # extract config_updates from `kwargs``: + # e.g. loader.update(requires_aesthetics_score=False) -> {"requires_aesthetics_score": False} + config_updates = {k: kwargs.pop(k) for k in self.config_specs.keys() if k in kwargs} + + # create a dict to contain all the component specs to be updated, + new_components_specs = {} - # Process all updates - register_components_dict = {} - for component_spec in self.component_specs: - if component_spec.name in kwargs: - # update the component spec - component_kwargs = kwargs.pop(component_spec.name) - if isinstance(component_kwargs, ComponentSpec): - component_spec = component_kwargs - elif isinstance(component_kwargs, str): - component_spec.repo = component_kwargs - elif isinstance(component_kwargs, tuple): - component_spec.repo = component_kwargs[0] - if len(component_kwargs) > 1: - component_spec.subfolder = component_kwargs[1] + # update global default repo on each component spec + # e.g loader.update(repo="new_repo") -> {"unet": ComponentSpec(repo="new_repo", ...), "vae": ComponentSpec(repo="new_repo", ...)} + if repo is not None: + for spec in self.component_specs.values(): + new_spec = deepcopy(spec) + new_spec.repo = repo + new_components_specs[spec.name] = new_spec + + # update component specs with component updates extracted from the `kwargs` + # YiYi Notes: should we automatically reload? + for name, new_value in component_updates.items(): + # make a copy of the spec to avoid partial mutation + new_spec = deepcopy(self.component_specs[name]) + if isinstance(new_value, ComponentSpec): + # e.g. loader.update(unet = ComponentSpec(type_hint=UNet2DConditionModel, ...)) + new_spec = new_value + elif isinstance(new_value, str): + # e.g. loader.update(unet="repo/unet") + new_spec.repo = new_value + elif isinstance(new_value, (tuple, list)): + # e.g. loader.update(unet = ("repo/unet", "subfolder")) + new_spec.repo = new_value[0] + if len(new_value) > 1: + new_spec.subfolder = new_value[1] + + # potentially override the spec if global repo is provided + new_components_specs[name] = new_spec - if self.components[component_spec.name]: - new_component = component_spec.load() - else: - new_component = None - register_components_dict[component_spec.name] = new_component + # attempt to update the components if it's already loaded + components_to_register = {} + for name, new_component_spec in new_components_specs.items(): + if getattr(self, name, None) is not None: + try: + # perform atomic update only if successful load the new component + # load, update components_spec and register_components + new_component = new_component_spec.load(**kwargs) + self.component_specs[name] = new_component_spec + components_to_register[name] = new_component + except Exception as e: + logger.warning(f"Failed to update component '{name}': {e}") + else: + # only update the spec if the component is not loaded (e.g. self.unet = None) + self.component_specs[name] = new_component_spec + components_to_register[name] = None - self.register_components(**register_components_dict) - - register_configs_dict = {} - for config_spec in self.config_specs: - if config_spec.name in kwargs: - config_value = kwargs.pop(config_spec.name) - if isinstance(config_value, ConfigSpec): - config_spec = config_value - else: - config_spec.value = config_value - register_configs_dict[config_spec.name] = config_spec.value - self.register_to_config(**register_configs_dict) - + self.register_components(**components_to_register) + + config_to_register = {} + for name, new_value in config_updates.items(): + if isinstance(new_value, ConfigSpec): + # e.g. requires_aesthetics_score = ConfigSpec(name="requires_aesthetics_score", default=False) + self.config_specs[name] = new_value + config_to_register[name] = new_value.default + else: + # e.g. requires_aesthetics_score = False + self.config_specs[name].default = new_value + config_to_register[name] = new_value + self.register_to_config(**config_to_register) - if len(kwargs) > 0: - logger.warning(f"Unexpected input '{kwargs.keys()}' provided. This input will be ignored.") + def load(self, **kwargs): + """ + Load components and optionally set config values. + + This method has three modes: + 1. `self.load()` - load all components from their specs + 2. `self.load(unet=unet, text_encoder=text_encoder)` - use provided components directly, + load remaining components from specs + 3. `self.load(...,requires_aesthetics_score=False)` - additinally set config values + + Args: + **kwargs: Can include: + - Component objects to set directly (e.g., unet=my_unet) + - config values to set (e.g., requires_aesthetics_score=False) + - additional kwargs to be passed to `from_pretrained()`, e.g. torch_dtype=torch.bfloat16 + + Returns: + self: The loader instance with loaded components + """ + config_updates = {k: kwargs.pop(k) for k in self.config_specs.keys() if k in kwargs} + passed_component_obj = {k: kwargs.pop(k) for k in self.component_specs.keys() if k in kwargs} + + # 1. Set any config values provided (without updating defaults in specs) + if config_updates: + self.register_to_config(**config_updates) + + # 2. Process components + components_to_register = {} + + # First register the components provided directly + for name, component in passed_component_obj.items(): + components_to_register[name] = component + + # Then load the remaining components from specs + remaining_components = set(self.component_specs.keys()) - set(passed_component_obj.keys()) + for name in remaining_components: + spec = self.component_specs[name] + try: + if spec.repo is not None: + components_to_register[name] = spec.load(**kwargs) + elif spec.config is not None: + components_to_register[name] = spec.create() + except Exception as e: + logger.warning(f"Failed to create component '{name}': {e}") + + # Register all components at once + self.register_components(**components_to_register) # YiYi TODO: should support to method def to(self, *args, **kwargs): pass - # YiYi TODO: should support save some components too! currently only modular_model_index.json is saved + # YiYi TODO: + # 1. should support save some components too! currently only modular_model_index.json is saved + # 2. maybe order the json file to make it more readable: configs first, then components def save_pretrained(self, save_directory: Union[str, os.PathLike], push_to_hub: bool = False, **kwargs): + + component_names = list(self.component_specs.keys()) + config_names = list(self.config_specs.keys()) + self.register_to_config(_components_names=component_names, _configs_names=config_names) self.save_config(save_directory=save_directory, push_to_hub=push_to_hub, **kwargs) @classmethod @validate_hf_hub_args def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]] = None, **kwargs): + config_dict = cls.load_config(pretrained_model_name_or_path, **kwargs) - return config_dict + expected_component = set(config_dict.pop("_components_names")) + expected_config = set(config_dict.pop("_configs_names")) + + component_specs = [] + config_specs = [] + for name, value in config_dict.items(): + if name in expected_component and isinstance(value, (tuple, list)) and len(value) == 3: + library_name, class_name, load_spec_dict = value + type_hint = simple_import_class_obj(library_name, class_name) + component_specs.append(ComponentSpec(name=name, type_hint=type_hint, **load_spec_dict)) + + elif name in expected_config: + config_specs.append(ConfigSpec(name=name, default=value)) + return cls(component_specs, config_specs=config_specs) + diff --git a/src/diffusers/pipelines/modular_pipeline_utils.py b/src/diffusers/pipelines/modular_pipeline_utils.py index bb8cc1283e8b..cc01138e8d3a 100644 --- a/src/diffusers/pipelines/modular_pipeline_utils.py +++ b/src/diffusers/pipelines/modular_pipeline_utils.py @@ -79,10 +79,8 @@ def from_component_spec(cls, component_spec: ComponentSpec): class ConfigSpec: """Specification for a pipeline configuration parameter.""" name: str - value: Any + default: Any description: Optional[str] = None - repo: Optional[Union[str, List[str]]] = None #YiYi Notes: not sure if this field is needed - @dataclass class InputParam: """Specification for an input parameter.""" From 3e4a772ead837ec3361b9384ace4e54768edc731 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Fri, 25 Apr 2025 19:43:38 +0200 Subject: [PATCH 29/39] fix --- src/diffusers/pipelines/modular_pipeline.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/src/diffusers/pipelines/modular_pipeline.py b/src/diffusers/pipelines/modular_pipeline.py index 5cd5221a7602..ccaed29daa51 100644 --- a/src/diffusers/pipelines/modular_pipeline.py +++ b/src/diffusers/pipelines/modular_pipeline.py @@ -1225,8 +1225,8 @@ def register_components(self, **kwargs): """ for name, module in kwargs.items(): - current_module = getattr(self, name, None) - + is_initialized = hasattr(self, name) + # update config based on the updated component spec component_spec = self.component_specs.get(name) if component_spec is None: @@ -1245,6 +1245,12 @@ def register_components(self, **kwargs): self.register_to_config(**register_dict) # set the component as attribute + # if it is not set yet, just set it and skip the warnings below + if not is_initialized: + setattr(self, name, module) + continue + + current_module = getattr(self, name, None) # skip if the component is already registered with the same object if current_module is module: logger.info(f"register_components: {name} is already registered with same object, skipping") From e8b5cde376a1e1375003e5c1850701553ea59276 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Mon, 28 Apr 2025 09:22:11 +0200 Subject: [PATCH 30/39] up! --- src/diffusers/pipelines/components_manager.py | 40 +- src/diffusers/pipelines/modular_pipeline.py | 429 ++++++++++-------- .../pipelines/modular_pipeline_utils.py | 209 +++++++-- .../pipelines/pipeline_loading_utils.py | 5 +- 4 files changed, 418 insertions(+), 265 deletions(-) diff --git a/src/diffusers/pipelines/components_manager.py b/src/diffusers/pipelines/components_manager.py index f9a039ddaa12..c5934a8be768 100644 --- a/src/diffusers/pipelines/components_manager.py +++ b/src/diffusers/pipelines/components_manager.py @@ -231,17 +231,17 @@ def search_best_candidate(module_sizes, min_memory_offload): -from .modular_pipeline_utils import ComponentSpec, ComponentLoadSpec +from .modular_pipeline_utils import ComponentSpec class ComponentsManager: def __init__(self): self.components = OrderedDict() self.added_time = OrderedDict() # Store when components were added - self.components_specs = OrderedDict() + self.load_ids = OrderedDict() # Store load_id of components (for model loaded with ComponentSpec) self.collections = OrderedDict() # collection_name -> set of component_names self.model_hooks = None self._auto_offload_enabled = False - def add(self, name, component, collection: Optional[str] = None, load_spec: Optional[ComponentLoadSpec] = None): + def add(self, name, component, collection: Optional[str] = None): if name in self.components: logger.warning(f"Overriding existing component '{name}' in ComponentsManager") @@ -252,42 +252,12 @@ def add(self, name, component, collection: Optional[str] = None, load_spec: Opti self.collections[collection] = set() self.collections[collection].add(name) - if load_spec is not None: - self.components_specs[name] = load_spec + if hasattr(component, "_diffusers_load_id"): + self.load_ids[name] = component._diffusers_load_id if self._auto_offload_enabled: self.enable_auto_cpu_offload(self._auto_offload_device) - # YiYi TODO: combine this with add method? - def add_with_spec(self, name, spec:Union[ComponentSpec, ComponentLoadSpec], collection: Optional[str] = None, force_add: bool = False, **kwargs): - """ - Add a component to the manager. - - Args: - name: Name of the component in the ComponentsManager - component: The ComponentSpec to load - collection: Optional collection to add the component to - force_add: If True, always add the component even if the ComponentSpec already exists - **kwargs: Additional arguments to pass to the component loader - """ - - if isinstance(spec, ComponentSpec) and spec.repo is None: - component = spec.create(**kwargs) - self.add(name, component, collection=collection) - elif isinstance(spec, ComponentSpec): - load_spec = spec.to_load_spec() - elif isinstance(spec, ComponentLoadSpec): - load_spec = spec - else: - raise ValueError(f"Invalid spec type: {type(spec)}") - - for k, v in self.components_specs.items(): - if v == load_spec and not force_add: - logger.warning(f"{name} is not added to ComponentsManager, because `{k}` already exists with same spec. Please use `force_add=True` to add it.") - return - - component = load_spec.load(**kwargs) - self.add(name, component, collection=collection, load_spec=load_spec) def remove(self, name): diff --git a/src/diffusers/pipelines/modular_pipeline.py b/src/diffusers/pipelines/modular_pipeline.py index ccaed29daa51..226c26a0b75d 100644 --- a/src/diffusers/pipelines/modular_pipeline.py +++ b/src/diffusers/pipelines/modular_pipeline.py @@ -26,7 +26,7 @@ from huggingface_hub.utils import validate_hf_hub_args -from ..configuration_utils import ConfigMixin +from ..configuration_utils import ConfigMixin, FrozenDict from ..utils import ( is_accelerate_available, is_accelerate_version, @@ -1162,37 +1162,10 @@ def run(self, state: PipelineState = None, output: Union[str, List[str]] = None, raise ValueError(f"Output '{output}' is not a valid output type") -# YiYi TODO: refactor the _fetch_class_library_tuple in pipeline_loading_utils.py to acceept class (current object) -from .pipeline_loading_utils import LOADABLE_CLASSES -import importlib -def _fetch_class_library_tuple(module_class): - # import it here to avoid circular import - diffusers_module = importlib.import_module(__name__.split(".")[0]) - pipelines = getattr(diffusers_module, "pipelines") - - library = module_class.__module__.split(".")[0] - - # check if the module is a pipeline module - module_path_items = module_class.__module__.split(".") - pipeline_dir = module_path_items[-2] if len(module_path_items) > 2 else None - - path = module_class.__module__.split(".") - is_pipeline_module = pipeline_dir in path and hasattr(pipelines, pipeline_dir) - - # if library is not in LOADABLE_CLASSES, then it is a custom module. - # Or if it's a pipeline module, then the module is inside the pipeline - # folder so we set the library to module name. - if is_pipeline_module: - library = pipeline_dir - elif library not in LOADABLE_CLASSES: - library = module_class.__module__ - - # retrieve class_name - class_name = module_class.__name__ - - return (library, class_name) +from .pipeline_loading_utils import _fetch_class_library_tuple +import importlib def simple_import_class_obj(library_name, class_name): from diffusers import pipelines is_pipeline_module = hasattr(pipelines, library_name) @@ -1206,6 +1179,11 @@ def simple_import_class_obj(library_name, class_name): return class_obj + +# YiYi TODO: +# 1. look into the serialization of modular_model_index.json, make sure the items are properly ordered like model_index.json (currently a mess) +# 2. do we need ConfigSpec? seems pretty unnecessrary for loader, can just add and kwargs to the loader +# 3. add validator for methods where we accpet kwargs to be passed to from_pretrained() class ModularLoader(ConfigMixin, PushToHubMixin): """ Base class for all Modular pipelines loaders. @@ -1217,7 +1195,7 @@ class ModularLoader(ConfigMixin, PushToHubMixin): def register_components(self, **kwargs): """ Register components with their corresponding specs. - This method is called when component changed or its spec changed (in self.component_specs). + This method is called when component changed or __init__ is called. Args: **kwargs: Keyword arguments where keys are component names and values are component objects. @@ -1225,28 +1203,40 @@ def register_components(self, **kwargs): """ for name, module in kwargs.items(): - is_initialized = hasattr(self, name) - - # update config based on the updated component spec - component_spec = self.component_specs.get(name) + # current component spec + component_spec = self._component_specs.get(name) if component_spec is None: logger.warning(f"register_components: skipping unknown component '{name}'") continue + + is_registered = hasattr(self, name) - library, class_name = _fetch_class_library_tuple(component_spec.type_hint) - load_spec_dict = OrderedDict({ - "repo": component_spec.repo, - "subfolder": component_spec.subfolder, - }) + if module is not None and not hasattr(module, "_diffusers_load_id"): + raise ValueError(f"`ModularLoader` only supports components created from `ComponentSpec`.") - register_dict = {name: (library, class_name, load_spec_dict)} + # actual library and class name of the module - # save model index config - self.register_to_config(**register_dict) + if module is not None: + library, class_name = _fetch_class_library_tuple(module) + new_component_spec = ComponentSpec.from_component(name, module) + component_spec_dict = self._component_spec_to_dict(new_component_spec) + + else: + library, class_name = None, None + # if module is None, we do not update the spec, + # but we still need to update the config to make sure it's synced with the component spec + # (in the case of the first time registration, we initilize the object with component spec, and then we call register_components() to register it to config) + new_component_spec = component_spec + component_spec_dict = self._component_spec_to_dict(component_spec) + + + register_dict = {name: (library, class_name, component_spec_dict)} # set the component as attribute - # if it is not set yet, just set it and skip the warnings below - if not is_initialized: + # if it is not set yet, just set it and skip the process to check and warn below + if not is_registered: + self.register_to_config(**register_dict) + self._component_specs[name] = new_component_spec setattr(self, name, module) continue @@ -1257,8 +1247,8 @@ def register_components(self, **kwargs): continue # it module is not an instance of the expected type, still register it but with a warning - if module is not None and not isinstance(module, component_spec.type_hint): - logger.warning(f"register_components: adding {name} with type: {module.__class__.__name__}, expected: {component_spec.type_hint.__name__}") + if module is not None and component_spec.type_hint is not None and not isinstance(module, component_spec.type_hint): + logger.warning(f"register_components: adding {name} with new type: {module.__class__.__name__}, previous type: {component_spec.type_hint.__name__}") # warn if unregister if current_module is not None and module is None: @@ -1266,15 +1256,6 @@ def register_components(self, **kwargs): f"register_components: setting '{name}' to None " f"(was {current_module.__class__.__name__})" ) - # warn if class mismatch - elif current_module is not None \ - and module is not None \ - and not isinstance(module, current_module.__class__): - logger.warning( - f"register_components: overwriting component '{name}' " - f"(type {current_module.__class__.__name__}) " - f"with DIFFERENT type {module.__class__.__name__}" - ) # same type, new instance → debug elif current_module is not None \ and module is not None \ @@ -1285,26 +1266,34 @@ def register_components(self, **kwargs): f"(same type {type(current_module).__name__}, new instance)" ) + # save modular_model_index.json config + self.register_to_config(**register_dict) + # update component spec + self._component_specs[name] = new_component_spec # finally set models setattr(self, name, module) - def __init__(self, component_specs: List[ComponentSpec], config_specs: Optional[List[ConfigSpec]]=None): - self.component_specs = { - spec.name: deepcopy(spec) for spec in (component_specs or []) + # YiYi TODO: add warning for passing multiple ComponentSpec/ConfigSpec with the same name + def __init__(self, specs: List[Union[ComponentSpec, ConfigSpec]]): + """ + Initialize the loader with a list of component specs and config specs. + """ + self._component_specs = { + spec.name: deepcopy(spec) for spec in specs if isinstance(spec, ComponentSpec) } - self.config_specs = { - spec.name: deepcopy(spec) for spec in (config_specs or []) + self._config_specs = { + spec.name: deepcopy(spec) for spec in specs if isinstance(spec, ConfigSpec) } register_components_dict = {} - for component_spec in self.component_specs.values(): - register_components_dict[component_spec.name] = None + for name, component_spec in self._component_specs.items(): + register_components_dict[name] = None self.register_components(**register_components_dict) default_configs = {} - for config_spec in self.config_specs.values(): - default_configs[config_spec.name] = config_spec.default + for name, config_spec in self._config_specs.items(): + default_configs[name] = config_spec.default self.register_to_config(**default_configs) @@ -1366,145 +1355,115 @@ def components(self) -> Dict[str, Any]: # return only components we've actually set as attributes on self return { name: getattr(self, name) - for name in self.component_specs.keys() + for name in self._component_specs.keys() if hasattr(self, name) } - def update(self, repo=None, **kwargs): + def update(self, **kwargs): """ Update components and configs after instance creation. Args: - repo (str, optional): Default repository to use for all components - **kwargs: - Updates, which can be: - - For components: - - A string: Used as the repository name - - A tuple: (repo, subfolder) or (repo,) - - A ComponentSpec: Replace the existing spec - - If the component is already loaded, it will be reloaded with updated info; - otherwise only the spec is updated. - - - For configs: - - Any value: Update the config value - - - Additional loader options: - Passed through to the underlying component loading methods - (e.g., from_pretrained), such as torch_dtype, revision, variant, etc. - """ - # extract component_updates from `kwargs``: - # e.g. loader.update(unet=..., vae=...)` -> {"unet": ..., "vae": ...} - component_updates = {k: kwargs.pop(k) for k in self.component_specs.keys() if k in kwargs} - # extract config_updates from `kwargs``: - # e.g. loader.update(requires_aesthetics_score=False) -> {"requires_aesthetics_score": False} - config_updates = {k: kwargs.pop(k) for k in self.config_specs.keys() if k in kwargs} + """ + """ + Update components and configuration values after the loader has been instantiated. - # create a dict to contain all the component specs to be updated, - new_components_specs = {} - - # update global default repo on each component spec - # e.g loader.update(repo="new_repo") -> {"unet": ComponentSpec(repo="new_repo", ...), "vae": ComponentSpec(repo="new_repo", ...)} - if repo is not None: - for spec in self.component_specs.values(): - new_spec = deepcopy(spec) - new_spec.repo = repo - new_components_specs[spec.name] = new_spec - - # update component specs with component updates extracted from the `kwargs` - # YiYi Notes: should we automatically reload? - for name, new_value in component_updates.items(): - # make a copy of the spec to avoid partial mutation - new_spec = deepcopy(self.component_specs[name]) - if isinstance(new_value, ComponentSpec): - # e.g. loader.update(unet = ComponentSpec(type_hint=UNet2DConditionModel, ...)) - new_spec = new_value - elif isinstance(new_value, str): - # e.g. loader.update(unet="repo/unet") - new_spec.repo = new_value - elif isinstance(new_value, (tuple, list)): - # e.g. loader.update(unet = ("repo/unet", "subfolder")) - new_spec.repo = new_value[0] - if len(new_value) > 1: - new_spec.subfolder = new_value[1] + This method allows you to: + 1. Replace existing components with new ones (e.g., updating the unet or text_encoder) + 2. Update configuration values (e.g., changing requires_safety_checker flag) + + Args: + **kwargs: Component objects or configuration values to update: + - Component objects: Must be created using ComponentSpec (e.g., `unet=new_unet, text_encoder=new_encoder`) + - Configuration values: Simple values to update configuration settings (e.g., `requires_safety_checker=False`) + + Raises: + ValueError: If a component wasn't created using ComponentSpec (doesn't have `_diffusers_load_id` attribute) - # potentially override the spec if global repo is provided - new_components_specs[name] = new_spec - - # attempt to update the components if it's already loaded - components_to_register = {} - for name, new_component_spec in new_components_specs.items(): - if getattr(self, name, None) is not None: - try: - # perform atomic update only if successful load the new component - # load, update components_spec and register_components - new_component = new_component_spec.load(**kwargs) - self.component_specs[name] = new_component_spec - components_to_register[name] = new_component - except Exception as e: - logger.warning(f"Failed to update component '{name}': {e}") - else: - # only update the spec if the component is not loaded (e.g. self.unet = None) - self.component_specs[name] = new_component_spec - components_to_register[name] = None + Examples: + ```python + # Update multiple components at once + loader.update( + unet=new_unet_model, + text_encoder=new_text_encoder + ) + + # Update configuration values + loader.update( + requires_safety_checker=False, + guidance_rescale=0.7 + ) + + # Update both components and configs together + loader.update( + unet=new_unet_model, + requires_safety_checker=False + ) + ``` + """ + + # extract component_specs_updates & config_specs_updates from `specs` + passed_components = {k: kwargs.pop(k) for k in self._component_specs if k in kwargs} + passed_config_values = {k: kwargs.pop(k) for k in self._config_specs if k in kwargs} + + for name, component in passed_components.items(): + if not hasattr(component, "_diffusers_load_id"): + raise ValueError(f"`ModularLoader` only supports components created from `ComponentSpec`.") + + if len(kwargs) > 0: + raise logger.warning(f"Unexpected keyword arguments, will be ignored: {kwargs.keys()}") - self.register_components(**components_to_register) + + self.register_components(**passed_components) + config_to_register = {} - for name, new_value in config_updates.items(): - if isinstance(new_value, ConfigSpec): - # e.g. requires_aesthetics_score = ConfigSpec(name="requires_aesthetics_score", default=False) - self.config_specs[name] = new_value - config_to_register[name] = new_value.default - else: - # e.g. requires_aesthetics_score = False - self.config_specs[name].default = new_value - config_to_register[name] = new_value + for name, new_value in passed_config_values.items(): + + # e.g. requires_aesthetics_score = False + self._config_specs[name].default = new_value + config_to_register[name] = new_value self.register_to_config(**config_to_register) - def load(self, **kwargs): + + # YiYi TODO: support map for additional from_pretrained kwargs + def load(self, component_names: List[str], **kwargs): """ - Load components and optionally set config values. - - This method has three modes: - 1. `self.load()` - load all components from their specs - 2. `self.load(unet=unet, text_encoder=text_encoder)` - use provided components directly, - load remaining components from specs - 3. `self.load(...,requires_aesthetics_score=False)` - additinally set config values + Load selectedcomponents from specs. Args: - **kwargs: Can include: - - Component objects to set directly (e.g., unet=my_unet) - - config values to set (e.g., requires_aesthetics_score=False) - - additional kwargs to be passed to `from_pretrained()`, e.g. torch_dtype=torch.bfloat16 - - Returns: - self: The loader instance with loaded components + component_names: List of component names to load + **kwargs: additional kwargs to be passed to `from_pretrained()`.Can be: + - a single value to be applied to all components to be loaded, e.g. torch_dtype=torch.bfloat16 + - a dict, e.g. torch_dtype={"unet": torch.bfloat16, "default": torch.float32} + - if potentially override ComponentSpec if passed a different loading field in kwargs, e.g. `repo`, `variant`, `revision`, etc. """ - config_updates = {k: kwargs.pop(k) for k in self.config_specs.keys() if k in kwargs} - passed_component_obj = {k: kwargs.pop(k) for k in self.component_specs.keys() if k in kwargs} - - # 1. Set any config values provided (without updating defaults in specs) - if config_updates: - self.register_to_config(**config_updates) + if not isinstance(component_names, list): + component_names = [component_names] + + components_to_load = set([name for name in component_names if name in self._component_specs]) + unknown_component_names = set([name for name in component_names if name not in self._component_specs]) + if len(unknown_component_names) > 0: + logger.warning(f"Unknown components will be ignored: {unknown_component_names}") - # 2. Process components components_to_register = {} - - # First register the components provided directly - for name, component in passed_component_obj.items(): - components_to_register[name] = component - - # Then load the remaining components from specs - remaining_components = set(self.component_specs.keys()) - set(passed_component_obj.keys()) - for name in remaining_components: - spec = self.component_specs[name] + for name in components_to_load: + spec = self._component_specs[name] + component_load_kwargs = {} + for key, value in kwargs.items(): + if not isinstance(value, dict): + # if the value is a single value, apply it to all components + component_load_kwargs[key] = value + else: + if name in value: + # if it is a dict, check if the component name is in the dict + component_load_kwargs[key] = value[name] + elif "default" in value: + # check if the default is specified + component_load_kwargs[key] = value["default"] try: - if spec.repo is not None: - components_to_register[name] = spec.load(**kwargs) - elif spec.config is not None: - components_to_register[name] = spec.create() + components_to_register[name] = spec.create(**component_load_kwargs) except Exception as e: logger.warning(f"Failed to create component '{name}': {e}") @@ -1518,16 +1477,21 @@ def to(self, *args, **kwargs): # YiYi TODO: # 1. should support save some components too! currently only modular_model_index.json is saved # 2. maybe order the json file to make it more readable: configs first, then components - def save_pretrained(self, save_directory: Union[str, os.PathLike], push_to_hub: bool = False, **kwargs): + def save_pretrained(self, save_directory: Union[str, os.PathLike], push_to_hub: bool = False, spec_only: bool = True, **kwargs): - component_names = list(self.component_specs.keys()) - config_names = list(self.config_specs.keys()) + component_names = list(self._component_specs.keys()) + config_names = list(self._config_specs.keys()) self.register_to_config(_components_names=component_names, _configs_names=config_names) self.save_config(save_directory=save_directory, push_to_hub=push_to_hub, **kwargs) + config = dict(self.config) + config.pop("_components_names", None) + config.pop("_configs_names", None) + self._internal_dict = FrozenDict(config) + @classmethod @validate_hf_hub_args - def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]] = None, **kwargs): + def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], spec_only: bool = True, **kwargs): config_dict = cls.load_config(pretrained_model_name_or_path, **kwargs) expected_component = set(config_dict.pop("_components_names")) @@ -1537,12 +1501,97 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P config_specs = [] for name, value in config_dict.items(): if name in expected_component and isinstance(value, (tuple, list)) and len(value) == 3: - library_name, class_name, load_spec_dict = value - type_hint = simple_import_class_obj(library_name, class_name) - component_specs.append(ComponentSpec(name=name, type_hint=type_hint, **load_spec_dict)) + library, class_name, component_spec_dict = value + component_spec = cls._dict_to_component_spec(name, component_spec_dict) + component_specs.append(component_spec) elif name in expected_config: config_specs.append(ConfigSpec(name=name, default=value)) - return cls(component_specs, config_specs=config_specs) + return cls(component_specs + config_specs) + + + @staticmethod + def _component_spec_to_dict(component_spec: ComponentSpec) -> Any: + """ + Convert a ComponentSpec into a JSON‐serializable dict for saving in + `modular_model_index.json`. + + This dict contains: + - "type_hint": Tuple[str, str] + The fully‐qualified module path and class name of the component. + - All loading fields defined by `component_spec.loading_fields()`, typically: + - "repo": Optional[str] + The model repository (e.g., "stabilityai/stable-diffusion-xl"). + - "subfolder": Optional[str] + A subfolder within the repo where this component lives. + - "variant": Optional[str] + An optional variant identifier for the model. + - "revision": Optional[str] + A specific git revision (commit hash, tag, or branch). + - ... any other loading fields defined on the spec. + + Args: + component_spec (ComponentSpec): + The spec object describing one pipeline component. + + Returns: + Dict[str, Any]: A mapping suitable for JSON serialization. + + Example: + >>> from diffusers.pipelines.modular_pipeline_utils import ComponentSpec + >>> from diffusers.models.unet import UNet2DConditionModel + >>> spec = ComponentSpec( + ... name="unet", + ... type_hint=UNet2DConditionModel, + ... config=None, + ... repo="path/to/repo", + ... subfolder="subfolder", + ... variant=None, + ... revision=None, + ... default_creation_method="from_pretrained", + ... ) + >>> ModularLoader._component_spec_to_dict(spec) + { + "type_hint": ("diffusers.models.unet", "UNet2DConditionModel"), + "repo": "path/to/repo", + "subfolder": "subfolder", + "variant": None, + "revision": None, + } + """ + if component_spec.type_hint is not None: + lib_name, cls_name = _fetch_class_library_tuple(component_spec.type_hint) + else: + lib_name = None + cls_name = None + load_spec_dict = {k: getattr(component_spec, k) for k in component_spec.loading_fields()} + return { + "type_hint": (lib_name, cls_name), + **load_spec_dict, + } + + @staticmethod + def _dict_to_component_spec( + name: str, + spec_dict: Dict[str, Any], + ) -> ComponentSpec: + """ + Reconstruct a ComponentSpec from a dict. + """ + # make a shallow copy so we can pop() safely + spec_dict = spec_dict.copy() + # pull out and resolve the stored type_hint + lib_name, cls_name = spec_dict.pop("type_hint") + if lib_name is not None and cls_name is not None: + type_hint = simple_import_class_obj(lib_name, cls_name) + else: + type_hint = None + + # re‐assemble the ComponentSpec + return ComponentSpec( + name=name, + type_hint=type_hint, + **spec_dict, + ) diff --git a/src/diffusers/pipelines/modular_pipeline_utils.py b/src/diffusers/pipelines/modular_pipeline_utils.py index cc01138e8d3a..05eff8f549da 100644 --- a/src/diffusers/pipelines/modular_pipeline_utils.py +++ b/src/diffusers/pipelines/modular_pipeline_utils.py @@ -13,66 +13,197 @@ # limitations under the License. import re -from dataclasses import dataclass, asdict -from typing import Any, Dict, List, Optional, Tuple, Type, Union +import inspect +from dataclasses import dataclass, asdict, field, fields +from typing import Any, Dict, List, Optional, Tuple, Type, Union, Literal from ..utils.import_utils import is_torch_available -from ..configuration_utils import FrozenDict +from ..configuration_utils import FrozenDict, ConfigMixin if is_torch_available(): import torch +# YiYi TODO: +# 1. validate the dataclass fields +# 2. add a validator for create_* methods, make sure they are valid inputs to pass to from_pretrained() @dataclass class ComponentSpec: - """Specification for a pipeline component.""" - name: str - type_hint: Type # YiYi Notes: change to component_type? + """Specification for a pipeline component. + + A component can be created in two ways: + 1. From scratch using __init__ with a config dict + 2. using `from_pretrained` + + Attributes: + name: Name of the component + type_hint: Type of the component (e.g. UNet2DConditionModel) + description: Optional description of the component + config: Optional config dict for __init__ creation + repo: Optional repo path for from_pretrained creation + subfolder: Optional subfolder in repo + variant: Optional variant in repo + revision: Optional revision in repo + default_creation_method: Preferred creation method - "from_config" or "from_pretrained" + """ + name: Optional[str] = None + type_hint: Optional[Type] = None description: Optional[str] = None - config: Optional[FrozenDict[str, Any]] = None # you can specific default config to create a default component if it is a stateless class like scheduler, guider or image processor - repo: Optional[Union[str, List[str]]] = None - subfolder: Optional[str] = None + config: Optional[FrozenDict[str, Any]] = None + # YiYi Notes: should we change it to pretrained_model_name_or_path for consistency? a bit long for a field name + repo: Optional[Union[str, List[str]]] = field(default=None, metadata={"loading": True}) + subfolder: Optional[str] = field(default=None, metadata={"loading": True}) + variant: Optional[str] = field(default=None, metadata={"loading": True}) + revision: Optional[str] = field(default=None, metadata={"loading": True}) + default_creation_method: Literal["from_config", "from_pretrained"] = "from_pretrained" - def create(self, **kwargs) -> Any: + + @classmethod + def from_component(cls, name: str, component: torch.nn.Module) -> Any: + """Create a ComponentSpec from a Component created by `create` method.""" + + if not hasattr(component, "_diffusers_load_id"): + raise ValueError("Component is not created by `create` method") + + type_hint = component.__class__ + + if component._diffusers_load_id == "null" and isinstance(component, ConfigMixin): + config = component.config + else: + config = None + + load_spec = cls.decode_load_id(component._diffusers_load_id) + + return cls(name=name, type_hint=type_hint, config=config, **load_spec) + + @classmethod + def loading_fields(cls) -> List[str]: + """ + Return the names of all loading‐related fields + (i.e. those whose field.metadata["loading"] is True). + """ + return [f.name for f in fields(cls) if f.metadata.get("loading", False)] + + + @property + def load_id(self) -> str: """ - Create the component based on the config and additional kwargs. + Unique identifier for this spec's pretrained load, + composed of repo|subfolder|variant|revision (no empty segments). + """ + parts = [getattr(self, k) for k in self.loading_fields()] + parts = ["null" if p is None else p for p in parts] + return "|".join(p for p in parts if p) + + @classmethod + def decode_load_id(cls, load_id: str) -> Dict[str, Optional[str]]: + """ + Decode a load_id string back into a dictionary of loading fields and values. Args: - **kwargs: Additional arguments to pass to the component's __init__ method - + load_id: The load_id string to decode, format: "repo|subfolder|variant|revision" + where None values are represented as "null" + Returns: - The created component + Dict mapping loading field names to their values. e.g. + { + "repo": "path/to/repo", + "subfolder": "subfolder", + "variant": "variant", + "revision": "revision" + } + If a segment value is "null", it's replaced with None. + Returns None if load_id is "null" (indicating component not loaded from pretrained). """ - if self.config is not None: - init_kwargs = self.config - else: - init_kwargs = {} - return self.type_hint(**init_kwargs, **kwargs) + if load_id == "null": + return None + + # Get all loading fields in order + loading_fields = cls.loading_fields() + result = {f: None for f in loading_fields} + + # Split the load_id + parts = load_id.split("|") + + # Map parts to loading fields by position + for i, part in enumerate(parts): + if i < len(loading_fields): + # Convert "null" string back to None + result[loading_fields[i]] = None if part == "null" else part + + return result - def load(self, **kwargs) -> Any: - return self.to_load_spec().load(**kwargs) + # YiYi TODO: add validator + def create(self, **kwargs) -> Any: + """Create the component using the preferred creation method.""" + + # from_pretrained creation + if self.default_creation_method == "from_pretrained": + return self.create_from_pretrained(**kwargs) + elif self.default_creation_method == "from_config": + # from_config creation + return self.create_from_config(**kwargs) + else: + raise ValueError(f"Invalid creation method: {self.default_creation_method}") - def to_load_spec(self) -> "ComponentLoadSpec": - """Convert to a ComponentLoadSpec for storage in ComponentsManager.""" - return ComponentLoadSpec.from_component_spec(self) - -@dataclass -class ComponentLoadSpec: - type_hint: type - repo: Optional[str] = None - subfolder: Optional[str] = None - - def load(self, **kwargs) -> Any: - """Load the component from the repository.""" - repo = kwargs.pop("repo", self.repo) - subfolder = kwargs.pop("subfolder", self.subfolder) + def create_from_config(self, config: Optional[Union[FrozenDict, Dict[str, Any]]] = None, **kwargs) -> Any: + """Create component using from_config with config.""" - return self.type_hint.from_pretrained(repo, subfolder=subfolder, **kwargs) + if self.type_hint is None: + raise ValueError( + f"`type_hint` is required when using from_config creation method." + ) + if not (isinstance(self.type_hint, type) and issubclass(self.type_hint, ConfigMixin)): + raise ValueError( + f"cannot create {self.type_hint} using from_config " + "because it is not a `ConfigMixin`." + ) + + config = config or self.config + + try: + component = self.type_hint.from_config(config, **kwargs) + except Exception as e: + raise ValueError(f"Error creating {self.name}[{self.type_hint.__name__}] from config: {e}") + component._diffusers_load_id = "null" + self.config = component.config + + return component + + # YiYi TODO: add guard for type of model, if it is supported by from_pretrained + def create_from_pretrained(self, **kwargs) -> Any: + """Create component using from_pretrained.""" + + passed_loading_kwargs = {key: kwargs.pop(key) for key in self.loading_fields() if key in kwargs} + load_kwargs = {key: passed_loading_kwargs.get(key, getattr(self, key)) for key in self.loading_fields()} + # repo is a required argument for from_pretrained, a.k.a. pretrained_model_name_or_path + repo = load_kwargs.pop("repo", None) + if repo is None: + raise ValueError(f"`repo` info is required when using from_pretrained creation method (you can directly set it in `repo` field of the ComponentSpec or pass it as an argument)") + + if self.type_hint is None: + try: + from diffusers import AutoModel + component = AutoModel.from_pretrained(repo, **load_kwargs, **kwargs) + except Exception as e: + raise ValueError(f"Error creating {self.name} without `type_hint` from pretrained: {e}") + self.type_hint = component.__class__ + else: + try: + component = self.type_hint.from_pretrained(repo, **load_kwargs, **kwargs) + except Exception as e: + raise ValueError(f"Error creating {self.name}[{self.type_hint.__name__}] from pretrained: {e}") + + if repo != self.repo: + self.repo = repo + for k, v in passed_loading_kwargs.items(): + if v is not None: + setattr(self, k, v) + component._diffusers_load_id = self.load_id + + return component - @classmethod - def from_component_spec(cls, component_spec: ComponentSpec): - return cls(type_hint=component_spec.type_hint, repo=component_spec.repo, subfolder=component_spec.subfolder) @dataclass diff --git a/src/diffusers/pipelines/pipeline_loading_utils.py b/src/diffusers/pipelines/pipeline_loading_utils.py index a9d6c561af34..813566434f52 100644 --- a/src/diffusers/pipelines/pipeline_loading_utils.py +++ b/src/diffusers/pipelines/pipeline_loading_utils.py @@ -841,7 +841,10 @@ def _fetch_class_library_tuple(module): library = not_compiled_module.__module__ # retrieve class_name - class_name = not_compiled_module.__class__.__name__ + if isinstance(not_compiled_module, type): + class_name = not_compiled_module.__name__ + else: + class_name = not_compiled_module.__class__.__name__ return (library, class_name) From 1952941a9c71a3ffa280b6ad643745780c5fbe34 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Tue, 29 Apr 2025 11:02:47 +0200 Subject: [PATCH 31/39] up! --- src/diffusers/__init__.py | 4 +- src/diffusers/pipelines/__init__.py | 4 +- src/diffusers/pipelines/components_manager.py | 8 +- src/diffusers/pipelines/modular_pipeline.py | 417 ++++++++---------- .../pipelines/modular_pipeline_utils.py | 76 +++- .../pipelines/pipeline_loading_utils.py | 14 + .../pipelines/stable_diffusion_xl/__init__.py | 4 +- .../pipeline_stable_diffusion_xl_modular.py | 90 ++-- .../dummy_torch_and_transformers_objects.py | 2 +- 9 files changed, 322 insertions(+), 297 deletions(-) diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 424011961ab0..5ddfbc4b3a4b 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -503,7 +503,7 @@ "StableDiffusionXLImg2ImgPipeline", "StableDiffusionXLInpaintPipeline", "StableDiffusionXLInstructPix2PixPipeline", - "StableDiffusionXLModularPipeline", + "StableDiffusionXLModularLoader", "StableDiffusionXLPAGImg2ImgPipeline", "StableDiffusionXLPAGInpaintPipeline", "StableDiffusionXLPAGPipeline", @@ -1073,7 +1073,7 @@ StableDiffusionXLImg2ImgPipeline, StableDiffusionXLInpaintPipeline, StableDiffusionXLInstructPix2PixPipeline, - StableDiffusionXLModularPipeline, + StableDiffusionXLModularLoader, StableDiffusionXLPAGImg2ImgPipeline, StableDiffusionXLPAGInpaintPipeline, StableDiffusionXLPAGPipeline, diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index aee275db0336..e0b47681af6f 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -329,7 +329,7 @@ "StableDiffusionXLInpaintPipeline", "StableDiffusionXLInstructPix2PixPipeline", "StableDiffusionXLPipeline", - "StableDiffusionXLModularPipeline", + "StableDiffusionXLModularLoader", "StableDiffusionXLAutoPipeline", ] ) @@ -693,7 +693,7 @@ StableDiffusionXLImg2ImgPipeline, StableDiffusionXLInpaintPipeline, StableDiffusionXLInstructPix2PixPipeline, - StableDiffusionXLModularPipeline, + StableDiffusionXLModularLoader, StableDiffusionXLPipeline, StableDiffusionXLAutoPipeline, ) diff --git a/src/diffusers/pipelines/components_manager.py b/src/diffusers/pipelines/components_manager.py index c5934a8be768..eaa2abaa7f8c 100644 --- a/src/diffusers/pipelines/components_manager.py +++ b/src/diffusers/pipelines/components_manager.py @@ -236,12 +236,15 @@ class ComponentsManager: def __init__(self): self.components = OrderedDict() self.added_time = OrderedDict() # Store when components were added - self.load_ids = OrderedDict() # Store load_id of components (for model loaded with ComponentSpec) self.collections = OrderedDict() # collection_name -> set of component_names self.model_hooks = None self._auto_offload_enabled = False def add(self, name, component, collection: Optional[str] = None): + + if hasattr(component, "_diffusers_load_id") and component._diffusers_load_id != "null": + name = f"{name}_{component._diffusers_load_id}" + if name in self.components: logger.warning(f"Overriding existing component '{name}' in ComponentsManager") @@ -251,9 +254,6 @@ def add(self, name, component, collection: Optional[str] = None): if collection not in self.collections: self.collections[collection] = set() self.collections[collection].add(name) - - if hasattr(component, "_diffusers_load_id"): - self.load_ids[name] = component._diffusers_load_id if self._auto_offload_enabled: self.enable_auto_cpu_offload(self._auto_offload_device) diff --git a/src/diffusers/pipelines/modular_pipeline.py b/src/diffusers/pipelines/modular_pipeline.py index 226c26a0b75d..3ab462934328 100644 --- a/src/diffusers/pipelines/modular_pipeline.py +++ b/src/diffusers/pipelines/modular_pipeline.py @@ -23,6 +23,7 @@ from tqdm.auto import tqdm import re import os +import importlib from huggingface_hub.utils import validate_hf_hub_args @@ -33,7 +34,7 @@ logging, PushToHubMixin, ) -from .pipeline_loading_utils import _get_pipeline_class +from .pipeline_loading_utils import _get_pipeline_class, simple_get_class_obj,_fetch_class_library_tuple from .modular_pipeline_utils import ( ComponentSpec, ConfigSpec, @@ -48,6 +49,7 @@ format_params, make_doc_string, ) +from .components_manager import ComponentsManager from copy import deepcopy if is_accelerate_available(): @@ -156,7 +158,116 @@ def format_value(v): return f"BlockState(\n{attributes}\n)" -class PipelineBlock: + +class ModularPipelineMixin: + """ + Mixin for all PipelineBlocks: PipelineBlock, AutoPipelineBlocks, SequentialPipelineBlocks + """ + + + def setup_loader(self, modular_repo: Optional[Union[str, os.PathLike]] = None, component_manager: Optional[ComponentsManager] = None, collection: Optional[str] = None): + """ + create a mouldar loader, optionally accept modular_repo to load from hub. + """ + + # Import components loader (it is model-specific class) + loader_class_name = MODULAR_LOADER_MAPPING[self.model_name] + diffusers_module = importlib.import_module(self.__module__.split(".")[0]) + loader_class = getattr(diffusers_module, loader_class_name) + + # Create deep copies to avoid modifying the original specs + component_specs = deepcopy(self.expected_components) + config_specs = deepcopy(self.expected_configs) + # Create the loader with the updated specs + specs = component_specs + config_specs + + self.loader = loader_class(specs, modular_repo=modular_repo, component_manager=component_manager, collection=collection) + + + @property + def default_call_parameters(self) -> Dict[str, Any]: + params = {} + for input_param in self.inputs: + params[input_param.name] = input_param.default + return params + + def run(self, state: PipelineState = None, output: Union[str, List[str]] = None, **kwargs): + """ + Run one or more blocks in sequence, optionally you can pass a previous pipeline state. + """ + if state is None: + state = PipelineState() + + if not hasattr(self, "loader"): + raise ValueError("Loader is not set, please call `setup_loader()` first.") + + # Make a copy of the input kwargs + input_params = kwargs.copy() + + default_params = self.default_call_parameters + + # Add inputs to state, using defaults if not provided in the kwargs or the state + # if same input already in the state, will override it if provided in the kwargs + + intermediates_inputs = [inp.name for inp in self.intermediates_inputs] + for name, default in default_params.items(): + if name in input_params: + if name not in intermediates_inputs: + state.add_input(name, input_params.pop(name)) + else: + state.add_input(name, input_params[name]) + elif name not in state.inputs: + state.add_input(name, default) + + for name in intermediates_inputs: + if name in input_params: + state.add_intermediate(name, input_params.pop(name)) + + # Warn about unexpected inputs + if len(input_params) > 0: + logger.warning(f"Unexpected input '{input_params.keys()}' provided. This input will be ignored.") + # Run the pipeline + with torch.no_grad(): + try: + pipeline, state = self(self.loader, state) + except Exception: + error_msg = f"Error in block: ({self.__class__.__name__}):\n" + logger.error(error_msg) + raise + + if output is None: + return state + + + elif isinstance(output, str): + return state.get_intermediate(output) + + elif isinstance(output, (list, tuple)): + return state.get_intermediates(output) + else: + raise ValueError(f"Output '{output}' is not a valid output type") + + @torch.compiler.disable + def progress_bar(self, iterable=None, total=None): + if not hasattr(self, "_progress_bar_config"): + self._progress_bar_config = {} + elif not isinstance(self._progress_bar_config, dict): + raise ValueError( + f"`self._progress_bar_config` should be of type `dict`, but is {type(self._progress_bar_config)}." + ) + + if iterable is not None: + return tqdm(iterable, **self._progress_bar_config) + elif total is not None: + return tqdm(total=total, **self._progress_bar_config) + else: + raise ValueError("Either `total` or `iterable` has to be defined.") + + def set_progress_bar_config(self, **kwargs): + self._progress_bar_config = kwargs + + +class PipelineBlock(ModularPipelineMixin): model_name = None @@ -356,7 +467,7 @@ def combine_outputs(*named_output_lists: List[Tuple[str, List[OutputParam]]]) -> return list(combined_dict.values()) -class AutoPipelineBlocks: +class AutoPipelineBlocks(ModularPipelineMixin): """ A class that automatically selects a block to run based on the inputs. @@ -583,18 +694,6 @@ def __repr__(self): expected_configs = getattr(self, "expected_configs", []) configs_str = format_configs(expected_configs, indent_level=2, add_empty_lines=False) - # Inputs and outputs section - moved up - inputs_str = format_inputs_short(self.inputs) - inputs_str = " Inputs:\n " + inputs_str - - outputs = [out.name for out in self.outputs] - intermediates_str = format_intermediates_short(self.intermediates_inputs, self.required_intermediates_inputs, self.intermediates_outputs) - intermediates_str = ( - " Intermediates:\n" - f"{intermediates_str}\n" - f" - final outputs: {', '.join(outputs)}" - ) - # Blocks section - moved to the end with simplified format blocks_str = " Blocks:\n" for i, (name, block) in enumerate(self.blocks.items()): @@ -624,11 +723,9 @@ def __repr__(self): return ( f"{header}\n" - f"{desc}" - f"{components_str}\n" - f"{configs_str}\n" - f"{inputs_str}\n" - f"{intermediates_str}\n" + f"{desc}\n\n" + f"{components_str}\n\n" + f"{configs_str}\n\n" f"{blocks_str}" f")" ) @@ -646,7 +743,7 @@ def doc(self): expected_configs=self.expected_configs ) -class SequentialPipelineBlocks: +class SequentialPipelineBlocks(ModularPipelineMixin): """ A class that combines multiple pipeline block classes into one. When called, it will call each block in sequence. """ @@ -919,7 +1016,7 @@ def __repr__(self): desc.extend(f" {line}" for line in desc_lines[1:]) desc = '\n'.join(desc) + '\n' - # Components section - use format_components with add_empty_lines=False + # Components section - focus only on expected components expected_components = getattr(self, "expected_components", []) components_str = format_components(expected_components, indent_level=2, add_empty_lines=False) @@ -927,18 +1024,6 @@ def __repr__(self): expected_configs = getattr(self, "expected_configs", []) configs_str = format_configs(expected_configs, indent_level=2, add_empty_lines=False) - # Inputs and outputs section - moved up - inputs_str = format_inputs_short(self.inputs) - inputs_str = " Inputs:\n " + inputs_str - - outputs = [out.name for out in self.outputs] - intermediates_str = format_intermediates_short(self.intermediates_inputs, self.required_intermediates_inputs, self.intermediates_outputs) - intermediates_str = ( - " Intermediates:\n" - f"{intermediates_str}\n" - f" - final outputs: {', '.join(outputs)}" - ) - # Blocks section - moved to the end with simplified format blocks_str = " Blocks:\n" for i, (name, block) in enumerate(self.blocks.items()): @@ -968,11 +1053,9 @@ def __repr__(self): return ( f"{header}\n" - f"{desc}" - f"{components_str}\n" - f"{configs_str}\n" - f"{inputs_str}\n" - f"{intermediates_str}\n" + f"{desc}\n\n" + f"{components_str}\n\n" + f"{configs_str}\n\n" f"{blocks_str}" f")" ) @@ -992,194 +1075,6 @@ def doc(self): -class ModularPipelineMixin: - """ - Mixin for all PipelineBlocks: PipelineBlock, AutoPipelineBlocks, SequentialPipelineBlocks - """ - - # def register_loader(self, global_components_manager: ComponentsManager, label: Optional[str] = None): - # self._global_components_manager = global_components_manager - # self._label = label - - #YiYi TODO: add validation for kwargs? - def setup_loader(self, **kwargs): - """ - Set up the components loader with repository information. - - Args: - **kwargs: Configuration for component loading. - - repo: Default repository to use for all components - - For individual components, pass a tuple of (repo, subfolder) - e.g., text_encoder=("repo_name", "text_encoder") - - Examples: - # Set repo for all components (subfolder will be component name) - setup_loader(repo="stabilityai/stable-diffusion-xl-base-1.0") - - # Set specific repo/subfolder for individual components - setup_loader( - unet=("stabilityai/stable-diffusion-xl-base-1.0", "unet"), - text_encoder=("stabilityai/stable-diffusion-xl-base-1.0", "text_encoder") - ) - - # Set default repo and override for specific components - setup_loader( - repo="stabilityai/stable-diffusion-xl-base-1.0", - unet=(""stabilityai/stable-diffusion-xl-refiner-1.0", "unet") - ) - """ - - # Create deep copies to avoid modifying the original specs - component_specs = deepcopy(self.expected_components) - config_specs = deepcopy(self.expected_configs) - - expected_component_names = set([c.name for c in component_specs]) - expected_config_names = set([c.name for c in config_specs]) - - # Check if a default repo is provided - repo = kwargs.pop("repo", None) - revision = kwargs.pop("revision", None) - variant = kwargs.pop("variant", None) - - passed_component_kwargs = {k: kwargs.pop(k) for k in expected_component_names if k in kwargs} - passed_config_kwargs = {k: kwargs.pop(k) for k in expected_config_names if k in kwargs} - if len(kwargs) > 0: - logger.warning(f"Unused keyword arguments: {kwargs.keys()}. This input will be ignored.") - - for name, value in passed_component_kwargs.items(): - if not isinstance(value, (tuple, list, str)): - raise ValueError(f"Invalid value for component '{name}': {value}. Expected a string, tuple or list") - elif isinstance(value, (tuple, list)) and len(value) > 2: - raise ValueError(f"Invalid value for component '{name}': {value}. Expected a tuple or list of length 1 or 2.") - - for name, value in passed_config_kwargs.items(): - if not isinstance(value, str): - raise ValueError(f"Invalid value for config '{name}': {value}. Expected a string") - - # First apply default repo to all components if provided - if repo is not None: - for component_spec in component_specs: - # components defined with a config are classes like image_processor or guider, - # skip setting loading related attributes for them, they should be initialized with the default config - if component_spec.config is None: - component_spec.repo = repo - - # YiYi TODO: should also accept `revision` and `variant` as a dict here so user can set different values for different components - if revision is not None: - component_spec.revision = revision - if variant is not None: - component_spec.variant = variant - for config_spec in config_specs: - config_spec.repo = repo - - # apply component-specific overrides - for name, value in passed_component_kwargs.items(): - if not isinstance(value, (tuple, list)): - value = (value,) - # Find the matching component spec - for component_spec in component_specs: - if component_spec.name == name: - # Handle tuple of (repo, subfolder) - component_spec.repo = value[0] - if len(value) > 1: - component_spec.subfolder = value[1] - break - - # apply config overrides - for name, value in passed_config_kwargs.items(): - for config_spec in config_specs: - if config_spec.name == name: - config_spec.repo = value - break - - # Import components loader (it is model-specific class) - loader_class_name = MODULAR_LOADER_MAPPING[self.model_name] - diffusers_module = importlib.import_module(self.__module__.split(".")[0]) - loader_class = getattr(diffusers_module, loader_class_name) - - # Create the loader with the updated specs - self.loader = loader_class(component_specs, config_specs) - - - @property - def default_call_parameters(self) -> Dict[str, Any]: - params = {} - for input_param in self.inputs: - params[input_param.name] = input_param.default - return params - - def run(self, state: PipelineState = None, output: Union[str, List[str]] = None, **kwargs): - """ - Run one or more blocks in sequence, optionally you can pass a previous pipeline state. - """ - if state is None: - state = PipelineState() - - # Make a copy of the input kwargs - input_params = kwargs.copy() - - default_params = self.default_call_parameters - - # Add inputs to state, using defaults if not provided in the kwargs or the state - # if same input already in the state, will override it if provided in the kwargs - - intermediates_inputs = [inp.name for inp in self.intermediates_inputs] - for name, default in default_params.items(): - if name in input_params: - if name not in intermediates_inputs: - state.add_input(name, input_params.pop(name)) - else: - state.add_input(name, input_params[name]) - elif name not in state.inputs: - state.add_input(name, default) - - for name in intermediates_inputs: - if name in input_params: - state.add_intermediate(name, input_params.pop(name)) - - # Warn about unexpected inputs - if len(input_params) > 0: - logger.warning(f"Unexpected input '{input_params.keys()}' provided. This input will be ignored.") - # Run the pipeline - with torch.no_grad(): - try: - pipeline, state = self(self, state) - except Exception: - error_msg = f"Error in block: ({self.__class__.__name__}):\n" - logger.error(error_msg) - raise - - if output is None: - return state - - - elif isinstance(output, str): - return state.get_intermediate(output) - - elif isinstance(output, (list, tuple)): - return state.get_intermediates(output) - else: - raise ValueError(f"Output '{output}' is not a valid output type") - - - - -from .pipeline_loading_utils import _fetch_class_library_tuple -import importlib -def simple_import_class_obj(library_name, class_name): - from diffusers import pipelines - is_pipeline_module = hasattr(pipelines, library_name) - - if is_pipeline_module: - pipeline_module = getattr(pipelines, library_name) - class_obj = getattr(pipeline_module, class_name) - else: - library = importlib.import_module(library_name) - class_obj = getattr(library, class_name) - - return class_obj - - # YiYi TODO: # 1. look into the serialization of modular_model_index.json, make sure the items are properly ordered like model_index.json (currently a mess) # 2. do we need ConfigSpec? seems pretty unnecessrary for loader, can just add and kwargs to the loader @@ -1228,9 +1123,12 @@ def register_components(self, **kwargs): # (in the case of the first time registration, we initilize the object with component spec, and then we call register_components() to register it to config) new_component_spec = component_spec component_spec_dict = self._component_spec_to_dict(component_spec) - - - register_dict = {name: (library, class_name, component_spec_dict)} + + # do not register if component is not to be loaded from pretrained + if new_component_spec.default_creation_method == "from_pretrained": + register_dict = {name: (library, class_name, component_spec_dict)} + else: + register_dict = {} # set the component as attribute # if it is not set yet, just set it and skip the process to check and warn below @@ -1238,6 +1136,8 @@ def register_components(self, **kwargs): self.register_to_config(**register_dict) self._component_specs[name] = new_component_spec setattr(self, name, module) + if module is not None and self._component_manager is not None: + self._component_manager.add(name, module, self._collection) continue current_module = getattr(self, name, None) @@ -1272,13 +1172,18 @@ def register_components(self, **kwargs): self._component_specs[name] = new_component_spec # finally set models setattr(self, name, module) + if module is not None and self._component_manager is not None: + self._component_manager.add(name, module, self._collection) + # YiYi TODO: add warning for passing multiple ComponentSpec/ConfigSpec with the same name - def __init__(self, specs: List[Union[ComponentSpec, ConfigSpec]]): + def __init__(self, specs: List[Union[ComponentSpec, ConfigSpec]], modular_repo: Optional[str] = None, component_manager: Optional[ComponentsManager] = None, collection: Optional[str] = None, **kwargs): """ Initialize the loader with a list of component specs and config specs. """ + self._component_manager = component_manager + self._collection = collection self._component_specs = { spec.name: deepcopy(spec) for spec in specs if isinstance(spec, ComponentSpec) } @@ -1286,6 +1191,19 @@ def __init__(self, specs: List[Union[ComponentSpec, ConfigSpec]]): spec.name: deepcopy(spec) for spec in specs if isinstance(spec, ConfigSpec) } + # update component_specs and config_specs from modular_repo + if modular_repo is not None: + config_dict = self.load_config(modular_repo, **kwargs) + + for name, value in config_dict.items(): + if name in self._component_specs and self._component_specs[name].default_creation_method == "from_pretrained" and isinstance(value, (tuple, list)) and len(value) == 3: + library, class_name, component_spec_dict = value + component_spec = self._dict_to_component_spec(name, component_spec_dict) + self._component_specs[name] = component_spec + + elif name in self._config_specs: + self._config_specs[name].default = value + register_components_dict = {} for name, component_spec in self._component_specs.items(): register_components_dict[name] = None @@ -1320,7 +1238,7 @@ def _execution_device(self): Accelerate's module hooks. """ for name, model in self.components.items(): - if not isinstance(model, torch.nn.Module) or name in self._exclude_from_cpu_offload: + if not isinstance(model, torch.nn.Module): continue if not hasattr(model, "_hf_hook"): @@ -1333,8 +1251,21 @@ def _execution_device(self): ): return torch.device(module._hf_hook.execution_device) return self.device + + @property + def device(self) -> torch.device: + r""" + Returns: + `torch.device`: The torch device on which the pipeline is located. + """ + + modules = [m for m in self.components.values() if isinstance(m, torch.nn.Module)] + + for module in modules: + return module.device + + return torch.device("cpu") - @property def dtype(self) -> torch.dtype: r""" @@ -1428,7 +1359,7 @@ def update(self, **kwargs): # YiYi TODO: support map for additional from_pretrained kwargs - def load(self, component_names: List[str], **kwargs): + def load(self, component_names: Optional[List[str]] = None, **kwargs): """ Load selectedcomponents from specs. @@ -1439,7 +1370,9 @@ def load(self, component_names: List[str], **kwargs): - a dict, e.g. torch_dtype={"unet": torch.bfloat16, "default": torch.float32} - if potentially override ComponentSpec if passed a different loading field in kwargs, e.g. `repo`, `variant`, `revision`, etc. """ - if not isinstance(component_names, list): + if component_names is None: + component_names = list(self._component_specs.keys()) + elif not isinstance(component_names, list): component_names = [component_names] components_to_load = set([name for name in component_names if name in self._component_specs]) @@ -1507,6 +1440,14 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P elif name in expected_config: config_specs.append(ConfigSpec(name=name, default=value)) + + for name in expected_component: + for spec in component_specs: + if spec.name == name: + break + else: + # append a empty component spec for these not in modular_model_index + component_specs.append(ComponentSpec(name=name, default_creation_method="from_config")) return cls(component_specs + config_specs) @@ -1583,7 +1524,7 @@ def _dict_to_component_spec( # pull out and resolve the stored type_hint lib_name, cls_name = spec_dict.pop("type_hint") if lib_name is not None and cls_name is not None: - type_hint = simple_import_class_obj(lib_name, cls_name) + type_hint = simple_get_class_obj(lib_name, cls_name) else: type_hint = None @@ -1592,6 +1533,4 @@ def _dict_to_component_spec( name=name, type_hint=type_hint, **spec_dict, - ) - - + ) \ No newline at end of file diff --git a/src/diffusers/pipelines/modular_pipeline_utils.py b/src/diffusers/pipelines/modular_pipeline_utils.py index 05eff8f549da..c8064a5215aa 100644 --- a/src/diffusers/pipelines/modular_pipeline_utils.py +++ b/src/diffusers/pipelines/modular_pipeline_utils.py @@ -58,6 +58,18 @@ class ComponentSpec: default_creation_method: Literal["from_config", "from_pretrained"] = "from_pretrained" + def __hash__(self): + """Make ComponentSpec hashable, using load_id as the hash value.""" + return hash((self.name, self.load_id, self.default_creation_method)) + + def __eq__(self, other): + """Compare ComponentSpec objects based on name and load_id.""" + if not isinstance(other, ComponentSpec): + return False + return (self.name == other.name and + self.load_id == other.load_id and + self.default_creation_method == other.default_creation_method) + @classmethod def from_component(cls, name: str, component: torch.nn.Module) -> Any: """Create a ComponentSpec from a Component created by `create` method.""" @@ -76,6 +88,18 @@ def from_component(cls, name: str, component: torch.nn.Module) -> Any: return cls(name=name, type_hint=type_hint, config=config, **load_spec) + @classmethod + def from_load_id(cls, load_id: str, name: Optional[str] = None) -> Any: + """Create a ComponentSpec from a load_id string.""" + if load_id == "null": + raise ValueError("Cannot create ComponentSpec from null load_id") + + # Decode the load_id into a dictionary of loading fields + load_fields = cls.decode_load_id(load_id) + + # Create a new ComponentSpec instance with the decoded fields + return cls(name=name, **load_fields) + @classmethod def loading_fields(cls) -> List[str]: """ @@ -115,12 +139,13 @@ def decode_load_id(cls, load_id: str) -> Dict[str, Optional[str]]: If a segment value is "null", it's replaced with None. Returns None if load_id is "null" (indicating component not loaded from pretrained). """ - if load_id == "null": - return None # Get all loading fields in order loading_fields = cls.loading_fields() result = {f: None for f in loading_fields} + + if load_id == "null": + return result # Split the load_id parts = load_id.split("|") @@ -149,25 +174,29 @@ def create(self, **kwargs) -> Any: def create_from_config(self, config: Optional[Union[FrozenDict, Dict[str, Any]]] = None, **kwargs) -> Any: """Create component using from_config with config.""" - if self.type_hint is None: + if self.type_hint is None or not isinstance(self.type_hint, type): raise ValueError( f"`type_hint` is required when using from_config creation method." ) - if not (isinstance(self.type_hint, type) and issubclass(self.type_hint, ConfigMixin)): - raise ValueError( - f"cannot create {self.type_hint} using from_config " - "because it is not a `ConfigMixin`." - ) - config = config or self.config + config = config or self.config or {} - try: + if issubclass(self.type_hint, ConfigMixin): component = self.type_hint.from_config(config, **kwargs) - except Exception as e: - raise ValueError(f"Error creating {self.name}[{self.type_hint.__name__}] from config: {e}") + else: + signature_params = inspect.signature(self.type_hint.__init__).parameters + init_kwargs = {} + for k, v in config.items(): + if k in signature_params: + init_kwargs[k] = v + for k, v in kwargs.items(): + if k in signature_params: + init_kwargs[k] = v + component = self.type_hint(**init_kwargs) component._diffusers_load_id = "null" - self.config = component.config + if hasattr(component, "config"): + self.config = component.config return component @@ -455,15 +484,18 @@ def format_components(components, indent_level=4, max_line_length=115, add_empty component_desc = f"{component_indent}{component.name} (`{type_name}`)" if component.description: component_desc += f": {component.description}" - if component.default_repo: - if isinstance(component.default_repo, list) and len(component.default_repo) == 2: - repo_info = component.default_repo[0] - subfolder = component.default_repo[1] - if subfolder: - repo_info += f", subfolder={subfolder}" - else: - repo_info = component.default_repo - component_desc += f" [{repo_info}]" + + # Get the loading fields dynamically + loading_field_values = [] + for field_name in component.loading_fields(): + field_value = getattr(component, field_name) + if field_value is not None: + loading_field_values.append(f"{field_name}={field_value}") + + # Add loading field information if available + if loading_field_values: + component_desc += f" [{', '.join(loading_field_values)}]" + formatted_components.append(component_desc) # Add an empty line after each component except the last one diff --git a/src/diffusers/pipelines/pipeline_loading_utils.py b/src/diffusers/pipelines/pipeline_loading_utils.py index 813566434f52..46ec1c0d4344 100644 --- a/src/diffusers/pipelines/pipeline_loading_utils.py +++ b/src/diffusers/pipelines/pipeline_loading_utils.py @@ -333,6 +333,20 @@ def maybe_raise_or_warn( ) +# a simpler version of get_class_obj_and_candidates, it won't work with custom code +def simple_get_class_obj(library_name, class_name): + from diffusers import pipelines + is_pipeline_module = hasattr(pipelines, library_name) + + if is_pipeline_module: + pipeline_module = getattr(pipelines, library_name) + class_obj = getattr(pipeline_module, class_name) + else: + library = importlib.import_module(library_name) + class_obj = getattr(library, class_name) + + return class_obj + def get_class_obj_and_candidates( library_name, class_name, importable_classes, pipelines, is_pipeline_module, component_name=None, cache_dir=None ): diff --git a/src/diffusers/pipelines/stable_diffusion_xl/__init__.py b/src/diffusers/pipelines/stable_diffusion_xl/__init__.py index 584b260eaaa8..006836fe30d4 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/__init__.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/__init__.py @@ -34,7 +34,7 @@ "StableDiffusionXLDecodeLatentsStep", "StableDiffusionXLDenoiseStep", "StableDiffusionXLInputStep", - "StableDiffusionXLModularPipeline", + "StableDiffusionXLModularLoader", "StableDiffusionXLPrepareAdditionalConditioningStep", "StableDiffusionXLPrepareLatentsStep", "StableDiffusionXLSetTimestepsStep", @@ -65,7 +65,7 @@ StableDiffusionXLDecodeLatentsStep, StableDiffusionXLDenoiseStep, StableDiffusionXLInputStep, - StableDiffusionXLModularPipeline, + StableDiffusionXLModularLoader, StableDiffusionXLPrepareAdditionalConditioningStep, StableDiffusionXLPrepareLatentsStep, StableDiffusionXLSetTimestepsStep, diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py index 1ff0befb1597..0f249e70baf5 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py @@ -56,8 +56,8 @@ CLIPVisionModelWithProjection, ) -from ...schedulers import KarrasDiffusionSchedulers -from ...guiders import GuiderType, ClassifierFreeGuidance +from ...schedulers import EulerDiscreteScheduler +from ...guiders import ClassifierFreeGuidance from ...configuration_utils import FrozenDict import numpy as np @@ -183,9 +183,13 @@ def description(self) -> str: def expected_components(self) -> List[ComponentSpec]: return [ ComponentSpec("image_encoder", CLIPVisionModelWithProjection), - ComponentSpec("feature_extractor", CLIPImageProcessor), + ComponentSpec("feature_extractor", CLIPImageProcessor, config=FrozenDict({"size": 224, "crop_size": 224}), default_creation_method="from_config"), ComponentSpec("unet", UNet2DConditionModel), - ComponentSpec("guider", GuiderType), + ComponentSpec( + "guider", + ClassifierFreeGuidance, + config=FrozenDict({"guidance_scale": 7.5}), + default_creation_method="from_config"), ] @property @@ -321,7 +325,11 @@ def expected_components(self) -> List[ComponentSpec]: ComponentSpec("text_encoder_2", CLIPTextModelWithProjection), ComponentSpec("tokenizer", CLIPTokenizer), ComponentSpec("tokenizer_2", CLIPTokenizer), - ComponentSpec("guider", GuiderType), + ComponentSpec( + "guider", + ClassifierFreeGuidance, + config=FrozenDict({"guidance_scale": 7.5}), + default_creation_method="from_config"), ] @property @@ -647,7 +655,11 @@ def description(self) -> str: def expected_components(self) -> List[ComponentSpec]: return [ ComponentSpec("vae", AutoencoderKL), - ComponentSpec("image_processor", VaeImageProcessor, config=FrozenDict({"vae_scale_factor": 8})), + ComponentSpec( + "image_processor", + VaeImageProcessor, + config=FrozenDict({"vae_scale_factor": 8}), + default_creation_method="from_config"), ] @property @@ -742,8 +754,16 @@ class StableDiffusionXLInpaintVaeEncoderStep(PipelineBlock): def expected_components(self) -> List[ComponentSpec]: return [ ComponentSpec("vae", AutoencoderKL), - ComponentSpec("image_processor", VaeImageProcessor, config=FrozenDict({"vae_scale_factor": 8})), - ComponentSpec("mask_processor", VaeImageProcessor, config=FrozenDict({"do_normalize": False, "vae_scale_factor": 8, "do_binarize": True, "do_convert_grayscale": True})), + ComponentSpec( + "image_processor", + VaeImageProcessor, + config=FrozenDict({"vae_scale_factor": 8}), + default_creation_method="from_config"), + ComponentSpec( + "mask_processor", + VaeImageProcessor, + config=FrozenDict({"do_normalize": False, "vae_scale_factor": 8, "do_binarize": True, "do_convert_grayscale": True}), + default_creation_method="from_config"), ] @@ -1030,7 +1050,7 @@ class StableDiffusionXLImg2ImgSetTimestepsStep(PipelineBlock): @property def expected_components(self) -> List[ComponentSpec]: return [ - ComponentSpec("scheduler", KarrasDiffusionSchedulers), + ComponentSpec("scheduler", EulerDiscreteScheduler), ] @property @@ -1153,7 +1173,7 @@ class StableDiffusionXLSetTimestepsStep(PipelineBlock): @property def expected_components(self) -> List[ComponentSpec]: return [ - ComponentSpec("scheduler", KarrasDiffusionSchedulers), + ComponentSpec("scheduler", EulerDiscreteScheduler), ] @property @@ -1208,7 +1228,7 @@ class StableDiffusionXLInpaintPrepareLatentsStep(PipelineBlock): @property def expected_components(self) -> List[ComponentSpec]: return [ - ComponentSpec("scheduler", KarrasDiffusionSchedulers), + ComponentSpec("scheduler", EulerDiscreteScheduler), ] @property @@ -1462,7 +1482,7 @@ class StableDiffusionXLImg2ImgPrepareLatentsStep(PipelineBlock): def expected_components(self) -> List[ComponentSpec]: return [ ComponentSpec("vae", AutoencoderKL), - ComponentSpec("scheduler", KarrasDiffusionSchedulers), + ComponentSpec("scheduler", EulerDiscreteScheduler), ] @property @@ -1610,7 +1630,7 @@ class StableDiffusionXLPrepareLatentsStep(PipelineBlock): @property def expected_components(self) -> List[ComponentSpec]: return [ - ComponentSpec("scheduler", KarrasDiffusionSchedulers), + ComponentSpec("scheduler", EulerDiscreteScheduler), ] @property @@ -2064,8 +2084,12 @@ class StableDiffusionXLDenoiseStep(PipelineBlock): @property def expected_components(self) -> List[ComponentSpec]: return [ - ComponentSpec("guider", GuiderType, config=FrozenDict({"guidance_scale": 7.5})), - ComponentSpec("scheduler", KarrasDiffusionSchedulers), + ComponentSpec( + "guider", + ClassifierFreeGuidance, + config=FrozenDict({"guidance_scale": 7.5}), + default_creation_method="from_config"), + ComponentSpec("scheduler", EulerDiscreteScheduler), ComponentSpec("unet", UNet2DConditionModel), ] @@ -2240,7 +2264,7 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: data.extra_step_kwargs = self.prepare_extra_step_kwargs(pipeline, data.generator, data.eta) data.num_warmup_steps = max(len(data.timesteps) - data.num_inference_steps * pipeline.scheduler.order, 0) - with pipeline.progress_bar(total=data.num_inference_steps) as progress_bar: + with self.progress_bar(total=data.num_inference_steps) as progress_bar: for i, t in enumerate(data.timesteps): pipeline.guider.set_state(step=i, num_inference_steps=data.num_inference_steps, timestep=t) @@ -2333,11 +2357,15 @@ class StableDiffusionXLControlNetDenoiseStep(PipelineBlock): @property def expected_components(self) -> List[ComponentSpec]: return [ - ComponentSpec("guider", GuiderType, config=FrozenDict({"guidance_scale": 7.5})), - ComponentSpec("scheduler", KarrasDiffusionSchedulers), + ComponentSpec( + "guider", + ClassifierFreeGuidance, + config=FrozenDict({"guidance_scale": 7.5}), + default_creation_method="from_config"), + ComponentSpec("scheduler", EulerDiscreteScheduler), ComponentSpec("unet", UNet2DConditionModel), ComponentSpec("controlnet", ControlNetModel), - ComponentSpec("control_image_processor", VaeImageProcessor, config=FrozenDict({"do_convert_rgb": True, "do_normalize": False})), + ComponentSpec("control_image_processor", VaeImageProcessor, config=FrozenDict({"do_convert_rgb": True, "do_normalize": False}), default_creation_method="from_config"), ] @property @@ -2636,7 +2664,7 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: data.num_warmup_steps = max(len(data.timesteps) - data.num_inference_steps * pipeline.scheduler.order, 0) # (5) Denoise loop - with pipeline.progress_bar(total=data.num_inference_steps) as progress_bar: + with self.progress_bar(total=data.num_inference_steps) as progress_bar: for i, t in enumerate(data.timesteps): pipeline.guider.set_state(step=i, num_inference_steps=data.num_inference_steps, timestep=t) @@ -2763,9 +2791,17 @@ def expected_components(self) -> List[ComponentSpec]: return [ ComponentSpec("unet", UNet2DConditionModel), ComponentSpec("controlnet", ControlNetUnionModel), - ComponentSpec("scheduler", KarrasDiffusionSchedulers), - ComponentSpec("guider", GuiderType, config=FrozenDict({"guidance_scale": 7.5})), - ComponentSpec("control_image_processor", VaeImageProcessor, config=FrozenDict({"do_convert_rgb": True, "do_normalize": False})), + ComponentSpec("scheduler", EulerDiscreteScheduler), + ComponentSpec( + "guider", + ClassifierFreeGuidance, + config=FrozenDict({"guidance_scale": 7.5}), + default_creation_method="from_config"), + ComponentSpec( + "control_image_processor", + VaeImageProcessor, + config=FrozenDict({"do_convert_rgb": True, "do_normalize": False}), + default_creation_method="from_config"), ] @property @@ -3052,7 +3088,7 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: data.extra_step_kwargs = self.prepare_extra_step_kwargs(pipeline, data.generator, data.eta) data.num_warmup_steps = max(len(data.timesteps) - data.num_inference_steps * pipeline.scheduler.order, 0) - with pipeline.progress_bar(total=data.num_inference_steps) as progress_bar: + with self.progress_bar(total=data.num_inference_steps) as progress_bar: for i, t in enumerate(data.timesteps): pipeline.guider.set_state(step=i, num_inference_steps=data.num_inference_steps, timestep=t) @@ -3180,7 +3216,11 @@ class StableDiffusionXLDecodeLatentsStep(PipelineBlock): def expected_components(self) -> List[ComponentSpec]: return [ ComponentSpec("vae", AutoencoderKL), - ComponentSpec("image_processor", VaeImageProcessor, config=FrozenDict({"vae_scale_factor": 8})) + ComponentSpec( + "image_processor", + VaeImageProcessor, + config=FrozenDict({"vae_scale_factor": 8}), + default_creation_method="from_config"), ] @property diff --git a/src/diffusers/utils/dummy_torch_and_transformers_objects.py b/src/diffusers/utils/dummy_torch_and_transformers_objects.py index 0a2c1eefae12..cbfbb842723a 100644 --- a/src/diffusers/utils/dummy_torch_and_transformers_objects.py +++ b/src/diffusers/utils/dummy_torch_and_transformers_objects.py @@ -2417,7 +2417,7 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) -class StableDiffusionXLModularPipeline(metaclass=DummyObject): +class StableDiffusionXLModularLoader(metaclass=DummyObject): _backends = ["torch", "transformers"] def __init__(self, *args, **kwargs): From de8ce5274393b4213f66ce92574bd6c5d465871f Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Wed, 30 Apr 2025 01:09:33 +0200 Subject: [PATCH 32/39] up --- src/diffusers/pipelines/components_manager.py | 292 ++++++++++++++---- src/diffusers/pipelines/modular_pipeline.py | 12 +- 2 files changed, 246 insertions(+), 58 deletions(-) diff --git a/src/diffusers/pipelines/components_manager.py b/src/diffusers/pipelines/components_manager.py index eaa2abaa7f8c..d2c8e9e1f1e1 100644 --- a/src/diffusers/pipelines/components_manager.py +++ b/src/diffusers/pipelines/components_manager.py @@ -232,6 +232,7 @@ def search_best_candidate(module_sizes, min_memory_offload): from .modular_pipeline_utils import ComponentSpec +import uuid class ComponentsManager: def __init__(self): self.components = OrderedDict() @@ -240,26 +241,65 @@ def __init__(self): self.model_hooks = None self._auto_offload_enabled = False + + def _get_by_collection(self, collection: str): + """ + Select components by collection name. + """ + selected_components = {} + if collection in self.collections: + component_ids = self.collections[collection] + for component_id in component_ids: + selected_components[component_id] = self.components[component_id] + return selected_components + + + def _get_by_load_id(self, load_id: str): + """ + Select components by its load_id. + """ + selected_components = {} + for name, component in self.components.items(): + if hasattr(component, "_diffusers_load_id") and component._diffusers_load_id == load_id: + selected_components[name] = component + return selected_components + + def add(self, name, component, collection: Optional[str] = None): + for comp_id, comp in self.components.items(): + if comp == component: + logger.warning(f"Component '{name}' already exists in ComponentsManager") + return comp_id + + component_id = f"{name}_{uuid.uuid4()}" + if hasattr(component, "_diffusers_load_id") and component._diffusers_load_id != "null": - name = f"{name}_{component._diffusers_load_id}" - - if name in self.components: - logger.warning(f"Overriding existing component '{name}' in ComponentsManager") + components_with_same_load_id = self._get_by_load_id(component._diffusers_load_id) + if components_with_same_load_id: + existing = ", ".join(components_with_same_load_id.keys()) + logger.warning( + f"Component '{name}' has duplicate load_id '{component._diffusers_load_id}' with existing components: {existing}. " + f"To remove a duplicate, call `components_manager.remove('')`." + ) - self.components[name] = component - self.added_time[name] = time.time() + + # add component to components manager + self.components[component_id] = component + self.added_time[component_id] = time.time() if collection: if collection not in self.collections: self.collections[collection] = set() - self.collections[collection].add(name) + self.collections[collection].add(component_id) if self._auto_offload_enabled: - self.enable_auto_cpu_offload(self._auto_offload_device) + self.enable_auto_cpu_offload(self._auto_offload_device) + + logger.info(f"Added component '{name}' to ComponentsManager as '{component_id}'") + return component_id - def remove(self, name): + def remove(self, name: Union[str, List[str]]): if name not in self.components: logger.warning(f"Component '{name}' not found in ComponentsManager") @@ -275,27 +315,83 @@ def remove(self, name): if self._auto_offload_enabled: self.enable_auto_cpu_offload(self._auto_offload_device) - # YiYi TODO: looking into improving the search pattern - def get(self, names: Union[str, List[str]]): + # YiYi TODO: looking into improving the search pattern and refactor the code + def get(self, names: Union[str, List[str]] = None, collection: Optional[str] = None, load_id: Optional[str] = None): """ - Get components by name with simple pattern matching. + Select components by name with simple pattern matching. Args: names: Component name(s) or pattern(s) Patterns: - - "unet" : exact match - - "!unet" : everything except exact match "unet" - - "base_*" : everything starting with "base_" - - "!base_*" : everything NOT starting with "base_" - - "*unet*" : anything containing "unet" - - "!*unet*" : anything NOT containing "unet" - - "refiner|vae|unet" : anything containing any of these terms - - "!refiner|vae|unet" : anything NOT containing any of these terms + - "unet" : match any component with base name "unet" (e.g., unet_123abc) + - "!unet" : everything except components with base name "unet" + - "unet*" : anything with base name starting with "unet" + - "!unet*" : anything with base name NOT starting with "unet" + - "*unet*" : anything with base name containing "unet" + - "!*unet*" : anything with base name NOT containing "unet" + - "refiner|vae|unet" : anything with base name exactly matching "refiner", "vae", or "unet" + - "!refiner|vae|unet" : anything with base name NOT exactly matching "refiner", "vae", or "unet" + - "unet*|vae*" : anything with base name starting with "unet" OR starting with "vae" Returns: Single component if names is str and matches one component, dict of components if names matches multiple components or is a list """ + + if collection: + if collection not in self.collections: + logger.warning(f"Collection '{collection}' not found in ComponentsManager") + return {} + components = self._get_by_collection(collection) + else: + components = self.components + + if load_id: + components = self._get_by_load_id(load_id) + + if names is None: + return components + + # Helper to extract base name from component_id + def get_base_name(component_id): + parts = component_id.split('_') + # If the last part looks like a UUID, remove it + if len(parts) > 1 and len(parts[-1]) >= 8 and '-' in parts[-1]: + return '_'.join(parts[:-1]) + return component_id + + # Create mapping from component_id to base_name for all components + base_names = {comp_id: get_base_name(comp_id) for comp_id in components.keys()} + + def matches_pattern(component_id, pattern, exact_match=False): + """ + Helper function to check if a component matches a pattern based on its base name. + + Args: + component_id: The component ID to check + pattern: The pattern to match against + exact_match: If True, only exact matches to base_name are considered + """ + base_name = base_names[component_id] + + # Exact match with base name + if exact_match: + return pattern == base_name + + # Prefix match (ends with *) + elif pattern.endswith('*'): + prefix = pattern[:-1] + return base_name.startswith(prefix) + + # Contains match (starts with *) + elif pattern.startswith('*'): + search = pattern[1:-1] if pattern.endswith('*') else pattern[1:] + return search in base_name + + # Exact match (no wildcards) + else: + return pattern == base_name + if isinstance(names, str): # Check if this is a "not" pattern is_not_pattern = names.startswith('!') @@ -305,33 +401,49 @@ def get(self, names: Union[str, List[str]]): # Handle OR patterns (containing |) if '|' in names: terms = names.split('|') + matches = {} + + for comp_id, comp in components.items(): + # For OR patterns with exact names (no wildcards), we do exact matching on base names + exact_match = all(not (term.startswith('*') or term.endswith('*')) for term in terms) + + # Check if any of the terms match this component + should_include = any(matches_pattern(comp_id, term, exact_match) for term in terms) + + # Flip the decision if this is a NOT pattern + if is_not_pattern: + should_include = not should_include + + if should_include: + matches[comp_id] = comp + + log_msg = "NOT " if is_not_pattern else "" + match_type = "exactly matching" if exact_match else "matching any of patterns" + logger.info(f"Getting components {log_msg}{match_type} {terms}: {list(matches.keys())}") + + # Try exact match with a base name + elif any(names == base_name for base_name in base_names.values()): + # Find all components with this base name matches = { - name: comp for name, comp in self.components.items() - if any((term in name) != is_not_pattern for term in terms) # Flip condition if not pattern + comp_id: comp for comp_id, comp in components.items() + if (base_names[comp_id] == names) != is_not_pattern } + if is_not_pattern: - logger.info(f"Getting components NOT containing any of {terms}: {list(matches.keys())}") - else: - logger.info(f"Getting components containing any of {terms}: {list(matches.keys())}") - - # Exact match - elif names in self.components: - if is_not_pattern: - matches = { - name: comp for name, comp in self.components.items() - if name != names - } - logger.info(f"Getting all components except '{names}': {list(matches.keys())}") + logger.info(f"Getting all components except those with base name '{names}': {list(matches.keys())}") else: - logger.info(f"Getting component: {names}") - return self.components[names] + logger.info(f"Getting components with base name '{names}': {list(matches.keys())}") + + # If there's exactly one match and it's not a NOT pattern, return the component directly + if len(matches) == 1 and not is_not_pattern: + return next(iter(matches.values())) # Prefix match (ends with *) elif names.endswith('*'): prefix = names[:-1] matches = { - name: comp for name, comp in self.components.items() - if name.startswith(prefix) != is_not_pattern # Flip condition if not pattern + comp_id: comp for comp_id, comp in components.items() + if base_names[comp_id].startswith(prefix) != is_not_pattern } if is_not_pattern: logger.info(f"Getting components NOT starting with '{prefix}': {list(matches.keys())}") @@ -342,16 +454,27 @@ def get(self, names: Union[str, List[str]]): elif names.startswith('*'): search = names[1:-1] if names.endswith('*') else names[1:] matches = { - name: comp for name, comp in self.components.items() - if (search in name) != is_not_pattern # Flip condition if not pattern + comp_id: comp for comp_id, comp in components.items() + if (search in base_names[comp_id]) != is_not_pattern } if is_not_pattern: logger.info(f"Getting components NOT containing '{search}': {list(matches.keys())}") else: logger.info(f"Getting components containing '{search}': {list(matches.keys())}") + # Substring match (no wildcards, but not an exact component name) + elif any(names in base_name for base_name in base_names.values()): + matches = { + comp_id: comp for comp_id, comp in components.items() + if (names in base_names[comp_id]) != is_not_pattern + } + if is_not_pattern: + logger.info(f"Getting components NOT containing '{names}': {list(matches.keys())}") + else: + logger.info(f"Getting components containing '{names}': {list(matches.keys())}") + else: - raise ValueError(f"Component '{names}' not found in ComponentsManager") + raise ValueError(f"Component or pattern '{names}' not found in ComponentsManager") if not matches: raise ValueError(f"No components found matching pattern '{names}'") @@ -360,7 +483,7 @@ def get(self, names: Union[str, List[str]]): elif isinstance(names, list): results = {} for name in names: - result = self.get(name) + result = self.get(name, collection) if isinstance(result, dict): results.update(result) else: @@ -409,6 +532,7 @@ def disable_auto_cpu_offload(self): self.model_hooks = None self._auto_offload_enabled = False + # YiYi TODO: add quantization info def get_model_info(self, name: str, fields: Optional[Union[str, List[str]]] = None) -> Optional[Dict[str, Any]]: """Get comprehensive information about a component. @@ -431,14 +555,23 @@ def get_model_info(self, name: str, fields: Optional[Union[str, List[str]]] = No info = { "model_id": name, "added_time": self.added_time[name], + "collection": next((coll for coll, comps in self.collections.items() if name in comps), None), } # Additional info for torch.nn.Module components if isinstance(component, torch.nn.Module): + # Check for hook information + has_hook = hasattr(component, "_hf_hook") + execution_device = None + if has_hook and hasattr(component._hf_hook, "execution_device"): + execution_device = component._hf_hook.execution_device + info.update({ "class_name": component.__class__.__name__, "size_gb": get_memory_footprint(component) / (1024**3), "adapters": None, # Default to None + "has_hook": has_hook, + "execution_device": execution_device, }) # Get adapters if applicable @@ -472,12 +605,56 @@ def get_model_info(self, name: str, fields: Optional[Union[str, List[str]]] = No return info def __repr__(self): + # Helper to get simple name without UUID + def get_simple_name(name): + # Extract the base name by splitting on underscore and taking first part + # This assumes names are in format "name_uuid" + parts = name.split('_') + # If we have at least 2 parts and the last part looks like a UUID, remove it + if len(parts) > 1 and len(parts[-1]) >= 8 and '-' in parts[-1]: + return '_'.join(parts[:-1]) + return name + + # Extract load_id if available + def get_load_id(component): + if hasattr(component, "_diffusers_load_id"): + return component._diffusers_load_id + return "N/A" + + # Format device info compactly + def format_device(component, info): + if not info["has_hook"]: + return str(getattr(component, 'device', 'N/A')) + else: + device = str(getattr(component, 'device', 'N/A')) + exec_device = str(info['execution_device'] or 'N/A') + return f"{device}({exec_device})" + + # Get all simple names to calculate width + simple_names = [get_simple_name(id) for id in self.components.keys()] + + # Get max length of load_ids for models + load_ids = [ + get_load_id(component) + for component in self.components.values() + if isinstance(component, torch.nn.Module) and hasattr(component, "_diffusers_load_id") + ] + max_load_id_len = max([15] + [len(str(lid)) for lid in load_ids]) if load_ids else 15 + + # Collection names + collection_names = [ + next((coll for coll, comps in self.collections.items() if name in comps), "N/A") + for name in self.components.keys() + ] + col_widths = { - "id": max(15, max(len(id) for id in self.components.keys())), + "name": max(15, max(len(name) for name in simple_names)), "class": max(25, max(len(component.__class__.__name__) for component in self.components.values())), - "device": 10, + "device": 15, # Reduced since using more compact format "dtype": 15, "size": 10, + "load_id": max_load_id_len, + "collection": max(10, max(len(str(c)) for c in collection_names)) } # Create the header lines @@ -494,17 +671,23 @@ def __repr__(self): if models: output += "Models:\n" + dash_line # Column headers - output += f"{'Model ID':<{col_widths['id']}} | {'Class':<{col_widths['class']}} | " - output += f"{'Device':<{col_widths['device']}} | {'Dtype':<{col_widths['dtype']}} | Size (GB)\n" + output += f"{'Name':<{col_widths['name']}} | {'Class':<{col_widths['class']}} | " + output += f"{'Device':<{col_widths['device']}} | {'Dtype':<{col_widths['dtype']}} | " + output += f"{'Size (GB)':<{col_widths['size']}} | {'Load ID':<{col_widths['load_id']}} | Collection\n" output += dash_line # Model entries for name, component in models.items(): info = self.get_model_info(name) - device = str(getattr(component, "device", "N/A")) + simple_name = get_simple_name(name) + device_str = format_device(component, info) dtype = str(component.dtype) if hasattr(component, "dtype") else "N/A" - output += f"{name:<{col_widths['id']}} | {info['class_name']:<{col_widths['class']}} | " - output += f"{device:<{col_widths['device']}} | {dtype:<{col_widths['dtype']}} | {info['size_gb']:.2f}\n" + load_id = get_load_id(component) + collection = info["collection"] or "N/A" + + output += f"{simple_name:<{col_widths['name']}} | {info['class_name']:<{col_widths['class']}} | " + output += f"{device_str:<{col_widths['device']}} | {dtype:<{col_widths['dtype']}} | " + output += f"{info['size_gb']:<{col_widths['size']}.2f} | {load_id:<{col_widths['load_id']}} | {collection}\n" output += dash_line # Other components section @@ -513,12 +696,16 @@ def __repr__(self): output += "\n" output += "Other Components:\n" + dash_line # Column headers for other components - output += f"{'Component ID':<{col_widths['id']}} | {'Class':<{col_widths['class']}}\n" + output += f"{'Name':<{col_widths['name']}} | {'Class':<{col_widths['class']}} | Collection\n" output += dash_line # Other component entries for name, component in others.items(): - output += f"{name:<{col_widths['id']}} | {component.__class__.__name__:<{col_widths['class']}}\n" + info = self.get_model_info(name) + simple_name = get_simple_name(name) + collection = info["collection"] or "N/A" + + output += f"{simple_name:<{col_widths['name']}} | {component.__class__.__name__:<{col_widths['class']}} | {collection}\n" output += dash_line # Add additional component info @@ -526,7 +713,8 @@ def __repr__(self): for name in self.components: info = self.get_model_info(name) if info is not None and (info.get("adapters") is not None or info.get("ip_adapter")): - output += f"\n{name}:\n" + simple_name = get_simple_name(name) + output += f"\n{simple_name}:\n" if info.get("adapters") is not None: output += f" Adapters: {info['adapters']}\n" if info.get("ip_adapter"): diff --git a/src/diffusers/pipelines/modular_pipeline.py b/src/diffusers/pipelines/modular_pipeline.py index 3ab462934328..1f1784b186ab 100644 --- a/src/diffusers/pipelines/modular_pipeline.py +++ b/src/diffusers/pipelines/modular_pipeline.py @@ -1101,7 +1101,7 @@ def register_components(self, **kwargs): # current component spec component_spec = self._component_specs.get(name) if component_spec is None: - logger.warning(f"register_components: skipping unknown component '{name}'") + logger.warning(f"ModularLoader.register_components: skipping unknown component '{name}'") continue is_registered = hasattr(self, name) @@ -1143,17 +1143,17 @@ def register_components(self, **kwargs): current_module = getattr(self, name, None) # skip if the component is already registered with the same object if current_module is module: - logger.info(f"register_components: {name} is already registered with same object, skipping") + logger.info(f"ModularLoader.register_components: {name} is already registered with same object, skipping") continue # it module is not an instance of the expected type, still register it but with a warning if module is not None and component_spec.type_hint is not None and not isinstance(module, component_spec.type_hint): - logger.warning(f"register_components: adding {name} with new type: {module.__class__.__name__}, previous type: {component_spec.type_hint.__name__}") + logger.warning(f"ModularLoader.register_components: adding {name} with new type: {module.__class__.__name__}, previous type: {component_spec.type_hint.__name__}") # warn if unregister if current_module is not None and module is None: logger.info( - f"register_components: setting '{name}' to None " + f"ModularLoader.register_components: setting '{name}' to None " f"(was {current_module.__class__.__name__})" ) # same type, new instance → debug @@ -1162,7 +1162,7 @@ def register_components(self, **kwargs): and isinstance(module, current_module.__class__) \ and current_module != module: logger.debug( - f"register_components: replacing existing '{name}' " + f"ModularLoader.register_components: replacing existing '{name}' " f"(same type {type(current_module).__name__}, new instance)" ) @@ -1343,7 +1343,7 @@ def update(self, **kwargs): raise ValueError(f"`ModularLoader` only supports components created from `ComponentSpec`.") if len(kwargs) > 0: - raise logger.warning(f"Unexpected keyword arguments, will be ignored: {kwargs.keys()}") + logger.warning(f"Unexpected keyword arguments, will be ignored: {kwargs.keys()}") self.register_components(**passed_components) From 911361379e6e2641a3e9020c283ffa8354a3037e Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Wed, 30 Apr 2025 04:47:16 +0200 Subject: [PATCH 33/39] component manager update get and get_one --- src/diffusers/pipelines/components_manager.py | 71 ++++++++++++++----- 1 file changed, 52 insertions(+), 19 deletions(-) diff --git a/src/diffusers/pipelines/components_manager.py b/src/diffusers/pipelines/components_manager.py index d2c8e9e1f1e1..bdff133e22d9 100644 --- a/src/diffusers/pipelines/components_manager.py +++ b/src/diffusers/pipelines/components_manager.py @@ -315,8 +315,8 @@ def remove(self, name: Union[str, List[str]]): if self._auto_offload_enabled: self.enable_auto_cpu_offload(self._auto_offload_device) - # YiYi TODO: looking into improving the search pattern and refactor the code - def get(self, names: Union[str, List[str]] = None, collection: Optional[str] = None, load_id: Optional[str] = None): + def get(self, names: Union[str, List[str]] = None, collection: Optional[str] = None, load_id: Optional[str] = None, + as_name_component_tuples: bool = False): """ Select components by name with simple pattern matching. @@ -332,16 +332,20 @@ def get(self, names: Union[str, List[str]] = None, collection: Optional[str] = N - "refiner|vae|unet" : anything with base name exactly matching "refiner", "vae", or "unet" - "!refiner|vae|unet" : anything with base name NOT exactly matching "refiner", "vae", or "unet" - "unet*|vae*" : anything with base name starting with "unet" OR starting with "vae" + collection: Optional collection to filter by + load_id: Optional load_id to filter by + as_name_component_tuples: If True, returns a list of (name, component) tuples using base names + instead of a dictionary with component IDs as keys Returns: - Single component if names is str and matches one component, - dict of components if names matches multiple components or is a list + Dictionary mapping component IDs to components, + or list of (base_name, component) tuples if as_name_component_tuples=True """ if collection: if collection not in self.collections: logger.warning(f"Collection '{collection}' not found in ComponentsManager") - return {} + return [] if as_name_component_tuples else {} components = self._get_by_collection(collection) else: components = self.components @@ -349,9 +353,6 @@ def get(self, names: Union[str, List[str]] = None, collection: Optional[str] = N if load_id: components = self._get_by_load_id(load_id) - if names is None: - return components - # Helper to extract base name from component_id def get_base_name(component_id): parts = component_id.split('_') @@ -360,6 +361,12 @@ def get_base_name(component_id): return '_'.join(parts[:-1]) return component_id + if names is None: + if as_name_component_tuples: + return [(get_base_name(comp_id), comp) for comp_id, comp in components.items()] + else: + return components + # Create mapping from component_id to base_name for all components base_names = {comp_id: get_base_name(comp_id) for comp_id in components.keys()} @@ -433,10 +440,6 @@ def matches_pattern(component_id, pattern, exact_match=False): logger.info(f"Getting all components except those with base name '{names}': {list(matches.keys())}") else: logger.info(f"Getting components with base name '{names}': {list(matches.keys())}") - - # If there's exactly one match and it's not a NOT pattern, return the component directly - if len(matches) == 1 and not is_not_pattern: - return next(iter(matches.values())) # Prefix match (ends with *) elif names.endswith('*'): @@ -478,17 +481,22 @@ def matches_pattern(component_id, pattern, exact_match=False): if not matches: raise ValueError(f"No components found matching pattern '{names}'") - return matches if len(matches) > 1 else next(iter(matches.values())) + + if as_name_component_tuples: + return [(base_names[comp_id], comp) for comp_id, comp in matches.items()] + else: + return matches elif isinstance(names, list): results = {} for name in names: - result = self.get(name, collection) - if isinstance(result, dict): - results.update(result) - else: - results[name] = result - return results + result = self.get(name, collection, load_id, as_name_component_tuples=False) + results.update(result) + + if as_name_component_tuples: + return [(base_names[comp_id], comp) for comp_id, comp in results.items()] + else: + return results else: raise ValueError(f"Invalid type for names: {type(names)}") @@ -767,6 +775,31 @@ def from_pretrained(self, pretrained_model_name_or_path, prefix: Optional[str] = f"2. Use a different prefix: add_from_pretrained(..., prefix='{prefix}_2')" ) + def get_one(self, name: Optional[str] = None, collection: Optional[str] = None, load_id: Optional[str] = None) -> Any: + """ + Get a single component by name. Raises an error if multiple components match or none are found. + + Args: + name: Component name or pattern + collection: Optional collection to filter by + load_id: Optional load_id to filter by + + Returns: + A single component + + Raises: + ValueError: If no components match or multiple components match + """ + results = self.get(name, collection, load_id) + + if not results: + raise ValueError(f"No components found matching '{name}'") + + if len(results) > 1: + raise ValueError(f"Multiple components found matching '{name}': {list(results.keys())}") + + return next(iter(results.values())) + def summarize_dict_by_value_and_parts(d: Dict[str, Any]) -> Dict[str, Any]: """Summarizes a dictionary by finding common prefixes that share the same value. From aaab69c8f3e46289d9a5d30e7c2c276324a42b8a Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Wed, 30 Apr 2025 06:30:57 +0200 Subject: [PATCH 34/39] fix merge --- .../pipeline_stable_diffusion_xl_modular.py | 16 ++-------------- 1 file changed, 2 insertions(+), 14 deletions(-) diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py index 46ef3963421d..94e01208b23e 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py @@ -59,8 +59,6 @@ from ...schedulers import EulerDiscreteScheduler from ...guiders import ClassifierFreeGuidance from ...configuration_utils import FrozenDict -from ...schedulers import KarrasDiffusionSchedulers -from ...guiders import GuiderType, ClassifierFreeGuidance import numpy as np @@ -192,7 +190,6 @@ def expected_components(self) -> List[ComponentSpec]: ClassifierFreeGuidance, config=FrozenDict({"guidance_scale": 7.5}), default_creation_method="from_config"), - ComponentSpec("guider", GuiderType), ] @property @@ -333,7 +330,6 @@ def expected_components(self) -> List[ComponentSpec]: ClassifierFreeGuidance, config=FrozenDict({"guidance_scale": 7.5}), default_creation_method="from_config"), - ComponentSpec("guider", GuiderType), ] @property @@ -2093,8 +2089,6 @@ def expected_components(self) -> List[ComponentSpec]: config=FrozenDict({"guidance_scale": 7.5}), default_creation_method="from_config"), ComponentSpec("scheduler", EulerDiscreteScheduler), - ComponentSpec("guider", GuiderType, obj=ClassifierFreeGuidance()), - ComponentSpec("scheduler", KarrasDiffusionSchedulers), ComponentSpec("unet", UNet2DConditionModel), ] @@ -2276,7 +2270,7 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: ip_adapter_embeds=("ip_adapter_embeds", "negative_ip_adapter_embeds"), ) - with pipeline.progress_bar(total=data.num_inference_steps) as progress_bar: + with self.progress_bar(total=data.num_inference_steps) as progress_bar: for i, t in enumerate(data.timesteps): pipeline.guider.set_state(step=i, num_inference_steps=data.num_inference_steps, timestep=t) guider_data = pipeline.guider.prepare_inputs(data) @@ -2353,12 +2347,9 @@ def expected_components(self) -> List[ComponentSpec]: config=FrozenDict({"guidance_scale": 7.5}), default_creation_method="from_config"), ComponentSpec("scheduler", EulerDiscreteScheduler), - ComponentSpec("guider", GuiderType, obj=ClassifierFreeGuidance()), - ComponentSpec("scheduler", KarrasDiffusionSchedulers), ComponentSpec("unet", UNet2DConditionModel), ComponentSpec("controlnet", ControlNetModel), ComponentSpec("control_image_processor", VaeImageProcessor, config=FrozenDict({"do_convert_rgb": True, "do_normalize": False}), default_creation_method="from_config"), - ComponentSpec("control_image_processor", VaeImageProcessor, obj=VaeImageProcessor(do_convert_rgb=True, do_normalize=False)), ] @property @@ -2782,9 +2773,6 @@ def expected_components(self) -> List[ComponentSpec]: VaeImageProcessor, config=FrozenDict({"do_convert_rgb": True, "do_normalize": False}), default_creation_method="from_config"), - ComponentSpec("scheduler", KarrasDiffusionSchedulers), - ComponentSpec("guider", GuiderType, obj=ClassifierFreeGuidance()), - ComponentSpec("control_image_processor", VaeImageProcessor, obj=VaeImageProcessor(do_convert_rgb=True, do_normalize=False)), ] @property @@ -3078,7 +3066,7 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: ip_adapter_embeds=("ip_adapter_embeds", "negative_ip_adapter_embeds"), ) - with pipeline.progress_bar(total=data.num_inference_steps) as progress_bar: + with self.progress_bar(total=data.num_inference_steps) as progress_bar: for i, t in enumerate(data.timesteps): pipeline.guider.set_state(step=i, num_inference_steps=data.num_inference_steps, timestep=t) guider_data = pipeline.guider.prepare_inputs(data) From c1084b8cb8f55ca1bd8ea39f5e7e3f3818b66c19 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Wed, 30 Apr 2025 06:39:00 +0200 Subject: [PATCH 35/39] more merge fix --- .../classifier_free_guidance_plus_plus.py | 115 ------------------ 1 file changed, 115 deletions(-) delete mode 100644 src/diffusers/guiders/classifier_free_guidance_plus_plus.py diff --git a/src/diffusers/guiders/classifier_free_guidance_plus_plus.py b/src/diffusers/guiders/classifier_free_guidance_plus_plus.py deleted file mode 100644 index d1c6f8744143..000000000000 --- a/src/diffusers/guiders/classifier_free_guidance_plus_plus.py +++ /dev/null @@ -1,115 +0,0 @@ -# Copyright 2024 The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from typing import Dict, Optional, Union, Tuple, List - -import torch - -from .guider_utils import BaseGuidance, rescale_noise_cfg, _default_prepare_inputs - - -class CFGPlusPlusGuidance(BaseGuidance): - """ - CFG++: https://huggingface.co/papers/2406.08070 - - Args: - guidance_scale (`float`, defaults to `0.7`): - The scale parameter for classifier-free guidance. Higher values result in stronger conditioning on the text - prompt, while lower values allow for more freedom in generation. Higher values may lead to saturation and - deterioration of image quality. - guidance_rescale (`float`, defaults to `0.0`): - The rescale factor applied to the noise predictions. This is used to improve image quality and fix - overexposure. Based on Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are - Flawed](https://huggingface.co/papers/2305.08891). - use_original_formulation (`bool`, defaults to `False`): - Whether to use the original formulation of classifier-free guidance as proposed in the paper. By default, - we use the diffusers-native implementation that has been in the codebase for a long time. See - [~guiders.classifier_free_guidance.ClassifierFreeGuidance] for more details. - start (`float`, defaults to `0.0`): - The fraction of the total number of denoising steps after which guidance starts. - stop (`float`, defaults to `1.0`): - The fraction of the total number of denoising steps after which guidance stops. - """ - - _input_predictions = ["pred_cond", "pred_uncond"] - - def __init__( - self, - guidance_scale: float = 0.7, - guidance_rescale: float = 0.0, - use_original_formulation: bool = False, - start: float = 0.0, - stop: float = 1.0, - ): - super().__init__(start, stop) - - self.guidance_scale = guidance_scale - self.guidance_rescale = guidance_rescale - self.use_original_formulation = use_original_formulation - - def prepare_inputs(self, denoiser: torch.nn.Module, *args: Union[Tuple[torch.Tensor], List[torch.Tensor]]) -> Tuple[List[torch.Tensor], ...]: - return _default_prepare_inputs(denoiser, self.num_conditions, *args) - - def prepare_outputs(self, denoiser: torch.nn.Module, pred: torch.Tensor) -> None: - self._num_outputs_prepared += 1 - if self._num_outputs_prepared > self.num_conditions: - raise ValueError(f"Expected {self.num_conditions} outputs, but prepare_outputs called more times.") - key = self._input_predictions[self._num_outputs_prepared - 1] - self._preds[key] = pred - - def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> torch.Tensor: - pred = None - - if not self._is_cfgpp_enabled(): - pred = pred_cond - else: - shift = pred_cond - pred_uncond - pred = pred_cond if self.use_original_formulation else pred_uncond - pred = pred + self.guidance_scale * shift - - if self.guidance_rescale > 0.0: - pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale) - - return pred - - @property - def is_conditional(self) -> bool: - return self._num_outputs_prepared == 0 - - @property - def num_conditions(self) -> int: - num_conditions = 1 - if self._is_cfgpp_enabled(): - num_conditions += 1 - return num_conditions - - @property - def outputs(self) -> Dict[str, torch.Tensor]: - scheduler_step_kwargs = {} - if self._is_cfgpp_enabled(): - scheduler_step_kwargs["_use_cfgpp"] = True - scheduler_step_kwargs["_model_output_uncond"] = self._preds.get("pred_uncond") - return self._preds, scheduler_step_kwargs - - def _is_cfgpp_enabled(self) -> bool: - if not self._enabled: - return False - - is_within_range = True - if self._num_inference_steps is not None: - skip_start_step = int(self._start * self._num_inference_steps) - skip_stop_step = int(self._stop * self._num_inference_steps) - is_within_range = skip_start_step <= self._step < skip_stop_step - - return is_within_range From 9bfddfe65d9c69bfca55595646adfcea8e00b3b0 Mon Sep 17 00:00:00 2001 From: YiYi Xu Date: Tue, 29 Apr 2025 18:43:46 -1000 Subject: [PATCH 36/39] Apply suggestions from code review --- src/diffusers/schedulers/scheduling_ddim.py | 8 ---- .../schedulers/scheduling_euler_discrete.py | 39 ------------------- 2 files changed, 47 deletions(-) diff --git a/src/diffusers/schedulers/scheduling_ddim.py b/src/diffusers/schedulers/scheduling_ddim.py index 2e74c9bbfccd..5c3cc6ed7a16 100644 --- a/src/diffusers/schedulers/scheduling_ddim.py +++ b/src/diffusers/schedulers/scheduling_ddim.py @@ -349,8 +349,6 @@ def step( generator=None, variance_noise: Optional[torch.Tensor] = None, return_dict: bool = True, - _model_output_uncond: Optional[torch.Tensor] = None, - _use_cfgpp: bool = True, ) -> Union[DDIMSchedulerOutput, Tuple]: """ Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion @@ -388,11 +386,6 @@ def step( raise ValueError( "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" ) - - if _use_cfgpp and self.config.prediction_type != "epsilon": - raise ValueError( - f"CFG++ is only supported for prediction type `epsilon`, but got {self.config.prediction_type}." - ) # See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf # Ideally, read DDIM paper in-detail understanding @@ -418,7 +411,6 @@ def step( # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf if self.config.prediction_type == "epsilon": pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5) - pred_epsilon = model_output if not _use_cfgpp else _model_output_uncond elif self.config.prediction_type == "sample": pred_original_sample = model_output pred_epsilon = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5) diff --git a/src/diffusers/schedulers/scheduling_euler_discrete.py b/src/diffusers/schedulers/scheduling_euler_discrete.py index 4c82ca7e389b..f4a4701541c4 100644 --- a/src/diffusers/schedulers/scheduling_euler_discrete.py +++ b/src/diffusers/schedulers/scheduling_euler_discrete.py @@ -584,8 +584,6 @@ def step( s_noise: float = 1.0, generator: Optional[torch.Generator] = None, return_dict: bool = True, - _model_output_uncond: Optional[torch.Tensor] = None, - _use_cfgpp: bool = False, ) -> Union[EulerDiscreteSchedulerOutput, Tuple]: """ Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion @@ -629,11 +627,6 @@ def step( "The `scale_model_input` function should be called before `step` to ensure correct denoising. " "See `StableDiffusionPipeline` for a usage example." ) - - if _use_cfgpp and self.config.prediction_type != "epsilon": - raise ValueError( - f"CFG++ is only supported for prediction type `epsilon`, but got {self.config.prediction_type}." - ) if self.step_index is None: self._init_step_index(timestep) @@ -675,38 +668,6 @@ def step( dt = self.sigmas[self.step_index + 1] - sigma_hat prev_sample = sample + derivative * dt - if _use_cfgpp: - prev_sample = prev_sample + (_model_output_uncond - model_output) * self.sigmas[self.step_index + 1] - - # denoised = sample - model_output * sigmas[i] - # d = (sample - denoised) / sigmas[i] - # new_sample = denoised + d * sigmas[i + 1] - - # new_sample = denoised + (sample - denoised) * sigmas[i + 1] / sigmas[i] - # new_sample = sample - model_output * sigmas[i] + model_output * sigmas[i + 1] - # new_sample = sample + model_output * (sigmas[i + 1] - sigmas[i]) - # new_sample = sample - model_output * sigmas[i] + model_output * sigmas[i + 1] --- (1) - - # CFG++ ===== - # denoised = sample - model_output * sigmas[i] - # uncond_denoised = sample - model_output_uncond * sigmas[i] - # d = (sample - uncond_denoised) / sigmas[i] - # new_sample = denoised + d * sigmas[i + 1] - - # new_sample = denoised + (sample - uncond_denoised) * sigmas[i + 1] / sigmas[i] - # new_sample = sample - model_output * sigmas[i] + model_output_uncond * sigmas[i + 1] --- (2) - - # To go from (1) to (2): - # new_sample_2 = new_sample_1 - model_output * sigmas[i + 1] + model_output_uncond * sigmas[i + 1] - # new_sample_2 = new_sample_1 + (model_output_uncond - model_output) * sigmas[i + 1] - # new_sample_2 = new_sample_1 + diff * sigmas[i + 1] - - # diff = model_output_uncond - model_output - # diff = model_output_uncond - (model_output_uncond + g * (model_output_cond - model_output_uncond)) - # diff = model_output_uncond - (g * model_output_cond + (1 - g) * model_output_uncond) - # diff = model_output_uncond - g * model_output_cond + (g - 1) * model_output_uncond - # diff = g * (model_output_uncond - model_output_cond) - # Cast sample back to model compatible dtype prev_sample = prev_sample.to(model_output.dtype) From cdb31dfccd5c876baa6fffcbd18db331b462b8e6 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Wed, 30 Apr 2025 06:47:51 +0200 Subject: [PATCH 37/39] more merge fix --- src/diffusers/schedulers/scheduling_ddim.py | 1 + src/diffusers/schedulers/scheduling_euler_discrete.py | 1 + 2 files changed, 2 insertions(+) diff --git a/src/diffusers/schedulers/scheduling_ddim.py b/src/diffusers/schedulers/scheduling_ddim.py index 5c3cc6ed7a16..13c9b3b4a5e9 100644 --- a/src/diffusers/schedulers/scheduling_ddim.py +++ b/src/diffusers/schedulers/scheduling_ddim.py @@ -411,6 +411,7 @@ def step( # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf if self.config.prediction_type == "epsilon": pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5) + pred_epsilon = model_output elif self.config.prediction_type == "sample": pred_original_sample = model_output pred_epsilon = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5) diff --git a/src/diffusers/schedulers/scheduling_euler_discrete.py b/src/diffusers/schedulers/scheduling_euler_discrete.py index f4a4701541c4..56757f3ca197 100644 --- a/src/diffusers/schedulers/scheduling_euler_discrete.py +++ b/src/diffusers/schedulers/scheduling_euler_discrete.py @@ -668,6 +668,7 @@ def step( dt = self.sigmas[self.step_index + 1] - sigma_hat prev_sample = sample + derivative * dt + # Cast sample back to model compatible dtype prev_sample = prev_sample.to(model_output.dtype) From 45ca4309b12b7cf3efdafa0fe6086e6ffb09e6fc Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Wed, 30 Apr 2025 06:56:56 +0200 Subject: [PATCH 38/39] ModularPipeline -> ModularLoader --- src/diffusers/__init__.py | 4 ++-- src/diffusers/pipelines/__init__.py | 4 ++-- src/diffusers/pipelines/pipeline_loading_utils.py | 2 +- .../pipeline_stable_diffusion_xl_modular.py | 2 +- src/diffusers/utils/dummy_pt_objects.py | 2 +- 5 files changed, 7 insertions(+), 7 deletions(-) diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 11c43bab28cc..c9ee38ac6fda 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -249,7 +249,7 @@ "KarrasVePipeline", "LDMPipeline", "LDMSuperResolutionPipeline", - "ModularPipeline", + "ModularLoader", "PNDMPipeline", "RePaintPipeline", "ScoreSdeVePipeline", @@ -840,7 +840,7 @@ KarrasVePipeline, LDMPipeline, LDMSuperResolutionPipeline, - ModularPipeline, + ModularLoader, PNDMPipeline, RePaintPipeline, ScoreSdeVePipeline, diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index e0b47681af6f..7b6bd2071ef4 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -46,7 +46,7 @@ "AutoPipelineForInpainting", "AutoPipelineForText2Image", ] - _import_structure["modular_pipeline"] = ["ModularPipeline"] + _import_structure["modular_pipeline"] = ["ModularLoader"] _import_structure["consistency_models"] = ["ConsistencyModelPipeline"] _import_structure["dance_diffusion"] = ["DanceDiffusionPipeline"] _import_structure["ddim"] = ["DDIMPipeline"] @@ -468,7 +468,7 @@ from .deprecated import KarrasVePipeline, LDMPipeline, PNDMPipeline, RePaintPipeline, ScoreSdeVePipeline from .dit import DiTPipeline from .latent_diffusion import LDMSuperResolutionPipeline - from .modular_pipeline import ModularPipeline + from .modular_pipeline import ModularLoader from .pipeline_utils import ( AudioPipelineOutput, DiffusionPipeline, diff --git a/src/diffusers/pipelines/pipeline_loading_utils.py b/src/diffusers/pipelines/pipeline_loading_utils.py index 46ec1c0d4344..48d5992f31ee 100644 --- a/src/diffusers/pipelines/pipeline_loading_utils.py +++ b/src/diffusers/pipelines/pipeline_loading_utils.py @@ -428,7 +428,7 @@ def _get_pipeline_class( revision=revision, ) - if class_obj.__name__ != "DiffusionPipeline" and class_obj.__name__ != "ModularPipeline": + if class_obj.__name__ != "DiffusionPipeline": return class_obj diffusers_module = importlib.import_module(class_obj.__module__.split(".")[0]) diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py index 94e01208b23e..5ae9e63851db 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py @@ -3569,7 +3569,7 @@ def description(self): # YiYi Notes: model specific components: -## (1) it should inherit from ModularPipelineComponents +## (1) it should inherit from ModularLoader ## (2) acts like a container that holds components and configs ## (3) define default config (related to components), e.g. default_sample_size, vae_scale_factor, num_channels_unet, num_channels_latents ## (4) inherit from model-specic loader class (e.g. StableDiffusionXLLoraLoaderMixin) diff --git a/src/diffusers/utils/dummy_pt_objects.py b/src/diffusers/utils/dummy_pt_objects.py index bea14cfe9c8d..f3837e39f192 100644 --- a/src/diffusers/utils/dummy_pt_objects.py +++ b/src/diffusers/utils/dummy_pt_objects.py @@ -1328,7 +1328,7 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) -class ModularPipeline(metaclass=DummyObject): +class ModularLoader(metaclass=DummyObject): _backends = ["torch"] def __init__(self, *args, **kwargs): From 35fa520625b9fa6151541fb9ab674eba46ce2ca5 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Wed, 30 Apr 2025 23:07:47 +0200 Subject: [PATCH 39/39] up --- src/diffusers/pipelines/modular_pipeline.py | 2 +- src/diffusers/pipelines/pipeline_utils.py | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/src/diffusers/pipelines/modular_pipeline.py b/src/diffusers/pipelines/modular_pipeline.py index 1f1784b186ab..636b543395df 100644 --- a/src/diffusers/pipelines/modular_pipeline.py +++ b/src/diffusers/pipelines/modular_pipeline.py @@ -172,7 +172,7 @@ def setup_loader(self, modular_repo: Optional[Union[str, os.PathLike]] = None, c # Import components loader (it is model-specific class) loader_class_name = MODULAR_LOADER_MAPPING[self.model_name] - diffusers_module = importlib.import_module(self.__module__.split(".")[0]) + diffusers_module = importlib.import_module("diffusers") loader_class = getattr(diffusers_module, loader_class_name) # Create deep copies to avoid modifying the original specs diff --git a/src/diffusers/pipelines/pipeline_utils.py b/src/diffusers/pipelines/pipeline_utils.py index c27cd434cd9a..22b0baee2e39 100644 --- a/src/diffusers/pipelines/pipeline_utils.py +++ b/src/diffusers/pipelines/pipeline_utils.py @@ -1917,9 +1917,10 @@ def from_pipe(cls, pipeline, **kwargs): f"{'' if k.startswith('_') else '_'}{k}": v for k, v in original_config.items() if k not in pipeline_kwargs } + optional_components = pipeline._optional_components if hasattr(pipeline, "_optional_components") and pipeline._optional_components else [] missing_modules = ( set(expected_modules) - - set(pipeline._optional_components) + - set(optional_components) - set(pipeline_kwargs.keys()) - set(true_optional_modules) )