diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index ae8768ae9f72..7a3de0b95747 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -34,10 +34,12 @@ _import_structure = { "configuration_utils": ["ConfigMixin"], + "guiders": [], "hooks": [], "loaders": ["FromOriginalModelMixin"], "models": [], "pipelines": [], + "modular_pipelines": [], "quantizers.quantization_config": [], "schedulers": [], "utils": [ @@ -130,12 +132,26 @@ _import_structure["utils.dummy_pt_objects"] = [name for name in dir(dummy_pt_objects) if not name.startswith("_")] else: + _import_structure["guiders"].extend( + [ + "AdaptiveProjectedGuidance", + "AutoGuidance", + "ClassifierFreeGuidance", + "ClassifierFreeZeroStarGuidance", + "SkipLayerGuidance", + "SmoothedEnergyGuidance", + "TangentialClassifierFreeGuidance", + ] + ) _import_structure["hooks"].extend( [ "FasterCacheConfig", "HookRegistry", "PyramidAttentionBroadcastConfig", + "LayerSkipConfig", + "SmoothedEnergyGuidanceConfig", "apply_faster_cache", + "apply_layer_skip", "apply_pyramid_attention_broadcast", ] ) @@ -239,13 +255,19 @@ "KarrasVePipeline", "LDMPipeline", "LDMSuperResolutionPipeline", - "ModularPipeline", "PNDMPipeline", "RePaintPipeline", "ScoreSdeVePipeline", "StableDiffusionMixin", ] ) + _import_structure["modular_pipelines"].extend( + [ + "ModularLoader", + "ComponentSpec", + "ComponentsManager", + ] + ) _import_structure["quantizers"] = ["DiffusersQuantizer"] _import_structure["schedulers"].extend( [ @@ -494,12 +516,10 @@ "StableDiffusionXLImg2ImgPipeline", "StableDiffusionXLInpaintPipeline", "StableDiffusionXLInstructPix2PixPipeline", - "StableDiffusionXLModularPipeline", "StableDiffusionXLPAGImg2ImgPipeline", "StableDiffusionXLPAGInpaintPipeline", "StableDiffusionXLPAGPipeline", "StableDiffusionXLPipeline", - "StableDiffusionXLAutoPipeline", "StableUnCLIPImg2ImgPipeline", "StableUnCLIPPipeline", "StableVideoDiffusionPipeline", @@ -526,6 +546,24 @@ ] ) + +try: + if not (is_torch_available() and is_transformers_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from .utils import dummy_torch_and_transformers_objects # noqa F403 + + _import_structure["utils.dummy_torch_and_transformers_objects"] = [ + name for name in dir(dummy_torch_and_transformers_objects) if not name.startswith("_") + ] + +else: + _import_structure["modular_pipelines"].extend( + [ + "StableDiffusionXLAutoPipeline", + "StableDiffusionXLModularLoader", + ] + ) try: if not (is_torch_available() and is_transformers_available() and is_opencv_available()): raise OptionalDependencyNotAvailable() @@ -731,10 +769,22 @@ except OptionalDependencyNotAvailable: from .utils.dummy_pt_objects import * # noqa F403 else: + from .guiders import ( + AdaptiveProjectedGuidance, + AutoGuidance, + ClassifierFreeGuidance, + ClassifierFreeZeroStarGuidance, + SkipLayerGuidance, + SmoothedEnergyGuidance, + TangentialClassifierFreeGuidance, + ) from .hooks import ( FasterCacheConfig, HookRegistry, + LayerSkipConfig, PyramidAttentionBroadcastConfig, + SmoothedEnergyGuidanceConfig, + apply_layer_skip, apply_faster_cache, apply_pyramid_attention_broadcast, ) @@ -837,12 +887,16 @@ KarrasVePipeline, LDMPipeline, LDMSuperResolutionPipeline, - ModularPipeline, PNDMPipeline, RePaintPipeline, ScoreSdeVePipeline, StableDiffusionMixin, ) + from .modular_pipelines import ( + ModularLoader, + ComponentSpec, + ComponentsManager, + ) from .quantizers import DiffusersQuantizer from .schedulers import ( AmusedScheduler, @@ -1070,12 +1124,10 @@ StableDiffusionXLImg2ImgPipeline, StableDiffusionXLInpaintPipeline, StableDiffusionXLInstructPix2PixPipeline, - StableDiffusionXLModularPipeline, StableDiffusionXLPAGImg2ImgPipeline, StableDiffusionXLPAGInpaintPipeline, StableDiffusionXLPAGPipeline, StableDiffusionXLPipeline, - StableDiffusionXLAutoPipeline, StableUnCLIPImg2ImgPipeline, StableUnCLIPPipeline, StableVideoDiffusionPipeline, @@ -1100,7 +1152,16 @@ WuerstchenDecoderPipeline, WuerstchenPriorPipeline, ) - + try: + if not (is_torch_available() and is_transformers_available()): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + from .utils.dummy_torch_and_transformers_objects import * # noqa F403 + else: + from .modular_pipelines import ( + StableDiffusionXLAutoPipeline, + StableDiffusionXLModularLoader, + ) try: if not (is_torch_available() and is_transformers_available() and is_k_diffusion_available()): raise OptionalDependencyNotAvailable() diff --git a/src/diffusers/guider.py b/src/diffusers/guider.py deleted file mode 100644 index 7445b7ba97af..000000000000 --- a/src/diffusers/guider.py +++ /dev/null @@ -1,745 +0,0 @@ -# Copyright 2024 The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import re -from typing import Any, Dict, List, Optional, Tuple, Union - -import torch -import torch.nn as nn - -from .models.attention_processor import ( - Attention, - AttentionProcessor, - PAGCFGIdentitySelfAttnProcessor2_0, - PAGIdentitySelfAttnProcessor2_0, -) -from .utils import logging - - -logger = logging.get_logger(__name__) # pylint: disable=invalid-name - - -# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg -def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): - r""" - Rescales `noise_cfg` tensor based on `guidance_rescale` to improve image quality and fix overexposure. Based on - Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are - Flawed](https://arxiv.org/pdf/2305.08891.pdf). - - Args: - noise_cfg (`torch.Tensor`): - The predicted noise tensor for the guided diffusion process. - noise_pred_text (`torch.Tensor`): - The predicted noise tensor for the text-guided diffusion process. - guidance_rescale (`float`, *optional*, defaults to 0.0): - A rescale factor applied to the noise predictions. - - Returns: - noise_cfg (`torch.Tensor`): The rescaled noise prediction tensor. - """ - std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True) - std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True) - # rescale the results from guidance (fixes overexposure) - noise_pred_rescaled = noise_cfg * (std_text / std_cfg) - # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images - noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg - return noise_cfg - - -class CFGGuider: - """ - This class is used to guide the pipeline with CFG (Classifier-Free Guidance). - """ - - @property - def do_classifier_free_guidance(self): - return self._guidance_scale > 1.0 and not self._disable_guidance - - @property - def guidance_rescale(self): - return self._guidance_rescale - - @property - def guidance_scale(self): - return self._guidance_scale - - @property - def batch_size(self): - return self._batch_size - - def set_guider(self, pipeline, guider_kwargs: Dict[str, Any]): - # a flag to disable CFG, e.g. we disable it for LCM and use a guidance scale embedding instead - disable_guidance = guider_kwargs.get("disable_guidance", False) - guidance_scale = guider_kwargs.get("guidance_scale", None) - if guidance_scale is None: - raise ValueError("guidance_scale is not provided in guider_kwargs") - guidance_rescale = guider_kwargs.get("guidance_rescale", 0.0) - batch_size = guider_kwargs.get("batch_size", None) - if batch_size is None: - raise ValueError("batch_size is not provided in guider_kwargs") - self._guidance_scale = guidance_scale - self._guidance_rescale = guidance_rescale - self._batch_size = batch_size - self._disable_guidance = disable_guidance - - def reset_guider(self, pipeline): - pass - - def maybe_update_guider(self, pipeline, timestep): - pass - - def maybe_update_input(self, pipeline, cond_input): - pass - - def _maybe_split_prepared_input(self, cond): - """ - Process and potentially split the conditional input for Classifier-Free Guidance (CFG). - - This method handles inputs that may already have CFG applied (i.e. when `cond` is output of `prepare_input`). - It determines whether to split the input based on its batch size relative to the expected batch size. - - Args: - cond (torch.Tensor): The conditional input tensor to process. - - Returns: - Tuple[torch.Tensor, torch.Tensor]: A tuple containing: - - The negative conditional input (uncond_input) - - The positive conditional input (cond_input) - """ - if cond.shape[0] == self.batch_size * 2: - neg_cond = cond[0 : self.batch_size] - cond = cond[self.batch_size :] - return neg_cond, cond - elif cond.shape[0] == self.batch_size: - return cond, cond - else: - raise ValueError(f"Unsupported input shape: {cond.shape}") - - def _is_prepared_input(self, cond): - """ - Check if the input is already prepared for Classifier-Free Guidance (CFG). - - Args: - cond (torch.Tensor): The conditional input tensor to check. - - Returns: - bool: True if the input is already prepared, False otherwise. - """ - cond_tensor = cond[0] if isinstance(cond, (list, tuple)) else cond - - return cond_tensor.shape[0] == self.batch_size * 2 - - def prepare_input( - self, - cond_input: Union[torch.Tensor, List[torch.Tensor]], - negative_cond_input: Optional[Union[torch.Tensor, List[torch.Tensor]]] = None, - ) -> Union[torch.Tensor, List[torch.Tensor]]: - """ - Prepare the input for CFG. - - Args: - cond_input (Union[torch.Tensor, List[torch.Tensor]]): - The conditional input. It can be a single tensor or a - list of tensors. It must have the same length as `negative_cond_input`. - negative_cond_input (Union[torch.Tensor, List[torch.Tensor]]): The negative conditional input. It can be a - single tensor or a list of tensors. It must have the same length as `cond_input`. - - Returns: - Union[torch.Tensor, List[torch.Tensor]]: The prepared input. - """ - - # we check if cond_input already has CFG applied, and split if it is the case. - if self._is_prepared_input(cond_input) and self.do_classifier_free_guidance: - return cond_input - - if self._is_prepared_input(cond_input) and not self.do_classifier_free_guidance: - if isinstance(cond_input, list): - negative_cond_input, cond_input = zip(*[self._maybe_split_prepared_input(cond) for cond in cond_input]) - else: - negative_cond_input, cond_input = self._maybe_split_prepared_input(cond_input) - - if not self._is_prepared_input(cond_input) and self.do_classifier_free_guidance and negative_cond_input is None: - raise ValueError( - "`negative_cond_input` is required when cond_input does not already contains negative conditional input" - ) - - if isinstance(cond_input, (list, tuple)): - if not self.do_classifier_free_guidance: - return cond_input - - if len(negative_cond_input) != len(cond_input): - raise ValueError("The length of negative_cond_input and cond_input must be the same.") - prepared_input = [] - for neg_cond, cond in zip(negative_cond_input, cond_input): - if neg_cond.shape[0] != cond.shape[0]: - raise ValueError("The batch size of negative_cond_input and cond_input must be the same.") - prepared_input.append(torch.cat([neg_cond, cond], dim=0)) - return prepared_input - - elif isinstance(cond_input, torch.Tensor): - if not self.do_classifier_free_guidance: - return cond_input - else: - return torch.cat([negative_cond_input, cond_input], dim=0) - - else: - raise ValueError(f"Unsupported input type: {type(cond_input)}") - - def apply_guidance( - self, - model_output: torch.Tensor, - timestep: int = None, - latents: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - if not self.do_classifier_free_guidance: - return model_output - - noise_pred_uncond, noise_pred_text = model_output.chunk(2) - noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) - - if self.guidance_rescale > 0.0: - # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf - noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale) - return noise_pred - - -class PAGGuider: - """ - This class is used to guide the pipeline with CFG (Classifier-Free Guidance). - """ - - def __init__( - self, - pag_applied_layers: Union[str, List[str]], - pag_attn_processors: Tuple[AttentionProcessor, AttentionProcessor] = ( - PAGCFGIdentitySelfAttnProcessor2_0(), - PAGIdentitySelfAttnProcessor2_0(), - ), - ): - r""" - Set the the self-attention layers to apply PAG. Raise ValueError if the input is invalid. - - Args: - pag_applied_layers (`str` or `List[str]`): - One or more strings identifying the layer names, or a simple regex for matching multiple layers, where - PAG is to be applied. A few ways of expected usage are as follows: - - Single layers specified as - "blocks.{layer_index}" - - Multiple layers as a list - ["blocks.{layers_index_1}", "blocks.{layer_index_2}", ...] - - Multiple layers as a block name - "mid" - - Multiple layers as regex - "blocks.({layer_index_1}|{layer_index_2})" - pag_attn_processors: - (`Tuple[AttentionProcessor, AttentionProcessor]`, defaults to `(PAGCFGIdentitySelfAttnProcessor2_0(), - PAGIdentitySelfAttnProcessor2_0())`): A tuple of two attention processors. The first attention - processor is for PAG with Classifier-free guidance enabled (conditional and unconditional). The second - attention processor is for PAG with CFG disabled (unconditional only). - """ - - if not isinstance(pag_applied_layers, list): - pag_applied_layers = [pag_applied_layers] - if pag_attn_processors is not None: - if not isinstance(pag_attn_processors, tuple) or len(pag_attn_processors) != 2: - raise ValueError("Expected a tuple of two attention processors") - - for i in range(len(pag_applied_layers)): - if not isinstance(pag_applied_layers[i], str): - raise ValueError( - f"Expected either a string or a list of string but got type {type(pag_applied_layers[i])}" - ) - - self.pag_applied_layers = pag_applied_layers - self._pag_attn_processors = pag_attn_processors - - def _set_pag_attn_processor(self, model, pag_applied_layers, do_classifier_free_guidance): - r""" - Set the attention processor for the PAG layers. - """ - pag_attn_processors = self._pag_attn_processors - pag_attn_proc = pag_attn_processors[0] if do_classifier_free_guidance else pag_attn_processors[1] - - def is_self_attn(module: nn.Module) -> bool: - r""" - Check if the module is self-attention module based on its name. - """ - return isinstance(module, Attention) and not module.is_cross_attention - - def is_fake_integral_match(layer_id, name): - layer_id = layer_id.split(".")[-1] - name = name.split(".")[-1] - return layer_id.isnumeric() and name.isnumeric() and layer_id == name - - for layer_id in pag_applied_layers: - # for each PAG layer input, we find corresponding self-attention layers in the unet model - target_modules = [] - - for name, module in model.named_modules(): - # Identify the following simple cases: - # (1) Self Attention layer existing - # (2) Whether the module name matches pag layer id even partially - # (3) Make sure it's not a fake integral match if the layer_id ends with a number - # For example, blocks.1, blocks.10 should be differentiable if layer_id="blocks.1" - if ( - is_self_attn(module) - and re.search(layer_id, name) is not None - and not is_fake_integral_match(layer_id, name) - ): - logger.debug(f"Applying PAG to layer: {name}") - target_modules.append(module) - - if len(target_modules) == 0: - raise ValueError(f"Cannot find PAG layer to set attention processor for: {layer_id}") - - for module in target_modules: - module.processor = pag_attn_proc - - @property - def do_classifier_free_guidance(self): - return self._guidance_scale > 1 and not self._disable_guidance - - @property - def do_perturbed_attention_guidance(self): - return self._pag_scale > 0 and not self._disable_guidance - - @property - def do_pag_adaptive_scaling(self): - return self._pag_adaptive_scale > 0 and self._pag_scale > 0 and not self._disable_guidance - - @property - def guidance_scale(self): - return self._guidance_scale - - @property - def guidance_rescale(self): - return self._guidance_rescale - - @property - def batch_size(self): - return self._batch_size - - @property - def pag_scale(self): - return self._pag_scale - - @property - def pag_adaptive_scale(self): - return self._pag_adaptive_scale - - def set_guider(self, pipeline, guider_kwargs: Dict[str, Any]): - pag_scale = guider_kwargs.get("pag_scale", 3.0) - pag_adaptive_scale = guider_kwargs.get("pag_adaptive_scale", 0.0) - - batch_size = guider_kwargs.get("batch_size", None) - if batch_size is None: - raise ValueError("batch_size is a required argument for PAGGuider") - - guidance_scale = guider_kwargs.get("guidance_scale", None) - guidance_rescale = guider_kwargs.get("guidance_rescale", 0.0) - disable_guidance = guider_kwargs.get("disable_guidance", False) - - if guidance_scale is None: - raise ValueError("guidance_scale is a required argument for PAGGuider") - - self._pag_scale = pag_scale - self._pag_adaptive_scale = pag_adaptive_scale - self._guidance_scale = guidance_scale - self._disable_guidance = disable_guidance - self._guidance_rescale = guidance_rescale - self._batch_size = batch_size - if not hasattr(pipeline, "original_attn_proc") or pipeline.original_attn_proc is None: - pipeline.original_attn_proc = pipeline.unet.attn_processors - self._set_pag_attn_processor( - model=pipeline.unet if hasattr(pipeline, "unet") else pipeline.transformer, - pag_applied_layers=self.pag_applied_layers, - do_classifier_free_guidance=self.do_classifier_free_guidance, - ) - - def reset_guider(self, pipeline): - if ( - self.do_perturbed_attention_guidance - and hasattr(pipeline, "original_attn_proc") - and pipeline.original_attn_proc is not None - ): - pipeline.unet.set_attn_processor(pipeline.original_attn_proc) - pipeline.original_attn_proc = None - - def maybe_update_guider(self, pipeline, timestep): - pass - - def maybe_update_input(self, pipeline, cond_input): - pass - - def _is_prepared_input(self, cond): - """ - Check if the input is already prepared for Perturbed Attention Guidance (PAG). - - Args: - cond (torch.Tensor): The conditional input tensor to check. - - Returns: - bool: True if the input is already prepared, False otherwise. - """ - cond_tensor = cond[0] if isinstance(cond, (list, tuple)) else cond - - return cond_tensor.shape[0] == self.batch_size * 3 - - def _maybe_split_prepared_input(self, cond): - """ - Process and potentially split the conditional input for Classifier-Free Guidance (CFG). - - This method handles inputs that may already have CFG applied (i.e. when `cond` is output of `prepare_input`). - It determines whether to split the input based on its batch size relative to the expected batch size. - - Args: - cond (torch.Tensor): The conditional input tensor to process. - - Returns: - Tuple[torch.Tensor, torch.Tensor]: A tuple containing: - - The negative conditional input (uncond_input) - - The positive conditional input (cond_input) - """ - if cond.shape[0] == self.batch_size * 3: - neg_cond = cond[0 : self.batch_size] - cond = cond[self.batch_size : self.batch_size * 2] - return neg_cond, cond - elif cond.shape[0] == self.batch_size: - return cond, cond - else: - raise ValueError(f"Unsupported input shape: {cond.shape}") - - def prepare_input( - self, - cond_input: Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor, ...]], - negative_cond_input: Optional[Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor, ...]]] = None, - ) -> Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor, ...]]: - """ - Prepare the input for CFG. - - Args: - cond_input (Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor, ...]]): - The conditional input. It can be a single tensor or a - list of tensors. It must have the same length as `negative_cond_input`. - negative_cond_input (Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor, ...]]): - The negative conditional input. It can be a single tensor or a list of tensors. It must have the same - length as `cond_input`. - - Returns: - Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor, ...]]: The prepared input. - """ - - # we check if cond_input already has CFG applied, and split if it is the case. - - if self._is_prepared_input(cond_input) and self.do_perturbed_attention_guidance: - return cond_input - - if self._is_prepared_input(cond_input) and not self.do_perturbed_attention_guidance: - if isinstance(cond_input, list): - negative_cond_input, cond_input = zip(*[self._maybe_split_prepared_input(cond) for cond in cond_input]) - else: - negative_cond_input, cond_input = self._maybe_split_prepared_input(cond_input) - - if not self._is_prepared_input(cond_input) and self.do_perturbed_attention_guidance and negative_cond_input is None: - raise ValueError( - "`negative_cond_input` is required when cond_input does not already contains negative conditional input" - ) - - if isinstance(cond_input, (list, tuple)): - if not self.do_perturbed_attention_guidance: - return cond_input - - if len(negative_cond_input) != len(cond_input): - raise ValueError("The length of negative_cond_input and cond_input must be the same.") - - prepared_input = [] - for neg_cond, cond in zip(negative_cond_input, cond_input): - if neg_cond.shape[0] != cond.shape[0]: - raise ValueError("The batch size of negative_cond_input and cond_input must be the same.") - - cond = torch.cat([cond] * 2, dim=0) - if self.do_classifier_free_guidance: - prepared_input.append(torch.cat([neg_cond, cond], dim=0)) - else: - prepared_input.append(cond) - - return prepared_input - - elif isinstance(cond_input, torch.Tensor): - if not self.do_perturbed_attention_guidance: - return cond_input - - cond_input = torch.cat([cond_input] * 2, dim=0) - if self.do_classifier_free_guidance: - return torch.cat([negative_cond_input, cond_input], dim=0) - else: - return cond_input - - else: - raise ValueError(f"Unsupported input type: {type(negative_cond_input)} and {type(cond_input)}") - - def apply_guidance( - self, - model_output: torch.Tensor, - timestep: int, - latents: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - if not self.do_perturbed_attention_guidance: - return model_output - - if self.do_pag_adaptive_scaling: - pag_scale = max(self._pag_scale - self._pag_adaptive_scale * (1000 - timestep), 0) - else: - pag_scale = self._pag_scale - - if self.do_classifier_free_guidance: - noise_pred_uncond, noise_pred_text, noise_pred_perturb = model_output.chunk(3) - noise_pred = ( - noise_pred_uncond - + self.guidance_scale * (noise_pred_text - noise_pred_uncond) - + pag_scale * (noise_pred_text - noise_pred_perturb) - ) - else: - noise_pred_text, noise_pred_perturb = model_output.chunk(2) - noise_pred = noise_pred_text + pag_scale * (noise_pred_text - noise_pred_perturb) - - if self.guidance_rescale > 0.0: - # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf - noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale) - - return noise_pred - - -class MomentumBuffer: - def __init__(self, momentum: float): - self.momentum = momentum - self.running_average = 0 - - def update(self, update_value: torch.Tensor): - new_average = self.momentum * self.running_average - self.running_average = update_value + new_average - - -class APGGuider: - """ - This class is used to guide the pipeline with APG (Adaptive Projected Guidance). - """ - - def normalized_guidance( - self, - pred_cond: torch.Tensor, - pred_uncond: torch.Tensor, - guidance_scale: float, - momentum_buffer: MomentumBuffer = None, - norm_threshold: float = 0.0, - eta: float = 1.0, - ): - """ - Based on the findings of [Eliminating Oversaturation and Artifacts of High Guidance Scales in Diffusion - Models](https://arxiv.org/pdf/2410.02416) - """ - diff = pred_cond - pred_uncond - if momentum_buffer is not None: - momentum_buffer.update(diff) - diff = momentum_buffer.running_average - if norm_threshold > 0: - ones = torch.ones_like(diff) - diff_norm = diff.norm(p=2, dim=[-1, -2, -3], keepdim=True) - scale_factor = torch.minimum(ones, norm_threshold / diff_norm) - diff = diff * scale_factor - v0, v1 = diff.double(), pred_cond.double() - v1 = torch.nn.functional.normalize(v1, dim=[-1, -2, -3]) - v0_parallel = (v0 * v1).sum(dim=[-1, -2, -3], keepdim=True) * v1 - v0_orthogonal = v0 - v0_parallel - diff_parallel, diff_orthogonal = v0_parallel.to(diff.dtype), v0_orthogonal.to(diff.dtype) - normalized_update = diff_orthogonal + eta * diff_parallel - pred_guided = pred_cond + (guidance_scale - 1) * normalized_update - return pred_guided - - @property - def adaptive_projected_guidance_momentum(self): - return self._adaptive_projected_guidance_momentum - - @property - def adaptive_projected_guidance_rescale_factor(self): - return self._adaptive_projected_guidance_rescale_factor - - @property - def do_classifier_free_guidance(self): - return self._guidance_scale > 1.0 and not self._disable_guidance - - @property - def guidance_rescale(self): - return self._guidance_rescale - - @property - def guidance_scale(self): - return self._guidance_scale - - @property - def batch_size(self): - return self._batch_size - - def set_guider(self, pipeline, guider_kwargs: Dict[str, Any]): - disable_guidance = guider_kwargs.get("disable_guidance", False) - guidance_scale = guider_kwargs.get("guidance_scale", None) - if guidance_scale is None: - raise ValueError("guidance_scale is not provided in guider_kwargs") - adaptive_projected_guidance_momentum = guider_kwargs.get("adaptive_projected_guidance_momentum", None) - adaptive_projected_guidance_rescale_factor = guider_kwargs.get( - "adaptive_projected_guidance_rescale_factor", 15.0 - ) - guidance_rescale = guider_kwargs.get("guidance_rescale", 0.0) - batch_size = guider_kwargs.get("batch_size", None) - if batch_size is None: - raise ValueError("batch_size is not provided in guider_kwargs") - self._adaptive_projected_guidance_momentum = adaptive_projected_guidance_momentum - self._adaptive_projected_guidance_rescale_factor = adaptive_projected_guidance_rescale_factor - self._guidance_scale = guidance_scale - self._guidance_rescale = guidance_rescale - self._batch_size = batch_size - self._disable_guidance = disable_guidance - if adaptive_projected_guidance_momentum is not None: - self.momentum_buffer = MomentumBuffer(adaptive_projected_guidance_momentum) - else: - self.momentum_buffer = None - self.scheduler = pipeline.scheduler - - def reset_guider(self, pipeline): - pass - - def maybe_update_guider(self, pipeline, timestep): - pass - - def maybe_update_input(self, pipeline, cond_input): - pass - - def _maybe_split_prepared_input(self, cond): - """ - Process and potentially split the conditional input for Classifier-Free Guidance (CFG). - - This method handles inputs that may already have CFG applied (i.e. when `cond` is output of `prepare_input`). - It determines whether to split the input based on its batch size relative to the expected batch size. - - Args: - cond (torch.Tensor): The conditional input tensor to process. - - Returns: - Tuple[torch.Tensor, torch.Tensor]: A tuple containing: - - The negative conditional input (uncond_input) - - The positive conditional input (cond_input) - """ - if cond.shape[0] == self.batch_size * 2: - neg_cond = cond[0 : self.batch_size] - cond = cond[self.batch_size :] - return neg_cond, cond - elif cond.shape[0] == self.batch_size: - return cond, cond - else: - raise ValueError(f"Unsupported input shape: {cond.shape}") - - def _is_prepared_input(self, cond): - """ - Check if the input is already prepared for Classifier-Free Guidance (CFG). - - Args: - cond (torch.Tensor): The conditional input tensor to check. - - Returns: - bool: True if the input is already prepared, False otherwise. - """ - cond_tensor = cond[0] if isinstance(cond, (list, tuple)) else cond - - return cond_tensor.shape[0] == self.batch_size * 2 - - def prepare_input( - self, - cond_input: Union[torch.Tensor, List[torch.Tensor]], - negative_cond_input: Optional[Union[torch.Tensor, List[torch.Tensor]]] = None, - ) -> Union[torch.Tensor, List[torch.Tensor]]: - """ - Prepare the input for CFG. - - Args: - cond_input (Union[torch.Tensor, List[torch.Tensor]]): - The conditional input. It can be a single tensor or a - list of tensors. It must have the same length as `negative_cond_input`. - negative_cond_input (Union[torch.Tensor, List[torch.Tensor]]): The negative conditional input. It can be a - single tensor or a list of tensors. It must have the same length as `cond_input`. - - Returns: - Union[torch.Tensor, List[torch.Tensor]]: The prepared input. - """ - - # we check if cond_input already has CFG applied, and split if it is the case. - if self._is_prepared_input(cond_input) and self.do_classifier_free_guidance: - return cond_input - - if self._is_prepared_input(cond_input) and not self.do_classifier_free_guidance: - if isinstance(cond_input, list): - negative_cond_input, cond_input = zip(*[self._maybe_split_prepared_input(cond) for cond in cond_input]) - else: - negative_cond_input, cond_input = self._maybe_split_prepared_input(cond_input) - - if not self._is_prepared_input(cond_input) and self.do_classifier_free_guidance and negative_cond_input is None: - raise ValueError( - "`negative_cond_input` is required when cond_input does not already contains negative conditional input" - ) - - if isinstance(cond_input, (list, tuple)): - if not self.do_classifier_free_guidance: - return cond_input - - if len(negative_cond_input) != len(cond_input): - raise ValueError("The length of negative_cond_input and cond_input must be the same.") - prepared_input = [] - for neg_cond, cond in zip(negative_cond_input, cond_input): - if neg_cond.shape[0] != cond.shape[0]: - raise ValueError("The batch size of negative_cond_input and cond_input must be the same.") - prepared_input.append(torch.cat([neg_cond, cond], dim=0)) - return prepared_input - - elif isinstance(cond_input, torch.Tensor): - if not self.do_classifier_free_guidance: - return cond_input - else: - return torch.cat([negative_cond_input, cond_input], dim=0) - - else: - raise ValueError(f"Unsupported input type: {type(cond_input)}") - - def apply_guidance( - self, - model_output: torch.Tensor, - timestep: int = None, - latents: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - if not self.do_classifier_free_guidance: - return model_output - - if latents is None: - raise ValueError("APG requires `latents` to convert model output to denoised prediction (x0).") - - sigma = self.scheduler.sigmas[self.scheduler.step_index] - noise_pred = latents - sigma * model_output - noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) - noise_pred = self.normalized_guidance( - noise_pred_text, - noise_pred_uncond, - self.guidance_scale, - self.momentum_buffer, - self.adaptive_projected_guidance_rescale_factor, - ) - noise_pred = (latents - noise_pred) / sigma - - if self.guidance_rescale > 0.0: - # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf - noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale) - return noise_pred diff --git a/src/diffusers/guiders/__init__.py b/src/diffusers/guiders/__init__.py new file mode 100644 index 000000000000..3c1ee293382d --- /dev/null +++ b/src/diffusers/guiders/__init__.py @@ -0,0 +1,29 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Union + +from ..utils import is_torch_available + + +if is_torch_available(): + from .adaptive_projected_guidance import AdaptiveProjectedGuidance + from .auto_guidance import AutoGuidance + from .classifier_free_guidance import ClassifierFreeGuidance + from .classifier_free_zero_star_guidance import ClassifierFreeZeroStarGuidance + from .skip_layer_guidance import SkipLayerGuidance + from .smoothed_energy_guidance import SmoothedEnergyGuidance + from .tangential_classifier_free_guidance import TangentialClassifierFreeGuidance + + GuiderType = Union[AdaptiveProjectedGuidance, AutoGuidance, ClassifierFreeGuidance, ClassifierFreeZeroStarGuidance, SkipLayerGuidance, SmoothedEnergyGuidance, TangentialClassifierFreeGuidance] diff --git a/src/diffusers/guiders/adaptive_projected_guidance.py b/src/diffusers/guiders/adaptive_projected_guidance.py new file mode 100644 index 000000000000..ef2f3f2c8420 --- /dev/null +++ b/src/diffusers/guiders/adaptive_projected_guidance.py @@ -0,0 +1,184 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from typing import Optional, List, TYPE_CHECKING, Dict, Union, Tuple + +import torch + +from .guider_utils import BaseGuidance, rescale_noise_cfg + +if TYPE_CHECKING: + from ..modular_pipelines.modular_pipeline import BlockState + + +class AdaptiveProjectedGuidance(BaseGuidance): + """ + Adaptive Projected Guidance (APG): https://huggingface.co/papers/2410.02416 + + Args: + guidance_scale (`float`, defaults to `7.5`): + The scale parameter for classifier-free guidance. Higher values result in stronger conditioning on the text + prompt, while lower values allow for more freedom in generation. Higher values may lead to saturation and + deterioration of image quality. + adaptive_projected_guidance_momentum (`float`, defaults to `None`): + The momentum parameter for the adaptive projected guidance. Disabled if set to `None`. + adaptive_projected_guidance_rescale (`float`, defaults to `15.0`): + The rescale factor applied to the noise predictions. This is used to improve image quality and fix + guidance_rescale (`float`, defaults to `0.0`): + The rescale factor applied to the noise predictions. This is used to improve image quality and fix + overexposure. Based on Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are + Flawed](https://huggingface.co/papers/2305.08891). + use_original_formulation (`bool`, defaults to `False`): + Whether to use the original formulation of classifier-free guidance as proposed in the paper. By default, + we use the diffusers-native implementation that has been in the codebase for a long time. See + [~guiders.classifier_free_guidance.ClassifierFreeGuidance] for more details. + start (`float`, defaults to `0.0`): + The fraction of the total number of denoising steps after which guidance starts. + stop (`float`, defaults to `1.0`): + The fraction of the total number of denoising steps after which guidance stops. + """ + + _input_predictions = ["pred_cond", "pred_uncond"] + + def __init__( + self, + guidance_scale: float = 7.5, + adaptive_projected_guidance_momentum: Optional[float] = None, + adaptive_projected_guidance_rescale: float = 15.0, + eta: float = 1.0, + guidance_rescale: float = 0.0, + use_original_formulation: bool = False, + start: float = 0.0, + stop: float = 1.0, + ): + super().__init__(start, stop) + + self.guidance_scale = guidance_scale + self.adaptive_projected_guidance_momentum = adaptive_projected_guidance_momentum + self.adaptive_projected_guidance_rescale = adaptive_projected_guidance_rescale + self.eta = eta + self.guidance_rescale = guidance_rescale + self.use_original_formulation = use_original_formulation + self.momentum_buffer = None + + def prepare_inputs(self, data: "BlockState", input_fields: Optional[Dict[str, Union[str, Tuple[str, str]]]] = None) -> List["BlockState"]: + + if input_fields is None: + input_fields = self._input_fields + + if self._step == 0: + if self.adaptive_projected_guidance_momentum is not None: + self.momentum_buffer = MomentumBuffer(self.adaptive_projected_guidance_momentum) + tuple_indices = [0] if self.num_conditions == 1 else [0, 1] + data_batches = [] + for i in range(self.num_conditions): + data_batch = self._prepare_batch(input_fields, data, tuple_indices[i], self._input_predictions[i]) + data_batches.append(data_batch) + return data_batches + + def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> torch.Tensor: + pred = None + + if not self._is_apg_enabled(): + pred = pred_cond + else: + pred = normalized_guidance( + pred_cond, + pred_uncond, + self.guidance_scale, + self.momentum_buffer, + self.eta, + self.adaptive_projected_guidance_rescale, + self.use_original_formulation, + ) + + if self.guidance_rescale > 0.0: + pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale) + + return pred, {} + + @property + def is_conditional(self) -> bool: + return self._count_prepared == 1 + + @property + def num_conditions(self) -> int: + num_conditions = 1 + if self._is_apg_enabled(): + num_conditions += 1 + return num_conditions + + def _is_apg_enabled(self) -> bool: + if not self._enabled: + return False + + is_within_range = True + if self._num_inference_steps is not None: + skip_start_step = int(self._start * self._num_inference_steps) + skip_stop_step = int(self._stop * self._num_inference_steps) + is_within_range = skip_start_step <= self._step < skip_stop_step + + is_close = False + if self.use_original_formulation: + is_close = math.isclose(self.guidance_scale, 0.0) + else: + is_close = math.isclose(self.guidance_scale, 1.0) + + return is_within_range and not is_close + + +class MomentumBuffer: + def __init__(self, momentum: float): + self.momentum = momentum + self.running_average = 0 + + def update(self, update_value: torch.Tensor): + new_average = self.momentum * self.running_average + self.running_average = update_value + new_average + + +def normalized_guidance( + pred_cond: torch.Tensor, + pred_uncond: torch.Tensor, + guidance_scale: float, + momentum_buffer: Optional[MomentumBuffer] = None, + eta: float = 1.0, + norm_threshold: float = 0.0, + use_original_formulation: bool = False, +): + diff = pred_cond - pred_uncond + dim = [-i for i in range(1, len(diff.shape))] + + if momentum_buffer is not None: + momentum_buffer.update(diff) + diff = momentum_buffer.running_average + + if norm_threshold > 0: + ones = torch.ones_like(diff) + diff_norm = diff.norm(p=2, dim=dim, keepdim=True) + scale_factor = torch.minimum(ones, norm_threshold / diff_norm) + diff = diff * scale_factor + + v0, v1 = diff.double(), pred_cond.double() + v1 = torch.nn.functional.normalize(v1, dim=dim) + v0_parallel = (v0 * v1).sum(dim=dim, keepdim=True) * v1 + v0_orthogonal = v0 - v0_parallel + diff_parallel, diff_orthogonal = v0_parallel.type_as(diff), v0_orthogonal.type_as(diff) + normalized_update = diff_orthogonal + eta * diff_parallel + + pred = pred_cond if use_original_formulation else pred_uncond + pred = pred + guidance_scale * normalized_update + + return pred diff --git a/src/diffusers/guiders/auto_guidance.py b/src/diffusers/guiders/auto_guidance.py new file mode 100644 index 000000000000..791cc582add2 --- /dev/null +++ b/src/diffusers/guiders/auto_guidance.py @@ -0,0 +1,177 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from typing import List, Optional, Union, TYPE_CHECKING, Dict, Tuple + +import torch + +from ..hooks import HookRegistry, LayerSkipConfig +from ..hooks.layer_skip import _apply_layer_skip_hook +from .guider_utils import BaseGuidance, rescale_noise_cfg + +if TYPE_CHECKING: + from ..modular_pipelines.modular_pipeline import BlockState + + +class AutoGuidance(BaseGuidance): + """ + AutoGuidance: https://huggingface.co/papers/2406.02507 + + Args: + guidance_scale (`float`, defaults to `7.5`): + The scale parameter for classifier-free guidance. Higher values result in stronger conditioning on the text + prompt, while lower values allow for more freedom in generation. Higher values may lead to saturation and + deterioration of image quality. + auto_guidance_layers (`int` or `List[int]`, *optional*): + The layer indices to apply skip layer guidance to. Can be a single integer or a list of integers. If not + provided, `skip_layer_config` must be provided. + auto_guidance_config (`LayerSkipConfig` or `List[LayerSkipConfig]`, *optional*): + The configuration for the skip layer guidance. Can be a single `LayerSkipConfig` or a list of + `LayerSkipConfig`. If not provided, `skip_layer_guidance_layers` must be provided. + dropout (`float`, *optional*): + The dropout probability for autoguidance on the enabled skip layers (either with `auto_guidance_layers` or + `auto_guidance_config`). If not provided, the dropout probability will be set to 1.0. + guidance_rescale (`float`, defaults to `0.0`): + The rescale factor applied to the noise predictions. This is used to improve image quality and fix + overexposure. Based on Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are + Flawed](https://huggingface.co/papers/2305.08891). + use_original_formulation (`bool`, defaults to `False`): + Whether to use the original formulation of classifier-free guidance as proposed in the paper. By default, + we use the diffusers-native implementation that has been in the codebase for a long time. See + [~guiders.classifier_free_guidance.ClassifierFreeGuidance] for more details. + start (`float`, defaults to `0.0`): + The fraction of the total number of denoising steps after which guidance starts. + stop (`float`, defaults to `1.0`): + The fraction of the total number of denoising steps after which guidance stops. + """ + + _input_predictions = ["pred_cond", "pred_uncond"] + + def __init__( + self, + guidance_scale: float = 7.5, + auto_guidance_layers: Optional[Union[int, List[int]]] = None, + auto_guidance_config: Union[LayerSkipConfig, List[LayerSkipConfig]] = None, + dropout: Optional[float] = None, + guidance_rescale: float = 0.0, + use_original_formulation: bool = False, + start: float = 0.0, + stop: float = 1.0, + ): + super().__init__(start, stop) + + self.guidance_scale = guidance_scale + self.auto_guidance_layers = auto_guidance_layers + self.auto_guidance_config = auto_guidance_config + self.dropout = dropout + self.guidance_rescale = guidance_rescale + self.use_original_formulation = use_original_formulation + + if auto_guidance_layers is None and auto_guidance_config is None: + raise ValueError( + "Either `auto_guidance_layers` or `auto_guidance_config` must be provided to enable Skip Layer Guidance." + ) + if auto_guidance_layers is not None and auto_guidance_config is not None: + raise ValueError("Only one of `auto_guidance_layers` or `auto_guidance_config` can be provided.") + if (dropout is None and auto_guidance_layers is not None) or (dropout is not None and auto_guidance_layers is None): + raise ValueError("`dropout` must be provided if `auto_guidance_layers` is provided.") + + if auto_guidance_layers is not None: + if isinstance(auto_guidance_layers, int): + auto_guidance_layers = [auto_guidance_layers] + if not isinstance(auto_guidance_layers, list): + raise ValueError( + f"Expected `auto_guidance_layers` to be an int or a list of ints, but got {type(auto_guidance_layers)}." + ) + auto_guidance_config = [LayerSkipConfig(layer, fqn="auto", dropout=dropout) for layer in auto_guidance_layers] + + if isinstance(auto_guidance_config, LayerSkipConfig): + auto_guidance_config = [auto_guidance_config] + + if not isinstance(auto_guidance_config, list): + raise ValueError( + f"Expected `auto_guidance_config` to be a LayerSkipConfig or a list of LayerSkipConfig, but got {type(auto_guidance_config)}." + ) + + self.auto_guidance_config = auto_guidance_config + self._auto_guidance_hook_names = [f"AutoGuidance_{i}" for i in range(len(self.auto_guidance_config))] + + def prepare_models(self, denoiser: torch.nn.Module) -> None: + self._count_prepared += 1 + if self._is_ag_enabled() and self.is_unconditional: + for name, config in zip(self._auto_guidance_hook_names, self.auto_guidance_config): + _apply_layer_skip_hook(denoiser, config, name=name) + + def cleanup_models(self, denoiser: torch.nn.Module) -> None: + if self._is_ag_enabled() and self.is_unconditional: + for name in self._auto_guidance_hook_names: + registry = HookRegistry.check_if_exists_or_initialize(denoiser) + registry.remove_hook(name, recurse=True) + + def prepare_inputs(self, data: "BlockState", input_fields: Optional[Dict[str, Union[str, Tuple[str, str]]]] = None) -> List["BlockState"]: + + if input_fields is None: + input_fields = self._input_fields + + tuple_indices = [0] if self.num_conditions == 1 else [0, 1] + data_batches = [] + for i in range(self.num_conditions): + data_batch = self._prepare_batch(input_fields, data, tuple_indices[i], self._input_predictions[i]) + data_batches.append(data_batch) + return data_batches + + def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> torch.Tensor: + pred = None + + if not self._is_ag_enabled(): + pred = pred_cond + else: + shift = pred_cond - pred_uncond + pred = pred_cond if self.use_original_formulation else pred_uncond + pred = pred + self.guidance_scale * shift + + if self.guidance_rescale > 0.0: + pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale) + + return pred, {} + + @property + def is_conditional(self) -> bool: + return self._count_prepared == 1 + + @property + def num_conditions(self) -> int: + num_conditions = 1 + if self._is_ag_enabled(): + num_conditions += 1 + return num_conditions + + def _is_ag_enabled(self) -> bool: + if not self._enabled: + return False + + is_within_range = True + if self._num_inference_steps is not None: + skip_start_step = int(self._start * self._num_inference_steps) + skip_stop_step = int(self._stop * self._num_inference_steps) + is_within_range = skip_start_step <= self._step < skip_stop_step + + is_close = False + if self.use_original_formulation: + is_close = math.isclose(self.guidance_scale, 0.0) + else: + is_close = math.isclose(self.guidance_scale, 1.0) + + return is_within_range and not is_close diff --git a/src/diffusers/guiders/classifier_free_guidance.py b/src/diffusers/guiders/classifier_free_guidance.py new file mode 100644 index 000000000000..a459e51cd083 --- /dev/null +++ b/src/diffusers/guiders/classifier_free_guidance.py @@ -0,0 +1,132 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from typing import Optional, List, TYPE_CHECKING, Dict, Union, Tuple + +import torch + +from .guider_utils import BaseGuidance, rescale_noise_cfg + +if TYPE_CHECKING: + from ..modular_pipelines.modular_pipeline import BlockState + + +class ClassifierFreeGuidance(BaseGuidance): + """ + Classifier-free guidance (CFG): https://huggingface.co/papers/2207.12598 + + CFG is a technique used to improve generation quality and condition-following in diffusion models. It works by + jointly training a model on both conditional and unconditional data, and using a weighted sum of the two during + inference. This allows the model to tradeoff between generation quality and sample diversity. + The original paper proposes scaling and shifting the conditional distribution based on the difference between + conditional and unconditional predictions. [x_pred = x_cond + scale * (x_cond - x_uncond)] + + Diffusers implemented the scaling and shifting on the unconditional prediction instead based on the [Imagen + paper](https://huggingface.co/papers/2205.11487), which is equivalent to what the original paper proposed in + theory. [x_pred = x_uncond + scale * (x_cond - x_uncond)] + + The intution behind the original formulation can be thought of as moving the conditional distribution estimates + further away from the unconditional distribution estimates, while the diffusers-native implementation can be + thought of as moving the unconditional distribution towards the conditional distribution estimates to get rid of + the unconditional predictions (usually negative features like "bad quality, bad anotomy, watermarks", etc.) + + The `use_original_formulation` argument can be set to `True` to use the original CFG formulation mentioned in the + paper. By default, we use the diffusers-native implementation that has been in the codebase for a long time. + + Args: + guidance_scale (`float`, defaults to `7.5`): + The scale parameter for classifier-free guidance. Higher values result in stronger conditioning on the text + prompt, while lower values allow for more freedom in generation. Higher values may lead to saturation and + deterioration of image quality. + guidance_rescale (`float`, defaults to `0.0`): + The rescale factor applied to the noise predictions. This is used to improve image quality and fix + overexposure. Based on Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are + Flawed](https://huggingface.co/papers/2305.08891). + use_original_formulation (`bool`, defaults to `False`): + Whether to use the original formulation of classifier-free guidance as proposed in the paper. By default, + we use the diffusers-native implementation that has been in the codebase for a long time. See + [~guiders.classifier_free_guidance.ClassifierFreeGuidance] for more details. + start (`float`, defaults to `0.0`): + The fraction of the total number of denoising steps after which guidance starts. + stop (`float`, defaults to `1.0`): + The fraction of the total number of denoising steps after which guidance stops. + """ + + _input_predictions = ["pred_cond", "pred_uncond"] + + def __init__( + self, guidance_scale: float = 7.5, guidance_rescale: float = 0.0, use_original_formulation: bool = False, start: float = 0.0, stop: float = 1.0 + ): + super().__init__(start, stop) + + self.guidance_scale = guidance_scale + self.guidance_rescale = guidance_rescale + self.use_original_formulation = use_original_formulation + + def prepare_inputs(self, data: "BlockState", input_fields: Optional[Dict[str, Union[str, Tuple[str, str]]]] = None) -> List["BlockState"]: + + if input_fields is None: + input_fields = self._input_fields + + tuple_indices = [0] if self.num_conditions == 1 else [0, 1] + data_batches = [] + for i in range(self.num_conditions): + data_batch = self._prepare_batch(input_fields, data, tuple_indices[i], self._input_predictions[i]) + data_batches.append(data_batch) + return data_batches + + def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> torch.Tensor: + pred = None + + if not self._is_cfg_enabled(): + pred = pred_cond + else: + shift = pred_cond - pred_uncond + pred = pred_cond if self.use_original_formulation else pred_uncond + pred = pred + self.guidance_scale * shift + + if self.guidance_rescale > 0.0: + pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale) + + return pred, {} + + @property + def is_conditional(self) -> bool: + return self._count_prepared == 1 + + @property + def num_conditions(self) -> int: + num_conditions = 1 + if self._is_cfg_enabled(): + num_conditions += 1 + return num_conditions + + def _is_cfg_enabled(self) -> bool: + if not self._enabled: + return False + + is_within_range = True + if self._num_inference_steps is not None: + skip_start_step = int(self._start * self._num_inference_steps) + skip_stop_step = int(self._stop * self._num_inference_steps) + is_within_range = skip_start_step <= self._step < skip_stop_step + + is_close = False + if self.use_original_formulation: + is_close = math.isclose(self.guidance_scale, 0.0) + else: + is_close = math.isclose(self.guidance_scale, 1.0) + + return is_within_range and not is_close diff --git a/src/diffusers/guiders/classifier_free_zero_star_guidance.py b/src/diffusers/guiders/classifier_free_zero_star_guidance.py new file mode 100644 index 000000000000..a722f2605036 --- /dev/null +++ b/src/diffusers/guiders/classifier_free_zero_star_guidance.py @@ -0,0 +1,148 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from typing import Optional, List, TYPE_CHECKING, Dict, Union, Tuple + +import torch + +from .guider_utils import BaseGuidance, rescale_noise_cfg + +if TYPE_CHECKING: + from ..modular_pipelines.modular_pipeline import BlockState + + +class ClassifierFreeZeroStarGuidance(BaseGuidance): + """ + Classifier-free Zero* (CFG-Zero*): https://huggingface.co/papers/2503.18886 + + This is an implementation of the Classifier-Free Zero* guidance technique, which is a variant of classifier-free + guidance. It proposes zero initialization of the noise predictions for the first few steps of the diffusion + process, and also introduces an optimal rescaling factor for the noise predictions, which can help in improving the + quality of generated images. + + The authors of the paper suggest setting zero initialization in the first 4% of the inference steps. + + Args: + guidance_scale (`float`, defaults to `7.5`): + The scale parameter for classifier-free guidance. Higher values result in stronger conditioning on the text + prompt, while lower values allow for more freedom in generation. Higher values may lead to saturation and + deterioration of image quality. + zero_init_steps (`int`, defaults to `1`): + The number of inference steps for which the noise predictions are zeroed out (see Section 4.2). + guidance_rescale (`float`, defaults to `0.0`): + The rescale factor applied to the noise predictions. This is used to improve image quality and fix + overexposure. Based on Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are + Flawed](https://huggingface.co/papers/2305.08891). + use_original_formulation (`bool`, defaults to `False`): + Whether to use the original formulation of classifier-free guidance as proposed in the paper. By default, + we use the diffusers-native implementation that has been in the codebase for a long time. See + [~guiders.classifier_free_guidance.ClassifierFreeGuidance] for more details. + start (`float`, defaults to `0.01`): + The fraction of the total number of denoising steps after which guidance starts. + stop (`float`, defaults to `0.2`): + The fraction of the total number of denoising steps after which guidance stops. + """ + + _input_predictions = ["pred_cond", "pred_uncond"] + + def __init__( + self, + guidance_scale: float = 7.5, + zero_init_steps: int = 1, + guidance_rescale: float = 0.0, + use_original_formulation: bool = False, + start: float = 0.0, + stop: float = 1.0, + ): + super().__init__(start, stop) + + self.guidance_scale = guidance_scale + self.zero_init_steps = zero_init_steps + self.guidance_rescale = guidance_rescale + self.use_original_formulation = use_original_formulation + + def prepare_inputs(self, data: "BlockState", input_fields: Optional[Dict[str, Union[str, Tuple[str, str]]]] = None) -> List["BlockState"]: + + if input_fields is None: + input_fields = self._input_fields + + tuple_indices = [0] if self.num_conditions == 1 else [0, 1] + data_batches = [] + for i in range(self.num_conditions): + data_batch = self._prepare_batch(input_fields, data, tuple_indices[i], self._input_predictions[i]) + data_batches.append(data_batch) + return data_batches + + def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> torch.Tensor: + pred = None + + if self._step < self.zero_init_steps: + pred = torch.zeros_like(pred_cond) + elif not self._is_cfg_enabled(): + pred = pred_cond + else: + pred_cond_flat = pred_cond.flatten(1) + pred_uncond_flat = pred_uncond.flatten(1) + alpha = cfg_zero_star_scale(pred_cond_flat, pred_uncond_flat) + alpha = alpha.view(-1, *(1,) * (len(pred_cond.shape) - 1)) + pred_uncond = pred_uncond * alpha + shift = pred_cond - pred_uncond + pred = pred_cond if self.use_original_formulation else pred_uncond + pred = pred + self.guidance_scale * shift + + if self.guidance_rescale > 0.0: + pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale) + + return pred, {} + + @property + def is_conditional(self) -> bool: + return self._count_prepared == 1 + + @property + def num_conditions(self) -> int: + num_conditions = 1 + if self._is_cfg_enabled(): + num_conditions += 1 + return num_conditions + + def _is_cfg_enabled(self) -> bool: + if not self._enabled: + return False + + is_within_range = True + if self._num_inference_steps is not None: + skip_start_step = int(self._start * self._num_inference_steps) + skip_stop_step = int(self._stop * self._num_inference_steps) + is_within_range = skip_start_step <= self._step < skip_stop_step + + is_close = False + if self.use_original_formulation: + is_close = math.isclose(self.guidance_scale, 0.0) + else: + is_close = math.isclose(self.guidance_scale, 1.0) + + return is_within_range and not is_close + + +def cfg_zero_star_scale(cond: torch.Tensor, uncond: torch.Tensor, eps: float = 1e-8) -> torch.Tensor: + cond_dtype = cond.dtype + cond = cond.float() + uncond = uncond.float() + dot_product = torch.sum(cond * uncond, dim=1, keepdim=True) + squared_norm = torch.sum(uncond**2, dim=1, keepdim=True) + eps + # st_star = v_cond^T * v_uncond / ||v_uncond||^2 + scale = dot_product / squared_norm + return scale.to(dtype=cond_dtype) diff --git a/src/diffusers/guiders/entropy_rectifying_guidance.py b/src/diffusers/guiders/entropy_rectifying_guidance.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/src/diffusers/guiders/guider_utils.py b/src/diffusers/guiders/guider_utils.py new file mode 100644 index 000000000000..e8e873f5c88f --- /dev/null +++ b/src/diffusers/guiders/guider_utils.py @@ -0,0 +1,215 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import TYPE_CHECKING, Any, Dict, List, Tuple, Union + +import torch + +from ..utils import get_logger + + +if TYPE_CHECKING: + from ..modular_pipelines.modular_pipeline import BlockState + + +logger = get_logger(__name__) # pylint: disable=invalid-name + + +class BaseGuidance: + r"""Base class providing the skeleton for implementing guidance techniques.""" + + _input_predictions = None + _identifier_key = "__guidance_identifier__" + + def __init__(self, start: float = 0.0, stop: float = 1.0): + self._start = start + self._stop = stop + self._step: int = None + self._num_inference_steps: int = None + self._timestep: torch.LongTensor = None + self._count_prepared = 0 + self._input_fields: Dict[str, Union[str, Tuple[str, str]]] = None + self._enabled = True + + if not (0.0 <= start < 1.0): + raise ValueError( + f"Expected `start` to be between 0.0 and 1.0, but got {start}." + ) + if not (start <= stop <= 1.0): + raise ValueError( + f"Expected `stop` to be between {start} and 1.0, but got {stop}." + ) + + if self._input_predictions is None or not isinstance(self._input_predictions, list): + raise ValueError( + "`_input_predictions` must be a list of required prediction names for the guidance technique." + ) + + def disable(self): + self._enabled = False + + def enable(self): + self._enabled = True + + def set_state(self, step: int, num_inference_steps: int, timestep: torch.LongTensor) -> None: + self._step = step + self._num_inference_steps = num_inference_steps + self._timestep = timestep + self._count_prepared = 0 + + def set_input_fields(self, **kwargs: Dict[str, Union[str, Tuple[str, str]]]) -> None: + """ + Set the input fields for the guidance technique. The input fields are used to specify the names of the + returned attributes containing the prepared data after `prepare_inputs` is called. The prepared data is + obtained from the values of the provided keyword arguments to this method. + + Args: + **kwargs (`Dict[str, Union[str, Tuple[str, str]]]`): + A dictionary where the keys are the names of the fields that will be used to store the data once + it is prepared with `prepare_inputs`. The values can be either a string or a tuple of length 2, + which is used to look up the required data provided for preparation. + + If a string is provided, it will be used as the conditional data (or unconditional if used with + a guidance method that requires it). If a tuple of length 2 is provided, the first element must + be the conditional data identifier and the second element must be the unconditional data identifier + or None. + + Example: + + ``` + data = {"prompt_embeds": , "negative_prompt_embeds": , "latents": } + + BaseGuidance.set_input_fields( + latents="latents", + prompt_embeds=("prompt_embeds", "negative_prompt_embeds"), + ) + ``` + """ + for key, value in kwargs.items(): + is_string = isinstance(value, str) + is_tuple_of_str_with_len_2 = isinstance(value, tuple) and len(value) == 2 and all(isinstance(v, str) for v in value) + if not (is_string or is_tuple_of_str_with_len_2): + raise ValueError( + f"Expected `set_input_fields` to be called with a string or a tuple of string with length 2, but got {type(value)} for key {key}." + ) + self._input_fields = kwargs + + def prepare_models(self, denoiser: torch.nn.Module) -> None: + """ + Prepares the models for the guidance technique on a given batch of data. This method should be overridden in + subclasses to implement specific model preparation logic. + """ + self._count_prepared += 1 + + def cleanup_models(self, denoiser: torch.nn.Module) -> None: + """ + Cleans up the models for the guidance technique after a given batch of data. This method should be overridden in + subclasses to implement specific model cleanup logic. It is useful for removing any hooks or other stateful + modifications made during `prepare_models`. + """ + pass + + def prepare_inputs(self, data: "BlockState") -> List["BlockState"]: + raise NotImplementedError("BaseGuidance::prepare_inputs must be implemented in subclasses.") + + def __call__(self, data: List["BlockState"]) -> Any: + if not all(hasattr(d, "noise_pred") for d in data): + raise ValueError("Expected all data to have `noise_pred` attribute.") + if len(data) != self.num_conditions: + raise ValueError( + f"Expected {self.num_conditions} data items, but got {len(data)}. Please check the input data." + ) + forward_inputs = {getattr(d, self._identifier_key): d.noise_pred for d in data} + return self.forward(**forward_inputs) + + def forward(self, *args, **kwargs) -> Any: + raise NotImplementedError("BaseGuidance::forward must be implemented in subclasses.") + + @property + def is_conditional(self) -> bool: + raise NotImplementedError("BaseGuidance::is_conditional must be implemented in subclasses.") + + @property + def is_unconditional(self) -> bool: + return not self.is_conditional + + @property + def num_conditions(self) -> int: + raise NotImplementedError("BaseGuidance::num_conditions must be implemented in subclasses.") + + @classmethod + def _prepare_batch(cls, input_fields: Dict[str, Union[str, Tuple[str, str]]], data: "BlockState", tuple_index: int, identifier: str) -> "BlockState": + """ + Prepares a batch of data for the guidance technique. This method is used in the `prepare_inputs` method of + the `BaseGuidance` class. It prepares the batch based on the provided tuple index. + + Args: + input_fields (`Dict[str, Union[str, Tuple[str, str]]]`): + A dictionary where the keys are the names of the fields that will be used to store the data once + it is prepared with `prepare_inputs`. The values can be either a string or a tuple of length 2, + which is used to look up the required data provided for preparation. + If a string is provided, it will be used as the conditional data (or unconditional if used with + a guidance method that requires it). If a tuple of length 2 is provided, the first element must + be the conditional data identifier and the second element must be the unconditional data identifier + or None. + data (`BlockState`): + The input data to be prepared. + tuple_index (`int`): + The index to use when accessing input fields that are tuples. + + Returns: + `BlockState`: The prepared batch of data. + """ + from ..modular_pipelines.modular_pipeline import BlockState + + if input_fields is None: + raise ValueError("Input fields cannot be None. Please pass `input_fields` to `prepare_inputs` or call `set_input_fields` before preparing inputs.") + data_batch = {} + for key, value in input_fields.items(): + try: + if isinstance(value, str): + data_batch[key] = getattr(data, value) + elif isinstance(value, tuple): + data_batch[key] = getattr(data, value[tuple_index]) + else: + # We've already checked that value is a string or a tuple of strings with length 2 + pass + except AttributeError: + logger.debug(f"`data` does not have attribute(s) {value}, skipping.") + data_batch[cls._identifier_key] = identifier + return BlockState(**data_batch) + + +def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): + r""" + Rescales `noise_cfg` tensor based on `guidance_rescale` to improve image quality and fix overexposure. Based on + Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are + Flawed](https://arxiv.org/pdf/2305.08891.pdf). + Args: + noise_cfg (`torch.Tensor`): + The predicted noise tensor for the guided diffusion process. + noise_pred_text (`torch.Tensor`): + The predicted noise tensor for the text-guided diffusion process. + guidance_rescale (`float`, *optional*, defaults to 0.0): + A rescale factor applied to the noise predictions. + Returns: + noise_cfg (`torch.Tensor`): The rescaled noise prediction tensor. + """ + std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True) + std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True) + # rescale the results from guidance (fixes overexposure) + noise_pred_rescaled = noise_cfg * (std_text / std_cfg) + # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images + noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg + return noise_cfg diff --git a/src/diffusers/guiders/skip_layer_guidance.py b/src/diffusers/guiders/skip_layer_guidance.py new file mode 100644 index 000000000000..7c19f6391f41 --- /dev/null +++ b/src/diffusers/guiders/skip_layer_guidance.py @@ -0,0 +1,251 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from typing import List, Optional, Union, TYPE_CHECKING, Dict, Tuple + +import torch + +from ..hooks import HookRegistry, LayerSkipConfig +from ..hooks.layer_skip import _apply_layer_skip_hook +from .guider_utils import BaseGuidance, rescale_noise_cfg + +if TYPE_CHECKING: + from ..modular_pipelines.modular_pipeline import BlockState + + +class SkipLayerGuidance(BaseGuidance): + """ + Skip Layer Guidance (SLG): https://github.com/Stability-AI/sd3.5 + + Spatio-Temporal Guidance (STG): https://huggingface.co/papers/2411.18664 + + SLG was introduced by StabilityAI for improving structure and anotomy coherence in generated images. It works by + skipping the forward pass of specified transformer blocks during the denoising process on an additional conditional + batch of data, apart from the conditional and unconditional batches already used in CFG + ([~guiders.classifier_free_guidance.ClassifierFreeGuidance]), and then scaling and shifting the CFG predictions + based on the difference between conditional without skipping and conditional with skipping predictions. + + The intution behind SLG can be thought of as moving the CFG predicted distribution estimates further away from + worse versions of the conditional distribution estimates (because skipping layers is equivalent to using a worse + version of the model for the conditional prediction). + + STG is an improvement and follow-up work combining ideas from SLG, PAG and similar techniques for improving + generation quality in video diffusion models. + + Additional reading: + - [Guiding a Diffusion Model with a Bad Version of Itself](https://huggingface.co/papers/2406.02507) + + The values for `skip_layer_guidance_scale`, `skip_layer_guidance_start`, and `skip_layer_guidance_stop` are + defaulted to the recommendations by StabilityAI for Stable Diffusion 3.5 Medium. + + Args: + guidance_scale (`float`, defaults to `7.5`): + The scale parameter for classifier-free guidance. Higher values result in stronger conditioning on the text + prompt, while lower values allow for more freedom in generation. Higher values may lead to saturation and + deterioration of image quality. + skip_layer_guidance_scale (`float`, defaults to `2.8`): + The scale parameter for skip layer guidance. Anatomy and structure coherence may improve with higher + values, but it may also lead to overexposure and saturation. + skip_layer_guidance_start (`float`, defaults to `0.01`): + The fraction of the total number of denoising steps after which skip layer guidance starts. + skip_layer_guidance_stop (`float`, defaults to `0.2`): + The fraction of the total number of denoising steps after which skip layer guidance stops. + skip_layer_guidance_layers (`int` or `List[int]`, *optional*): + The layer indices to apply skip layer guidance to. Can be a single integer or a list of integers. If not + provided, `skip_layer_config` must be provided. The recommended values are `[7, 8, 9]` for Stable Diffusion + 3.5 Medium. + skip_layer_config (`LayerSkipConfig` or `List[LayerSkipConfig]`, *optional*): + The configuration for the skip layer guidance. Can be a single `LayerSkipConfig` or a list of + `LayerSkipConfig`. If not provided, `skip_layer_guidance_layers` must be provided. + guidance_rescale (`float`, defaults to `0.0`): + The rescale factor applied to the noise predictions. This is used to improve image quality and fix + overexposure. Based on Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are + Flawed](https://huggingface.co/papers/2305.08891). + use_original_formulation (`bool`, defaults to `False`): + Whether to use the original formulation of classifier-free guidance as proposed in the paper. By default, + we use the diffusers-native implementation that has been in the codebase for a long time. See + [~guiders.classifier_free_guidance.ClassifierFreeGuidance] for more details. + start (`float`, defaults to `0.01`): + The fraction of the total number of denoising steps after which guidance starts. + stop (`float`, defaults to `0.2`): + The fraction of the total number of denoising steps after which guidance stops. + """ + + _input_predictions = ["pred_cond", "pred_uncond", "pred_cond_skip"] + + def __init__( + self, + guidance_scale: float = 7.5, + skip_layer_guidance_scale: float = 2.8, + skip_layer_guidance_start: float = 0.01, + skip_layer_guidance_stop: float = 0.2, + skip_layer_guidance_layers: Optional[Union[int, List[int]]] = None, + skip_layer_config: Union[LayerSkipConfig, List[LayerSkipConfig]] = None, + guidance_rescale: float = 0.0, + use_original_formulation: bool = False, + start: float = 0.0, + stop: float = 1.0, + ): + super().__init__(start, stop) + + self.guidance_scale = guidance_scale + self.skip_layer_guidance_scale = skip_layer_guidance_scale + self.skip_layer_guidance_start = skip_layer_guidance_start + self.skip_layer_guidance_stop = skip_layer_guidance_stop + self.guidance_rescale = guidance_rescale + self.use_original_formulation = use_original_formulation + + if not (0.0 <= skip_layer_guidance_start < 1.0): + raise ValueError( + f"Expected `skip_layer_guidance_start` to be between 0.0 and 1.0, but got {skip_layer_guidance_start}." + ) + if not (skip_layer_guidance_start <= skip_layer_guidance_stop <= 1.0): + raise ValueError( + f"Expected `skip_layer_guidance_stop` to be between 0.0 and 1.0, but got {skip_layer_guidance_stop}." + ) + + if skip_layer_guidance_layers is None and skip_layer_config is None: + raise ValueError( + "Either `skip_layer_guidance_layers` or `skip_layer_config` must be provided to enable Skip Layer Guidance." + ) + if skip_layer_guidance_layers is not None and skip_layer_config is not None: + raise ValueError("Only one of `skip_layer_guidance_layers` or `skip_layer_config` can be provided.") + + if skip_layer_guidance_layers is not None: + if isinstance(skip_layer_guidance_layers, int): + skip_layer_guidance_layers = [skip_layer_guidance_layers] + if not isinstance(skip_layer_guidance_layers, list): + raise ValueError( + f"Expected `skip_layer_guidance_layers` to be an int or a list of ints, but got {type(skip_layer_guidance_layers)}." + ) + skip_layer_config = [LayerSkipConfig(layer, fqn="auto") for layer in skip_layer_guidance_layers] + + if isinstance(skip_layer_config, LayerSkipConfig): + skip_layer_config = [skip_layer_config] + + if not isinstance(skip_layer_config, list): + raise ValueError( + f"Expected `skip_layer_config` to be a LayerSkipConfig or a list of LayerSkipConfig, but got {type(skip_layer_config)}." + ) + + self.skip_layer_config = skip_layer_config + self._skip_layer_hook_names = [f"SkipLayerGuidance_{i}" for i in range(len(self.skip_layer_config))] + + def prepare_models(self, denoiser: torch.nn.Module) -> None: + self._count_prepared += 1 + if self._is_slg_enabled() and self.is_conditional and self._count_prepared > 1: + for name, config in zip(self._skip_layer_hook_names, self.skip_layer_config): + _apply_layer_skip_hook(denoiser, config, name=name) + + def cleanup_models(self, denoiser: torch.nn.Module) -> None: + if self._is_slg_enabled() and self.is_conditional and self._count_prepared > 1: + registry = HookRegistry.check_if_exists_or_initialize(denoiser) + # Remove the hooks after inference + for hook_name in self._skip_layer_hook_names: + registry.remove_hook(hook_name, recurse=True) + + def prepare_inputs(self, data: "BlockState", input_fields: Optional[Dict[str, Union[str, Tuple[str, str]]]] = None) -> List["BlockState"]: + + if input_fields is None: + input_fields = self._input_fields + + if self.num_conditions == 1: + tuple_indices = [0] + input_predictions = ["pred_cond"] + elif self.num_conditions == 2: + tuple_indices = [0, 1] + input_predictions = ["pred_cond", "pred_uncond"] if self._is_cfg_enabled() else ["pred_cond", "pred_cond_skip"] + else: + tuple_indices = [0, 1, 0] + input_predictions = ["pred_cond", "pred_uncond", "pred_cond_skip"] + data_batches = [] + for i in range(self.num_conditions): + data_batch = self._prepare_batch(input_fields, data, tuple_indices[i], input_predictions[i]) + data_batches.append(data_batch) + return data_batches + + def forward( + self, + pred_cond: torch.Tensor, + pred_uncond: Optional[torch.Tensor] = None, + pred_cond_skip: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + pred = None + + if not self._is_cfg_enabled() and not self._is_slg_enabled(): + pred = pred_cond + elif not self._is_cfg_enabled(): + shift = pred_cond - pred_cond_skip + pred = pred_cond if self.use_original_formulation else pred_cond_skip + pred = pred + self.skip_layer_guidance_scale * shift + elif not self._is_slg_enabled(): + shift = pred_cond - pred_uncond + pred = pred_cond if self.use_original_formulation else pred_uncond + pred = pred + self.guidance_scale * shift + else: + shift = pred_cond - pred_uncond + shift_skip = pred_cond - pred_cond_skip + pred = pred_cond if self.use_original_formulation else pred_uncond + pred = pred + self.guidance_scale * shift + self.skip_layer_guidance_scale * shift_skip + + if self.guidance_rescale > 0.0: + pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale) + + return pred, {} + + @property + def is_conditional(self) -> bool: + return self._count_prepared == 1 or self._count_prepared == 3 + + @property + def num_conditions(self) -> int: + num_conditions = 1 + if self._is_cfg_enabled(): + num_conditions += 1 + if self._is_slg_enabled(): + num_conditions += 1 + return num_conditions + + def _is_cfg_enabled(self) -> bool: + if not self._enabled: + return False + + is_within_range = True + if self._num_inference_steps is not None: + skip_start_step = int(self._start * self._num_inference_steps) + skip_stop_step = int(self._stop * self._num_inference_steps) + is_within_range = skip_start_step <= self._step < skip_stop_step + + is_close = False + if self.use_original_formulation: + is_close = math.isclose(self.guidance_scale, 0.0) + else: + is_close = math.isclose(self.guidance_scale, 1.0) + + return is_within_range and not is_close + + def _is_slg_enabled(self) -> bool: + if not self._enabled: + return False + + is_within_range = True + if self._num_inference_steps is not None: + skip_start_step = int(self.skip_layer_guidance_start * self._num_inference_steps) + skip_stop_step = int(self.skip_layer_guidance_stop * self._num_inference_steps) + is_within_range = skip_start_step < self._step < skip_stop_step + + is_zero = math.isclose(self.skip_layer_guidance_scale, 0.0) + + return is_within_range and not is_zero diff --git a/src/diffusers/guiders/smoothed_energy_guidance.py b/src/diffusers/guiders/smoothed_energy_guidance.py new file mode 100644 index 000000000000..3986da913f82 --- /dev/null +++ b/src/diffusers/guiders/smoothed_energy_guidance.py @@ -0,0 +1,244 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from typing import List, Optional, Union, TYPE_CHECKING, Dict, Tuple + +import torch + +from ..hooks import HookRegistry +from ..hooks.smoothed_energy_guidance_utils import SmoothedEnergyGuidanceConfig, _apply_smoothed_energy_guidance_hook +from .guider_utils import BaseGuidance, rescale_noise_cfg + +if TYPE_CHECKING: + from ..modular_pipelines.modular_pipeline import BlockState + + +class SmoothedEnergyGuidance(BaseGuidance): + """ + Smoothed Energy Guidance (SEG): https://huggingface.co/papers/2408.00760 + + SEG is only supported as an experimental prototype feature for now, so the implementation may be modified + in the future without warning or guarantee of reproducibility. This implementation assumes: + - Generated images are square (height == width) + - The model does not combine different modalities together (e.g., text and image latent streams are + not combined together such as Flux) + + Args: + guidance_scale (`float`, defaults to `7.5`): + The scale parameter for classifier-free guidance. Higher values result in stronger conditioning on the text + prompt, while lower values allow for more freedom in generation. Higher values may lead to saturation and + deterioration of image quality. + seg_guidance_scale (`float`, defaults to `3.0`): + The scale parameter for smoothed energy guidance. Anatomy and structure coherence may improve with higher + values, but it may also lead to overexposure and saturation. + seg_blur_sigma (`float`, defaults to `9999999.0`): + The amount by which we blur the attention weights. Setting this value greater than 9999.0 results in + infinite blur, which means uniform queries. Controlling it exponentially is empirically effective. + seg_blur_threshold_inf (`float`, defaults to `9999.0`): + The threshold above which the blur is considered infinite. + seg_guidance_start (`float`, defaults to `0.0`): + The fraction of the total number of denoising steps after which smoothed energy guidance starts. + seg_guidance_stop (`float`, defaults to `1.0`): + The fraction of the total number of denoising steps after which smoothed energy guidance stops. + seg_guidance_layers (`int` or `List[int]`, *optional*): + The layer indices to apply smoothed energy guidance to. Can be a single integer or a list of integers. If not + provided, `seg_guidance_config` must be provided. The recommended values are `[7, 8, 9]` for Stable Diffusion + 3.5 Medium. + seg_guidance_config (`SmoothedEnergyGuidanceConfig` or `List[SmoothedEnergyGuidanceConfig]`, *optional*): + The configuration for the smoothed energy layer guidance. Can be a single `SmoothedEnergyGuidanceConfig` or a list of + `SmoothedEnergyGuidanceConfig`. If not provided, `seg_guidance_layers` must be provided. + guidance_rescale (`float`, defaults to `0.0`): + The rescale factor applied to the noise predictions. This is used to improve image quality and fix + overexposure. Based on Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are + Flawed](https://huggingface.co/papers/2305.08891). + use_original_formulation (`bool`, defaults to `False`): + Whether to use the original formulation of classifier-free guidance as proposed in the paper. By default, + we use the diffusers-native implementation that has been in the codebase for a long time. See + [~guiders.classifier_free_guidance.ClassifierFreeGuidance] for more details. + start (`float`, defaults to `0.01`): + The fraction of the total number of denoising steps after which guidance starts. + stop (`float`, defaults to `0.2`): + The fraction of the total number of denoising steps after which guidance stops. + """ + + _input_predictions = ["pred_cond", "pred_uncond", "pred_cond_seg"] + + def __init__( + self, + guidance_scale: float = 7.5, + seg_guidance_scale: float = 2.8, + seg_blur_sigma: float = 9999999.0, + seg_blur_threshold_inf: float = 9999.0, + seg_guidance_start: float = 0.0, + seg_guidance_stop: float = 1.0, + seg_guidance_layers: Optional[Union[int, List[int]]] = None, + seg_guidance_config: Union[SmoothedEnergyGuidanceConfig, List[SmoothedEnergyGuidanceConfig]] = None, + guidance_rescale: float = 0.0, + use_original_formulation: bool = False, + start: float = 0.0, + stop: float = 1.0, + ): + super().__init__(start, stop) + + self.guidance_scale = guidance_scale + self.seg_guidance_scale = seg_guidance_scale + self.seg_blur_sigma = seg_blur_sigma + self.seg_blur_threshold_inf = seg_blur_threshold_inf + self.seg_guidance_start = seg_guidance_start + self.seg_guidance_stop = seg_guidance_stop + self.guidance_rescale = guidance_rescale + self.use_original_formulation = use_original_formulation + + if not (0.0 <= seg_guidance_start < 1.0): + raise ValueError( + f"Expected `seg_guidance_start` to be between 0.0 and 1.0, but got {seg_guidance_start}." + ) + if not (seg_guidance_start <= seg_guidance_stop <= 1.0): + raise ValueError( + f"Expected `seg_guidance_stop` to be between 0.0 and 1.0, but got {seg_guidance_stop}." + ) + + if seg_guidance_layers is None and seg_guidance_config is None: + raise ValueError( + "Either `seg_guidance_layers` or `seg_guidance_config` must be provided to enable Smoothed Energy Guidance." + ) + if seg_guidance_layers is not None and seg_guidance_config is not None: + raise ValueError("Only one of `seg_guidance_layers` or `seg_guidance_config` can be provided.") + + if seg_guidance_layers is not None: + if isinstance(seg_guidance_layers, int): + seg_guidance_layers = [seg_guidance_layers] + if not isinstance(seg_guidance_layers, list): + raise ValueError( + f"Expected `seg_guidance_layers` to be an int or a list of ints, but got {type(seg_guidance_layers)}." + ) + seg_guidance_config = [SmoothedEnergyGuidanceConfig(layer, fqn="auto") for layer in seg_guidance_layers] + + if isinstance(seg_guidance_config, SmoothedEnergyGuidanceConfig): + seg_guidance_config = [seg_guidance_config] + + if not isinstance(seg_guidance_config, list): + raise ValueError( + f"Expected `seg_guidance_config` to be a SmoothedEnergyGuidanceConfig or a list of SmoothedEnergyGuidanceConfig, but got {type(seg_guidance_config)}." + ) + + self.seg_guidance_config = seg_guidance_config + self._seg_layer_hook_names = [f"SmoothedEnergyGuidance_{i}" for i in range(len(self.seg_guidance_config))] + + def prepare_models(self, denoiser: torch.nn.Module) -> None: + if self._is_seg_enabled() and self.is_conditional and self._count_prepared > 1: + for name, config in zip(self._seg_layer_hook_names, self.seg_guidance_config): + _apply_smoothed_energy_guidance_hook(denoiser, config, self.seg_blur_sigma, name=name) + + def cleanup_models(self, denoiser: torch.nn.Module): + if self._is_seg_enabled() and self.is_conditional and self._count_prepared > 1: + registry = HookRegistry.check_if_exists_or_initialize(denoiser) + # Remove the hooks after inference + for hook_name in self._seg_layer_hook_names: + registry.remove_hook(hook_name, recurse=True) + + def prepare_inputs(self, data: "BlockState", input_fields: Optional[Dict[str, Union[str, Tuple[str, str]]]] = None) -> List["BlockState"]: + + if input_fields is None: + input_fields = self._input_fields + + if self.num_conditions == 1: + tuple_indices = [0] + input_predictions = ["pred_cond"] + elif self.num_conditions == 2: + tuple_indices = [0, 1] + input_predictions = ["pred_cond", "pred_uncond"] if self._is_cfg_enabled() else ["pred_cond", "pred_cond_seg"] + else: + tuple_indices = [0, 1, 0] + input_predictions = ["pred_cond", "pred_uncond", "pred_cond_seg"] + data_batches = [] + for i in range(self.num_conditions): + data_batch = self._prepare_batch(input_fields, data, tuple_indices[i], input_predictions[i]) + data_batches.append(data_batch) + return data_batches + + def forward( + self, + pred_cond: torch.Tensor, + pred_uncond: Optional[torch.Tensor] = None, + pred_cond_seg: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + pred = None + + if not self._is_cfg_enabled() and not self._is_seg_enabled(): + pred = pred_cond + elif not self._is_cfg_enabled(): + shift = pred_cond - pred_cond_seg + pred = pred_cond if self.use_original_formulation else pred_cond_seg + pred = pred + self.seg_guidance_scale * shift + elif not self._is_seg_enabled(): + shift = pred_cond - pred_uncond + pred = pred_cond if self.use_original_formulation else pred_uncond + pred = pred + self.guidance_scale * shift + else: + shift = pred_cond - pred_uncond + shift_seg = pred_cond - pred_cond_seg + pred = pred_cond if self.use_original_formulation else pred_uncond + pred = pred + self.guidance_scale * shift + self.seg_guidance_scale * shift_seg + + if self.guidance_rescale > 0.0: + pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale) + + return pred, {} + + @property + def is_conditional(self) -> bool: + return self._count_prepared == 1 or self._count_prepared == 3 + + @property + def num_conditions(self) -> int: + num_conditions = 1 + if self._is_cfg_enabled(): + num_conditions += 1 + if self._is_seg_enabled(): + num_conditions += 1 + return num_conditions + + def _is_cfg_enabled(self) -> bool: + if not self._enabled: + return False + + is_within_range = True + if self._num_inference_steps is not None: + skip_start_step = int(self._start * self._num_inference_steps) + skip_stop_step = int(self._stop * self._num_inference_steps) + is_within_range = skip_start_step <= self._step < skip_stop_step + + is_close = False + if self.use_original_formulation: + is_close = math.isclose(self.guidance_scale, 0.0) + else: + is_close = math.isclose(self.guidance_scale, 1.0) + + return is_within_range and not is_close + + def _is_seg_enabled(self) -> bool: + if not self._enabled: + return False + + is_within_range = True + if self._num_inference_steps is not None: + skip_start_step = int(self.seg_guidance_start * self._num_inference_steps) + skip_stop_step = int(self.seg_guidance_stop * self._num_inference_steps) + is_within_range = skip_start_step < self._step < skip_stop_step + + is_zero = math.isclose(self.seg_guidance_scale, 0.0) + + return is_within_range and not is_zero diff --git a/src/diffusers/guiders/tangential_classifier_free_guidance.py b/src/diffusers/guiders/tangential_classifier_free_guidance.py new file mode 100644 index 000000000000..017693fd9f07 --- /dev/null +++ b/src/diffusers/guiders/tangential_classifier_free_guidance.py @@ -0,0 +1,137 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from typing import Optional, List, TYPE_CHECKING, Dict, Union, Tuple + +import torch + +from .guider_utils import BaseGuidance, rescale_noise_cfg + +if TYPE_CHECKING: + from ..modular_pipelines.modular_pipeline import BlockState + + +class TangentialClassifierFreeGuidance(BaseGuidance): + """ + Tangential Classifier Free Guidance (TCFG): https://huggingface.co/papers/2503.18137 + + Args: + guidance_scale (`float`, defaults to `7.5`): + The scale parameter for classifier-free guidance. Higher values result in stronger conditioning on the text + prompt, while lower values allow for more freedom in generation. Higher values may lead to saturation and + deterioration of image quality. + guidance_rescale (`float`, defaults to `0.0`): + The rescale factor applied to the noise predictions. This is used to improve image quality and fix + overexposure. Based on Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are + Flawed](https://huggingface.co/papers/2305.08891). + use_original_formulation (`bool`, defaults to `False`): + Whether to use the original formulation of classifier-free guidance as proposed in the paper. By default, + we use the diffusers-native implementation that has been in the codebase for a long time. See + [~guiders.classifier_free_guidance.ClassifierFreeGuidance] for more details. + start (`float`, defaults to `0.0`): + The fraction of the total number of denoising steps after which guidance starts. + stop (`float`, defaults to `1.0`): + The fraction of the total number of denoising steps after which guidance stops. + """ + + _input_predictions = ["pred_cond", "pred_uncond"] + + def __init__( + self, + guidance_scale: float = 7.5, + guidance_rescale: float = 0.0, + use_original_formulation: bool = False, + start: float = 0.0, + stop: float = 1.0, + ): + super().__init__(start, stop) + + self.guidance_scale = guidance_scale + self.guidance_rescale = guidance_rescale + self.use_original_formulation = use_original_formulation + + def prepare_inputs(self, data: "BlockState", input_fields: Optional[Dict[str, Union[str, Tuple[str, str]]]] = None) -> List["BlockState"]: + + if input_fields is None: + input_fields = self._input_fields + + tuple_indices = [0] if self.num_conditions == 1 else [0, 1] + data_batches = [] + for i in range(self.num_conditions): + data_batch = self._prepare_batch(input_fields, data, tuple_indices[i], self._input_predictions[i]) + data_batches.append(data_batch) + return data_batches + + def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> torch.Tensor: + pred = None + + if not self._is_tcfg_enabled(): + pred = pred_cond + else: + pred = normalized_guidance(pred_cond, pred_uncond, self.guidance_scale, self.use_original_formulation) + + if self.guidance_rescale > 0.0: + pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale) + + return pred, {} + + @property + def is_conditional(self) -> bool: + return self._num_outputs_prepared == 1 + + @property + def num_conditions(self) -> int: + num_conditions = 1 + if self._is_tcfg_enabled(): + num_conditions += 1 + return num_conditions + + def _is_tcfg_enabled(self) -> bool: + if not self._enabled: + return False + + is_within_range = True + if self._num_inference_steps is not None: + skip_start_step = int(self._start * self._num_inference_steps) + skip_stop_step = int(self._stop * self._num_inference_steps) + is_within_range = skip_start_step <= self._step < skip_stop_step + + is_close = False + if self.use_original_formulation: + is_close = math.isclose(self.guidance_scale, 0.0) + else: + is_close = math.isclose(self.guidance_scale, 1.0) + + return is_within_range and not is_close + + +def normalized_guidance(pred_cond: torch.Tensor, pred_uncond: torch.Tensor, guidance_scale: float, use_original_formulation: bool = False) -> torch.Tensor: + cond_dtype = pred_cond.dtype + preds = torch.stack([pred_cond, pred_uncond], dim=1).float() + preds = preds.flatten(2) + U, S, Vh = torch.linalg.svd(preds, full_matrices=False) + Vh_modified = Vh.clone() + Vh_modified[:, 1] = 0 + + uncond_flat = pred_uncond.reshape(pred_uncond.size(0), 1, -1).float() + x_Vh = torch.matmul(uncond_flat, Vh.transpose(-2, -1)) + x_Vh_V = torch.matmul(x_Vh, Vh_modified) + pred_uncond = x_Vh_V.reshape(pred_uncond.shape).to(cond_dtype) + + pred = pred_cond if use_original_formulation else pred_uncond + shift = pred_cond - pred_uncond + pred = pred + guidance_scale * shift + + return pred diff --git a/src/diffusers/hooks/__init__.py b/src/diffusers/hooks/__init__.py index 764ceb25b465..9d0e96e9e79e 100644 --- a/src/diffusers/hooks/__init__.py +++ b/src/diffusers/hooks/__init__.py @@ -5,5 +5,7 @@ from .faster_cache import FasterCacheConfig, apply_faster_cache from .group_offloading import apply_group_offloading from .hooks import HookRegistry, ModelHook + from .layer_skip import LayerSkipConfig, apply_layer_skip from .layerwise_casting import apply_layerwise_casting, apply_layerwise_casting_hook from .pyramid_attention_broadcast import PyramidAttentionBroadcastConfig, apply_pyramid_attention_broadcast + from .smoothed_energy_guidance_utils import SmoothedEnergyGuidanceConfig diff --git a/src/diffusers/hooks/_common.py b/src/diffusers/hooks/_common.py new file mode 100644 index 000000000000..3d9c99e8189f --- /dev/null +++ b/src/diffusers/hooks/_common.py @@ -0,0 +1,43 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional + +import torch + +from ..models.attention import FeedForward, LuminaFeedForward +from ..models.attention_processor import Attention, MochiAttention + + +_ATTENTION_CLASSES = (Attention, MochiAttention) +_FEEDFORWARD_CLASSES = (FeedForward, LuminaFeedForward) + +_SPATIAL_TRANSFORMER_BLOCK_IDENTIFIERS = ("blocks", "transformer_blocks", "single_transformer_blocks", "layers") +_TEMPORAL_TRANSFORMER_BLOCK_IDENTIFIERS = ("temporal_transformer_blocks",) +_CROSS_TRANSFORMER_BLOCK_IDENTIFIERS = ("blocks", "transformer_blocks", "layers") + +_ALL_TRANSFORMER_BLOCK_IDENTIFIERS = tuple( + { + *_SPATIAL_TRANSFORMER_BLOCK_IDENTIFIERS, + *_TEMPORAL_TRANSFORMER_BLOCK_IDENTIFIERS, + *_CROSS_TRANSFORMER_BLOCK_IDENTIFIERS, + } +) + + +def _get_submodule_from_fqn(module: torch.nn.Module, fqn: str) -> Optional[torch.nn.Module]: + for submodule_name, submodule in module.named_modules(): + if submodule_name == fqn: + return submodule + return None diff --git a/src/diffusers/hooks/_helpers.py b/src/diffusers/hooks/_helpers.py new file mode 100644 index 000000000000..9043ffc41838 --- /dev/null +++ b/src/diffusers/hooks/_helpers.py @@ -0,0 +1,271 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass +from typing import Any, Callable, Type + +from ..models.attention import BasicTransformerBlock +from ..models.attention_processor import AttnProcessor2_0 +from ..models.transformers.cogvideox_transformer_3d import CogVideoXBlock +from ..models.transformers.transformer_cogview4 import CogView4AttnProcessor, CogView4TransformerBlock +from ..models.transformers.transformer_flux import FluxSingleTransformerBlock, FluxTransformerBlock +from ..models.transformers.transformer_hunyuan_video import ( + HunyuanVideoSingleTransformerBlock, + HunyuanVideoTokenReplaceSingleTransformerBlock, + HunyuanVideoTokenReplaceTransformerBlock, + HunyuanVideoTransformerBlock, +) +from ..models.transformers.transformer_ltx import LTXVideoTransformerBlock +from ..models.transformers.transformer_mochi import MochiTransformerBlock +from ..models.transformers.transformer_wan import WanTransformerBlock + + +@dataclass +class AttentionProcessorMetadata: + skip_processor_output_fn: Callable[[Any], Any] + + +@dataclass +class TransformerBlockMetadata: + skip_block_output_fn: Callable[[Any], Any] + return_hidden_states_index: int = None + return_encoder_hidden_states_index: int = None + + +class AttentionProcessorRegistry: + _registry = {} + + @classmethod + def register(cls, model_class: Type, metadata: AttentionProcessorMetadata): + cls._registry[model_class] = metadata + + @classmethod + def get(cls, model_class: Type) -> AttentionProcessorMetadata: + if model_class not in cls._registry: + raise ValueError(f"Model class {model_class} not registered.") + return cls._registry[model_class] + + +class TransformerBlockRegistry: + _registry = {} + + @classmethod + def register(cls, model_class: Type, metadata: TransformerBlockMetadata): + cls._registry[model_class] = metadata + + @classmethod + def get(cls, model_class: Type) -> TransformerBlockMetadata: + if model_class not in cls._registry: + raise ValueError(f"Model class {model_class} not registered.") + return cls._registry[model_class] + + +def _register_attention_processors_metadata(): + # AttnProcessor2_0 + AttentionProcessorRegistry.register( + model_class=AttnProcessor2_0, + metadata=AttentionProcessorMetadata( + skip_processor_output_fn=_skip_proc_output_fn_Attention_AttnProcessor2_0, + ), + ) + + # CogView4AttnProcessor + AttentionProcessorRegistry.register( + model_class=CogView4AttnProcessor, + metadata=AttentionProcessorMetadata( + skip_processor_output_fn=_skip_proc_output_fn_Attention_CogView4AttnProcessor, + ), + ) + + +def _register_transformer_blocks_metadata(): + # BasicTransformerBlock + TransformerBlockRegistry.register( + model_class=BasicTransformerBlock, + metadata=TransformerBlockMetadata( + skip_block_output_fn=_skip_block_output_fn_BasicTransformerBlock, + return_hidden_states_index=0, + return_encoder_hidden_states_index=None, + ), + ) + + # CogVideoX + TransformerBlockRegistry.register( + model_class=CogVideoXBlock, + metadata=TransformerBlockMetadata( + skip_block_output_fn=_skip_block_output_fn_CogVideoXBlock, + return_hidden_states_index=0, + return_encoder_hidden_states_index=1, + ), + ) + + # CogView4 + TransformerBlockRegistry.register( + model_class=CogView4TransformerBlock, + metadata=TransformerBlockMetadata( + skip_block_output_fn=_skip_block_output_fn_CogView4TransformerBlock, + return_hidden_states_index=0, + return_encoder_hidden_states_index=1, + ), + ) + + # Flux + TransformerBlockRegistry.register( + model_class=FluxTransformerBlock, + metadata=TransformerBlockMetadata( + skip_block_output_fn=_skip_block_output_fn_FluxTransformerBlock, + return_hidden_states_index=1, + return_encoder_hidden_states_index=0, + ), + ) + TransformerBlockRegistry.register( + model_class=FluxSingleTransformerBlock, + metadata=TransformerBlockMetadata( + skip_block_output_fn=_skip_block_output_fn_FluxSingleTransformerBlock, + return_hidden_states_index=1, + return_encoder_hidden_states_index=0, + ), + ) + + # HunyuanVideo + TransformerBlockRegistry.register( + model_class=HunyuanVideoTransformerBlock, + metadata=TransformerBlockMetadata( + skip_block_output_fn=_skip_block_output_fn_HunyuanVideoTransformerBlock, + return_hidden_states_index=0, + return_encoder_hidden_states_index=1, + ), + ) + TransformerBlockRegistry.register( + model_class=HunyuanVideoSingleTransformerBlock, + metadata=TransformerBlockMetadata( + skip_block_output_fn=_skip_block_output_fn_HunyuanVideoSingleTransformerBlock, + return_hidden_states_index=0, + return_encoder_hidden_states_index=1, + ), + ) + TransformerBlockRegistry.register( + model_class=HunyuanVideoTokenReplaceTransformerBlock, + metadata=TransformerBlockMetadata( + skip_block_output_fn=_skip_block_output_fn_HunyuanVideoTokenReplaceTransformerBlock, + return_hidden_states_index=0, + return_encoder_hidden_states_index=1, + ), + ) + TransformerBlockRegistry.register( + model_class=HunyuanVideoTokenReplaceSingleTransformerBlock, + metadata=TransformerBlockMetadata( + skip_block_output_fn=_skip_block_output_fn_HunyuanVideoTokenReplaceSingleTransformerBlock, + return_hidden_states_index=0, + return_encoder_hidden_states_index=1, + ), + ) + + # LTXVideo + TransformerBlockRegistry.register( + model_class=LTXVideoTransformerBlock, + metadata=TransformerBlockMetadata( + skip_block_output_fn=_skip_block_output_fn_LTXVideoTransformerBlock, + return_hidden_states_index=0, + return_encoder_hidden_states_index=None, + ), + ) + + # Mochi + TransformerBlockRegistry.register( + model_class=MochiTransformerBlock, + metadata=TransformerBlockMetadata( + skip_block_output_fn=_skip_block_output_fn_MochiTransformerBlock, + return_hidden_states_index=0, + return_encoder_hidden_states_index=1, + ), + ) + + # Wan + TransformerBlockRegistry.register( + model_class=WanTransformerBlock, + metadata=TransformerBlockMetadata( + skip_block_output_fn=_skip_block_output_fn_WanTransformerBlock, + return_hidden_states_index=0, + return_encoder_hidden_states_index=None, + ), + ) + + +# fmt: off +def _skip_attention___ret___hidden_states(self, *args, **kwargs): + hidden_states = kwargs.get("hidden_states", None) + if hidden_states is None and len(args) > 0: + hidden_states = args[0] + return hidden_states + + +def _skip_attention___ret___hidden_states___encoder_hidden_states(self, *args, **kwargs): + hidden_states = kwargs.get("hidden_states", None) + encoder_hidden_states = kwargs.get("encoder_hidden_states", None) + if hidden_states is None and len(args) > 0: + hidden_states = args[0] + if encoder_hidden_states is None and len(args) > 1: + encoder_hidden_states = args[1] + return hidden_states, encoder_hidden_states + + +_skip_proc_output_fn_Attention_AttnProcessor2_0 = _skip_attention___ret___hidden_states +_skip_proc_output_fn_Attention_CogView4AttnProcessor = _skip_attention___ret___hidden_states___encoder_hidden_states + + +def _skip_block_output_fn___hidden_states_0___ret___hidden_states(self, *args, **kwargs): + hidden_states = kwargs.get("hidden_states", None) + if hidden_states is None and len(args) > 0: + hidden_states = args[0] + return hidden_states + + +def _skip_block_output_fn___hidden_states_0___encoder_hidden_states_1___ret___hidden_states___encoder_hidden_states(self, *args, **kwargs): + hidden_states = kwargs.get("hidden_states", None) + encoder_hidden_states = kwargs.get("encoder_hidden_states", None) + if hidden_states is None and len(args) > 0: + hidden_states = args[0] + if encoder_hidden_states is None and len(args) > 1: + encoder_hidden_states = args[1] + return hidden_states, encoder_hidden_states + + +def _skip_block_output_fn___hidden_states_0___encoder_hidden_states_1___ret___encoder_hidden_states___hidden_states(self, *args, **kwargs): + hidden_states = kwargs.get("hidden_states", None) + encoder_hidden_states = kwargs.get("encoder_hidden_states", None) + if hidden_states is None and len(args) > 0: + hidden_states = args[0] + if encoder_hidden_states is None and len(args) > 1: + encoder_hidden_states = args[1] + return encoder_hidden_states, hidden_states + + +_skip_block_output_fn_BasicTransformerBlock = _skip_block_output_fn___hidden_states_0___ret___hidden_states +_skip_block_output_fn_CogVideoXBlock = _skip_block_output_fn___hidden_states_0___encoder_hidden_states_1___ret___hidden_states___encoder_hidden_states +_skip_block_output_fn_CogView4TransformerBlock = _skip_block_output_fn___hidden_states_0___encoder_hidden_states_1___ret___hidden_states___encoder_hidden_states +_skip_block_output_fn_FluxTransformerBlock = _skip_block_output_fn___hidden_states_0___encoder_hidden_states_1___ret___encoder_hidden_states___hidden_states +_skip_block_output_fn_FluxSingleTransformerBlock = _skip_block_output_fn___hidden_states_0___encoder_hidden_states_1___ret___encoder_hidden_states___hidden_states +_skip_block_output_fn_HunyuanVideoTransformerBlock = _skip_block_output_fn___hidden_states_0___encoder_hidden_states_1___ret___hidden_states___encoder_hidden_states +_skip_block_output_fn_HunyuanVideoSingleTransformerBlock = _skip_block_output_fn___hidden_states_0___encoder_hidden_states_1___ret___hidden_states___encoder_hidden_states +_skip_block_output_fn_HunyuanVideoTokenReplaceTransformerBlock = _skip_block_output_fn___hidden_states_0___encoder_hidden_states_1___ret___hidden_states___encoder_hidden_states +_skip_block_output_fn_HunyuanVideoTokenReplaceSingleTransformerBlock = _skip_block_output_fn___hidden_states_0___encoder_hidden_states_1___ret___hidden_states___encoder_hidden_states +_skip_block_output_fn_LTXVideoTransformerBlock = _skip_block_output_fn___hidden_states_0___ret___hidden_states +_skip_block_output_fn_MochiTransformerBlock = _skip_block_output_fn___hidden_states_0___encoder_hidden_states_1___ret___hidden_states___encoder_hidden_states +_skip_block_output_fn_WanTransformerBlock = _skip_block_output_fn___hidden_states_0___ret___hidden_states +# fmt: on + + +_register_attention_processors_metadata() +_register_transformer_blocks_metadata() diff --git a/src/diffusers/hooks/layer_skip.py b/src/diffusers/hooks/layer_skip.py new file mode 100644 index 000000000000..65a99464ba2f --- /dev/null +++ b/src/diffusers/hooks/layer_skip.py @@ -0,0 +1,231 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from dataclasses import dataclass +from typing import Callable, List, Optional + +import torch + +from ..utils import get_logger +from ..utils.torch_utils import unwrap_module +from ._common import _ALL_TRANSFORMER_BLOCK_IDENTIFIERS, _ATTENTION_CLASSES, _FEEDFORWARD_CLASSES, _get_submodule_from_fqn +from ._helpers import AttentionProcessorRegistry, TransformerBlockRegistry +from .hooks import HookRegistry, ModelHook + + +logger = get_logger(__name__) # pylint: disable=invalid-name + +_LAYER_SKIP_HOOK = "layer_skip_hook" + + +# Aryan/YiYi TODO: we need to make guider class a config mixin so I think this is not needed +# either remove or make it serializable +@dataclass +class LayerSkipConfig: + r""" + Configuration for skipping internal transformer blocks when executing a transformer model. + + Args: + indices (`List[int]`): + The indices of the layer to skip. This is typically the first layer in the transformer block. + fqn (`str`, defaults to `"auto"`): + The fully qualified name identifying the stack of transformer blocks. Typically, this is + `transformer_blocks`, `single_transformer_blocks`, `blocks`, `layers`, or `temporal_transformer_blocks`. + For automatic detection, set this to `"auto"`. + "auto" only works on DiT models. For UNet models, you must provide the correct fqn. + skip_attention (`bool`, defaults to `True`): + Whether to skip attention blocks. + skip_ff (`bool`, defaults to `True`): + Whether to skip feed-forward blocks. + skip_attention_scores (`bool`, defaults to `False`): + Whether to skip attention score computation in the attention blocks. This is equivalent to using `value` + projections as the output of scaled dot product attention. + dropout (`float`, defaults to `1.0`): + The dropout probability for dropping the outputs of the skipped layers. By default, this is set to `1.0`, + meaning that the outputs of the skipped layers are completely ignored. If set to `0.0`, the outputs of the + skipped layers are fully retained, which is equivalent to not skipping any layers. + """ + + indices: List[int] + fqn: str = "auto" + skip_attention: bool = True + skip_attention_scores: bool = False + skip_ff: bool = True + dropout: float = 1.0 + + def __post_init__(self): + if not (0 <= self.dropout <= 1): + raise ValueError(f"Expected `dropout` to be between 0.0 and 1.0, but got {self.dropout}.") + if not math.isclose(self.dropout, 1.0) and self.skip_attention_scores: + raise ValueError( + "Cannot set `skip_attention_scores` to True when `dropout` is not 1.0. Please set `dropout` to 1.0." + ) + + +class AttentionScoreSkipFunctionMode(torch.overrides.TorchFunctionMode): + def __torch_function__(self, func, types, args=(), kwargs=None): + if kwargs is None: + kwargs = {} + if func is torch.nn.functional.scaled_dot_product_attention: + value = kwargs.get("value", None) + if value is None: + value = args[2] + return value + return func(*args, **kwargs) + + +class AttentionProcessorSkipHook(ModelHook): + def __init__(self, skip_processor_output_fn: Callable, skip_attention_scores: bool = False, dropout: float = 1.0): + self.skip_processor_output_fn = skip_processor_output_fn + self.skip_attention_scores = skip_attention_scores + self.dropout = dropout + + def new_forward(self, module: torch.nn.Module, *args, **kwargs): + if self.skip_attention_scores: + if not math.isclose(self.dropout, 1.0): + raise ValueError( + "Cannot set `skip_attention_scores` to True when `dropout` is not 1.0. Please set `dropout` to 1.0." + ) + with AttentionScoreSkipFunctionMode(): + output = self.fn_ref.original_forward(*args, **kwargs) + else: + if math.isclose(self.dropout, 1.0): + output = self.skip_processor_output_fn(module, *args, **kwargs) + else: + output = self.fn_ref.original_forward(*args, **kwargs) + output = torch.nn.functional.dropout(output, p=self.dropout) + return output + + +class FeedForwardSkipHook(ModelHook): + def __init__(self, dropout: float): + super().__init__() + self.dropout = dropout + + def new_forward(self, module: torch.nn.Module, *args, **kwargs): + if math.isclose(self.dropout, 1.0): + output = kwargs.get("hidden_states", None) + if output is None: + output = kwargs.get("x", None) + if output is None and len(args) > 0: + output = args[0] + else: + output = self.fn_ref.original_forward(*args, **kwargs) + output = torch.nn.functional.dropout(output, p=self.dropout) + return output + + +class TransformerBlockSkipHook(ModelHook): + def __init__(self, dropout: float): + super().__init__() + self.dropout = dropout + + def initialize_hook(self, module): + self._metadata = TransformerBlockRegistry.get(unwrap_module(module).__class__) + return module + + def new_forward(self, module: torch.nn.Module, *args, **kwargs): + if math.isclose(self.dropout, 1.0): + output = self._metadata.skip_block_output_fn(module, *args, **kwargs) + else: + output = self.fn_ref.original_forward(*args, **kwargs) + output = torch.nn.functional.dropout(output, p=self.dropout) + return output + +def apply_layer_skip(module: torch.nn.Module, config: LayerSkipConfig) -> None: + r""" + Apply layer skipping to internal layers of a transformer. + + Args: + module (`torch.nn.Module`): + The transformer model to which the layer skip hook should be applied. + config (`LayerSkipConfig`): + The configuration for the layer skip hook. + + Example: + + ```python + >>> from diffusers import apply_layer_skip_hook, CogVideoXTransformer3DModel, LayerSkipConfig + >>> transformer = CogVideoXTransformer3DModel.from_pretrained("THUDM/CogVideoX-5b", torch_dtype=torch.bfloat16) + >>> config = LayerSkipConfig(layer_index=[10, 20], fqn="transformer_blocks") + >>> apply_layer_skip_hook(transformer, config) + ``` + """ + _apply_layer_skip_hook(module, config) + + +def _apply_layer_skip_hook(module: torch.nn.Module, config: LayerSkipConfig, name: Optional[str] = None) -> None: + name = name or _LAYER_SKIP_HOOK + + if config.skip_attention and config.skip_attention_scores: + raise ValueError("Cannot set both `skip_attention` and `skip_attention_scores` to True. Please choose one.") + if not math.isclose(config.dropout, 1.0) and config.skip_attention_scores: + raise ValueError("Cannot set `skip_attention_scores` to True when `dropout` is not 1.0. Please set `dropout` to 1.0.") + + if config.fqn == "auto": + for identifier in _ALL_TRANSFORMER_BLOCK_IDENTIFIERS: + if hasattr(module, identifier): + config.fqn = identifier + break + else: + raise ValueError( + "Could not find a suitable identifier for the transformer blocks automatically. Please provide a valid " + "`fqn` (fully qualified name) that identifies a stack of transformer blocks." + ) + + transformer_blocks = _get_submodule_from_fqn(module, config.fqn) + if transformer_blocks is None or not isinstance(transformer_blocks, torch.nn.ModuleList): + raise ValueError( + f"Could not find {config.fqn} in the provided module, or configured `fqn` (fully qualified name) does not identify " + f"a `torch.nn.ModuleList`. Please provide a valid `fqn` that identifies a stack of transformer blocks." + ) + if len(config.indices) == 0: + raise ValueError("Layer index list is empty. Please provide a non-empty list of layer indices to skip.") + + blocks_found = False + for i, block in enumerate(transformer_blocks): + if i not in config.indices: + continue + + blocks_found = True + + if config.skip_attention and config.skip_ff: + logger.debug(f"Applying TransformerBlockSkipHook to '{config.fqn}.{i}'") + registry = HookRegistry.check_if_exists_or_initialize(block) + hook = TransformerBlockSkipHook(config.dropout) + registry.register_hook(hook, name) + + elif config.skip_attention or config.skip_attention_scores: + for submodule_name, submodule in block.named_modules(): + if isinstance(submodule, _ATTENTION_CLASSES) and not submodule.is_cross_attention: + logger.debug(f"Applying AttentionProcessorSkipHook to '{config.fqn}.{i}.{submodule_name}'") + output_fn = AttentionProcessorRegistry.get(submodule.processor.__class__).skip_processor_output_fn + registry = HookRegistry.check_if_exists_or_initialize(submodule) + hook = AttentionProcessorSkipHook(output_fn, config.skip_attention_scores, config.dropout) + registry.register_hook(hook, name) + + if config.skip_ff: + for submodule_name, submodule in block.named_modules(): + if isinstance(submodule, _FEEDFORWARD_CLASSES): + logger.debug(f"Applying FeedForwardSkipHook to '{config.fqn}.{i}.{submodule_name}'") + registry = HookRegistry.check_if_exists_or_initialize(submodule) + hook = FeedForwardSkipHook(config.dropout) + registry.register_hook(hook, name) + + if not blocks_found: + raise ValueError( + f"Could not find any transformer blocks matching the provided indices {config.indices} and " + f"fully qualified name '{config.fqn}'. Please check the indices and fqn for correctness." + ) diff --git a/src/diffusers/hooks/smoothed_energy_guidance_utils.py b/src/diffusers/hooks/smoothed_energy_guidance_utils.py new file mode 100644 index 000000000000..f0366e29887f --- /dev/null +++ b/src/diffusers/hooks/smoothed_energy_guidance_utils.py @@ -0,0 +1,158 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Tuple + +import torch +import torch.nn.functional as F + +from ..utils import get_logger +from ._common import _ATTENTION_CLASSES, _get_submodule_from_fqn +from .hooks import HookRegistry, ModelHook + + +logger = get_logger(__name__) # pylint: disable=invalid-name + +_SMOOTHED_ENERGY_GUIDANCE_HOOK = "smoothed_energy_guidance_hook" + + +@dataclass +class SmoothedEnergyGuidanceConfig: + r""" + Configuration for skipping internal transformer blocks when executing a transformer model. + + Args: + indices (`List[int]`): + The indices of the layer to skip. This is typically the first layer in the transformer block. + fqn (`str`, defaults to `"auto"`): + The fully qualified name identifying the stack of transformer blocks. Typically, this is + `transformer_blocks`, `single_transformer_blocks`, `blocks`, `layers`, or `temporal_transformer_blocks`. + For automatic detection, set this to `"auto"`. + "auto" only works on DiT models. For UNet models, you must provide the correct fqn. + _query_proj_identifiers (`List[str]`, defaults to `None`): + The identifiers for the query projection layers. Typically, these are `to_q`, `query`, or `q_proj`. + If `None`, `to_q` is used by default. + """ + + indices: List[int] + fqn: str = "auto" + _query_proj_identifiers: List[str] = None + + +class SmoothedEnergyGuidanceHook(ModelHook): + def __init__(self, blur_sigma: float = 1.0, blur_threshold_inf: float = 9999.9) -> None: + super().__init__() + self.blur_sigma = blur_sigma + self.blur_threshold_inf = blur_threshold_inf + + def post_forward(self, module: torch.nn.Module, output: torch.Tensor) -> torch.Tensor: + # Copied from https://github.com/SusungHong/SEG-SDXL/blob/cf8256d640d5373541cfea3b3b6caf93272cf986/pipeline_seg.py#L172C31-L172C102 + kernel_size = math.ceil(6 * self.blur_sigma) + 1 - math.ceil(6 * self.blur_sigma) % 2 + smoothed_output = _gaussian_blur_2d(output, kernel_size, self.blur_sigma, self.blur_threshold_inf) + return smoothed_output + + +def _apply_smoothed_energy_guidance_hook(module: torch.nn.Module, config: SmoothedEnergyGuidanceConfig, blur_sigma: float, name: Optional[str] = None) -> None: + name = name or _SMOOTHED_ENERGY_GUIDANCE_HOOK + + if config.fqn == "auto": + for identifier in _ALL_TRANSFORMER_BLOCK_IDENTIFIERS: + if hasattr(module, identifier): + config.fqn = identifier + break + else: + raise ValueError( + "Could not find a suitable identifier for the transformer blocks automatically. Please provide a valid " + "`fqn` (fully qualified name) that identifies a stack of transformer blocks." + ) + + if config._query_proj_identifiers is None: + config._query_proj_identifiers = ["to_q"] + + transformer_blocks = _get_submodule_from_fqn(module, config.fqn) + blocks_found = False + for i, block in enumerate(transformer_blocks): + if i not in config.indices: + continue + + blocks_found = True + + for submodule_name, submodule in block.named_modules(): + if not isinstance(submodule, _ATTENTION_CLASSES) or submodule.is_cross_attention: + continue + for identifier in config._query_proj_identifiers: + query_proj = getattr(submodule, identifier, None) + if query_proj is None or not isinstance(query_proj, torch.nn.Linear): + continue + logger.debug( + f"Registering smoothed energy guidance hook on {config.fqn}.{i}.{submodule_name}.{identifier}" + ) + registry = HookRegistry.check_if_exists_or_initialize(query_proj) + hook = SmoothedEnergyGuidanceHook(blur_sigma) + registry.register_hook(hook, name) + + if not blocks_found: + raise ValueError( + f"Could not find any transformer blocks matching the provided indices {config.indices} and " + f"fully qualified name '{config.fqn}'. Please check the indices and fqn for correctness." + ) + + +# Modified from https://github.com/SusungHong/SEG-SDXL/blob/cf8256d640d5373541cfea3b3b6caf93272cf986/pipeline_seg.py#L71 +def _gaussian_blur_2d(query: torch.Tensor, kernel_size: int, sigma: float, sigma_threshold_inf: float) -> torch.Tensor: + """ + This implementation assumes that the input query is for visual (image/videos) tokens to apply the 2D gaussian + blur. However, some models use joint text-visual token attention for which this may not be suitable. Additionally, + this implementation also assumes that the visual tokens come from a square image/video. In practice, despite + these assumptions, applying the 2D square gaussian blur on the query projections generates reasonable results + for Smoothed Energy Guidance. + + SEG is only supported as an experimental prototype feature for now, so the implementation may be modified + in the future without warning or guarantee of reproducibility. + """ + assert query.ndim == 3 + + is_inf = sigma > sigma_threshold_inf + batch_size, seq_len, embed_dim = query.shape + + seq_len_sqrt = int(math.sqrt(seq_len)) + num_square_tokens = seq_len_sqrt * seq_len_sqrt + query_slice = query[:, :num_square_tokens, :] + query_slice = query_slice.permute(0, 2, 1) + query_slice = query_slice.reshape(batch_size, embed_dim, seq_len_sqrt, seq_len_sqrt) + + if is_inf: + kernel_size = min(kernel_size, seq_len_sqrt - (seq_len_sqrt % 2 - 1)) + kernel_size_half = (kernel_size - 1) / 2 + + x = torch.linspace(-kernel_size_half, kernel_size_half, steps=kernel_size) + pdf = torch.exp(-0.5 * (x / sigma).pow(2)) + kernel1d = pdf / pdf.sum() + kernel1d = kernel1d.to(query) + kernel2d = torch.matmul(kernel1d[:, None], kernel1d[None, :]) + kernel2d = kernel2d.expand(embed_dim, 1, kernel2d.shape[0], kernel2d.shape[1]) + + padding = [kernel_size // 2, kernel_size // 2, kernel_size // 2, kernel_size // 2] + query_slice = F.pad(query_slice, padding, mode="reflect") + query_slice = F.conv2d(query_slice, kernel2d, groups=embed_dim) + else: + query_slice[:] = query_slice.mean(dim=(-2, -1), keepdim=True) + + query_slice = query_slice.reshape(batch_size, embed_dim, num_square_tokens) + query_slice = query_slice.permute(0, 2, 1) + query[:, :num_square_tokens, :] = query_slice.clone() + + return query diff --git a/src/diffusers/modular_pipelines/__init__.py b/src/diffusers/modular_pipelines/__init__.py new file mode 100644 index 000000000000..cb2ed78ce360 --- /dev/null +++ b/src/diffusers/modular_pipelines/__init__.py @@ -0,0 +1,82 @@ +from typing import TYPE_CHECKING + +from ..utils import ( + DIFFUSERS_SLOW_IMPORT, + OptionalDependencyNotAvailable, + _LazyModule, + get_objects_from_module, + is_torch_available, + is_transformers_available, +) + + +# These modules contain pipelines from multiple libraries/frameworks +_dummy_objects = {} +_import_structure = {} + +try: + if not is_torch_available(): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ..utils import dummy_pt_objects # noqa F403 + + _dummy_objects.update(get_objects_from_module(dummy_pt_objects)) +else: + _import_structure["modular_pipeline"] = [ + "ModularPipelineMixin", + "PipelineBlock", + "AutoPipelineBlocks", + "SequentialPipelineBlocks", + "LoopSequentialPipelineBlocks", + "ModularLoader", + "PipelineState", + "BlockState", + ] + _import_structure["modular_pipeline_utils"] = [ + "ComponentSpec", + "ConfigSpec", + "InputParam", + "OutputParam", + ] + _import_structure["stable_diffusion_xl"] = ["StableDiffusionXLAutoPipeline", "StableDiffusionXLModularLoader"] + _import_structure["components_manager"] = ["ComponentsManager"] + +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + try: + if not is_torch_available(): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + from ..utils.dummy_pt_objects import * # noqa F403 + else: + from .modular_pipeline import ( + AutoPipelineBlocks, + BlockState, + LoopSequentialPipelineBlocks, + ModularLoader, + ModularPipelineMixin, + PipelineBlock, + PipelineState, + SequentialPipelineBlocks, + ) + from .modular_pipeline_utils import ( + ComponentSpec, + ConfigSpec, + InputParam, + OutputParam, + ) + from .stable_diffusion_xl import ( + StableDiffusionXLAutoPipeline, + StableDiffusionXLModularLoader, + ) + from .components_manager import ComponentsManager +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) diff --git a/src/diffusers/pipelines/components_manager.py b/src/diffusers/modular_pipelines/components_manager.py similarity index 51% rename from src/diffusers/pipelines/components_manager.py rename to src/diffusers/modular_pipelines/components_manager.py index 6d7665e29292..992353389b95 100644 --- a/src/diffusers/pipelines/components_manager.py +++ b/src/diffusers/modular_pipelines/components_manager.py @@ -26,6 +26,10 @@ logging, ) from ..models.modeling_utils import ModelMixin +from .modular_pipeline_utils import ComponentSpec + + +import uuid if is_accelerate_available(): @@ -229,54 +233,209 @@ def search_best_candidate(module_sizes, min_memory_offload): return hooks_to_offload + class ComponentsManager: def __init__(self): self.components = OrderedDict() - self.added_time = OrderedDict() # Store when components were added + self.added_time = OrderedDict() # Store when components were added + self.collections = OrderedDict() # collection_name -> set of component_names self.model_hooks = None self._auto_offload_enabled = False - def add(self, name, component): - if name in self.components: - logger.warning(f"Overriding existing component '{name}' in ComponentsManager") - self.components[name] = component - self.added_time[name] = time.time() + + def _lookup_ids(self, name=None, collection=None, load_id=None, components: OrderedDict = None): + """ + Lookup component_ids by name, collection, or load_id. + """ + if components is None: + components = self.components + + if name: + ids_by_name = set() + for component_id, component in components.items(): + comp_name = self._id_to_name(component_id) + if comp_name == name: + ids_by_name.add(component_id) + else: + ids_by_name = set(components.keys()) + if collection: + ids_by_collection = set() + for component_id, component in components.items(): + if component_id in self.collections[collection]: + ids_by_collection.add(component_id) + else: + ids_by_collection = set(components.keys()) + if load_id: + ids_by_load_id = set() + for name, component in components.items(): + if hasattr(component, "_diffusers_load_id") and component._diffusers_load_id == load_id: + ids_by_load_id.add(name) + else: + ids_by_load_id = set(components.keys()) + ids = ids_by_name.intersection(ids_by_collection).intersection(ids_by_load_id) + return ids + + @staticmethod + def _id_to_name(component_id: str): + return "_".join(component_id.split("_")[:-1]) + + def add(self, name, component, collection: Optional[str] = None): + + component_id = f"{name}_{uuid.uuid4()}" + + # check for duplicated components + for comp_id, comp in self.components.items(): + if comp == component: + comp_name = self._id_to_name(comp_id) + if comp_name == name: + logger.warning( + f"component '{name}' already exists as '{comp_id}'" + ) + component_id = comp_id + break + else: + logger.warning( + f"Adding component '{name}' as '{component_id}', but it is duplicate of '{comp_id}'" + f"To remove a duplicate, call `components_manager.remove('')`." + ) + + + # check for duplicated load_id and warn (we do not delete for you) + if hasattr(component, "_diffusers_load_id") and component._diffusers_load_id != "null": + components_with_same_load_id = self._lookup_ids(load_id=component._diffusers_load_id) + components_with_same_load_id = [id for id in components_with_same_load_id if id != component_id] + + if components_with_same_load_id: + existing = ", ".join(components_with_same_load_id) + logger.warning( + f"Adding component '{component_id}', but it has duplicate load_id '{component._diffusers_load_id}' with existing components: {existing}. " + f"To remove a duplicate, call `components_manager.remove('')`." + ) + + # add component to components manager + self.components[component_id] = component + self.added_time[component_id] = time.time() + + if collection: + if collection not in self.collections: + self.collections[collection] = set() + if not component_id in self.collections[collection]: + comp_ids_in_collection = self._lookup_ids(name=name, collection=collection) + for comp_id in comp_ids_in_collection: + logger.info(f"Removing existing {name} from collection '{collection}': {comp_id}") + self.remove(comp_id) + self.collections[collection].add(component_id) + logger.info(f"Added component '{name}' in collection '{collection}': {component_id}") + else: + logger.info(f"Added component '{name}' as '{component_id}'") + if self._auto_offload_enabled: - self.enable_auto_cpu_offload(self._auto_offload_device) + self.enable_auto_cpu_offload(self._auto_offload_device) + + return component_id - def remove(self, name): - if name not in self.components: - logger.warning(f"Component '{name}' not found in ComponentsManager") + + def remove(self, component_id: str = None): + + if component_id not in self.components: + logger.warning(f"Component '{component_id}' not found in ComponentsManager") return - - self.components.pop(name) - self.added_time.pop(name) + + component = self.components.pop(component_id) + self.added_time.pop(component_id) + + for collection in self.collections: + if component_id in self.collections[collection]: + self.collections[collection].remove(component_id) if self._auto_offload_enabled: self.enable_auto_cpu_offload(self._auto_offload_device) - - # YiYi TODO: looking into improving the search pattern - def get(self, names: Union[str, List[str]]): + else: + if isinstance(component, torch.nn.Module): + component.to("cpu") + del component + import gc + gc.collect() + if torch.cuda.is_available(): + torch.cuda.empty_cache() + + def get(self, names: Union[str, List[str]] = None, collection: Optional[str] = None, load_id: Optional[str] = None, + as_name_component_tuples: bool = False): """ - Get components by name with simple pattern matching. + Select components by name with simple pattern matching. Args: names: Component name(s) or pattern(s) Patterns: - - "unet" : exact match - - "!unet" : everything except exact match "unet" - - "base_*" : everything starting with "base_" - - "!base_*" : everything NOT starting with "base_" - - "*unet*" : anything containing "unet" - - "!*unet*" : anything NOT containing "unet" - - "refiner|vae|unet" : anything containing any of these terms - - "!refiner|vae|unet" : anything NOT containing any of these terms + - "unet" : match any component with base name "unet" (e.g., unet_123abc) + - "!unet" : everything except components with base name "unet" + - "unet*" : anything with base name starting with "unet" + - "!unet*" : anything with base name NOT starting with "unet" + - "*unet*" : anything with base name containing "unet" + - "!*unet*" : anything with base name NOT containing "unet" + - "refiner|vae|unet" : anything with base name exactly matching "refiner", "vae", or "unet" + - "!refiner|vae|unet" : anything with base name NOT exactly matching "refiner", "vae", or "unet" + - "unet*|vae*" : anything with base name starting with "unet" OR starting with "vae" + collection: Optional collection to filter by + load_id: Optional load_id to filter by + as_name_component_tuples: If True, returns a list of (name, component) tuples using base names + instead of a dictionary with component IDs as keys Returns: - Single component if names is str and matches one component, - dict of components if names matches multiple components or is a list + Dictionary mapping component IDs to components, + or list of (base_name, component) tuples if as_name_component_tuples=True """ + + selected_ids = self._lookup_ids(collection=collection, load_id=load_id) + components = {k: self.components[k] for k in selected_ids} + + # Helper to extract base name from component_id + def get_base_name(component_id): + parts = component_id.split('_') + # If the last part looks like a UUID, remove it + if len(parts) > 1 and len(parts[-1]) >= 8 and '-' in parts[-1]: + return '_'.join(parts[:-1]) + return component_id + + if names is None: + if as_name_component_tuples: + return [(get_base_name(comp_id), comp) for comp_id, comp in components.items()] + else: + return components + + # Create mapping from component_id to base_name for all components + base_names = {comp_id: get_base_name(comp_id) for comp_id in components.keys()} + + def matches_pattern(component_id, pattern, exact_match=False): + """ + Helper function to check if a component matches a pattern based on its base name. + + Args: + component_id: The component ID to check + pattern: The pattern to match against + exact_match: If True, only exact matches to base_name are considered + """ + base_name = base_names[component_id] + + # Exact match with base name + if exact_match: + return pattern == base_name + + # Prefix match (ends with *) + elif pattern.endswith('*'): + prefix = pattern[:-1] + return base_name.startswith(prefix) + + # Contains match (starts with *) + elif pattern.startswith('*'): + search = pattern[1:-1] if pattern.endswith('*') else pattern[1:] + return search in base_name + + # Exact match (no wildcards) + else: + return pattern == base_name + if isinstance(names, str): # Check if this is a "not" pattern is_not_pattern = names.startswith('!') @@ -286,33 +445,45 @@ def get(self, names: Union[str, List[str]]): # Handle OR patterns (containing |) if '|' in names: terms = names.split('|') + matches = {} + + for comp_id, comp in components.items(): + # For OR patterns with exact names (no wildcards), we do exact matching on base names + exact_match = all(not (term.startswith('*') or term.endswith('*')) for term in terms) + + # Check if any of the terms match this component + should_include = any(matches_pattern(comp_id, term, exact_match) for term in terms) + + # Flip the decision if this is a NOT pattern + if is_not_pattern: + should_include = not should_include + + if should_include: + matches[comp_id] = comp + + log_msg = "NOT " if is_not_pattern else "" + match_type = "exactly matching" if exact_match else "matching any of patterns" + logger.info(f"Getting components {log_msg}{match_type} {terms}: {list(matches.keys())}") + + # Try exact match with a base name + elif any(names == base_name for base_name in base_names.values()): + # Find all components with this base name matches = { - name: comp for name, comp in self.components.items() - if any((term in name) != is_not_pattern for term in terms) # Flip condition if not pattern + comp_id: comp for comp_id, comp in components.items() + if (base_names[comp_id] == names) != is_not_pattern } + if is_not_pattern: - logger.info(f"Getting components NOT containing any of {terms}: {list(matches.keys())}") - else: - logger.info(f"Getting components containing any of {terms}: {list(matches.keys())}") - - # Exact match - elif names in self.components: - if is_not_pattern: - matches = { - name: comp for name, comp in self.components.items() - if name != names - } - logger.info(f"Getting all components except '{names}': {list(matches.keys())}") + logger.info(f"Getting all components except those with base name '{names}': {list(matches.keys())}") else: - logger.info(f"Getting component: {names}") - return self.components[names] + logger.info(f"Getting components with base name '{names}': {list(matches.keys())}") # Prefix match (ends with *) elif names.endswith('*'): prefix = names[:-1] matches = { - name: comp for name, comp in self.components.items() - if name.startswith(prefix) != is_not_pattern # Flip condition if not pattern + comp_id: comp for comp_id, comp in components.items() + if base_names[comp_id].startswith(prefix) != is_not_pattern } if is_not_pattern: logger.info(f"Getting components NOT starting with '{prefix}': {list(matches.keys())}") @@ -323,31 +494,46 @@ def get(self, names: Union[str, List[str]]): elif names.startswith('*'): search = names[1:-1] if names.endswith('*') else names[1:] matches = { - name: comp for name, comp in self.components.items() - if (search in name) != is_not_pattern # Flip condition if not pattern + comp_id: comp for comp_id, comp in components.items() + if (search in base_names[comp_id]) != is_not_pattern } if is_not_pattern: logger.info(f"Getting components NOT containing '{search}': {list(matches.keys())}") else: logger.info(f"Getting components containing '{search}': {list(matches.keys())}") + # Substring match (no wildcards, but not an exact component name) + elif any(names in base_name for base_name in base_names.values()): + matches = { + comp_id: comp for comp_id, comp in components.items() + if (names in base_names[comp_id]) != is_not_pattern + } + if is_not_pattern: + logger.info(f"Getting components NOT containing '{names}': {list(matches.keys())}") + else: + logger.info(f"Getting components containing '{names}': {list(matches.keys())}") + else: - raise ValueError(f"Component '{names}' not found in ComponentsManager") + raise ValueError(f"Component or pattern '{names}' not found in ComponentsManager") if not matches: raise ValueError(f"No components found matching pattern '{names}'") - return matches if len(matches) > 1 else next(iter(matches.values())) + + if as_name_component_tuples: + return [(base_names[comp_id], comp) for comp_id, comp in matches.items()] + else: + return matches elif isinstance(names, list): results = {} for name in names: - result = self.get(name) - if isinstance(result, dict): - results.update(result) - else: - results[name] = result - logger.info(f"Getting multiple components: {list(results.keys())}") - return results + result = self.get(name, collection, load_id, as_name_component_tuples=False) + results.update(result) + + if as_name_component_tuples: + return [(base_names[comp_id], comp) for comp_id, comp in results.items()] + else: + return results else: raise ValueError(f"Invalid type for names: {type(names)}") @@ -391,11 +577,12 @@ def disable_auto_cpu_offload(self): self.model_hooks = None self._auto_offload_enabled = False - def get_model_info(self, name: str, fields: Optional[Union[str, List[str]]] = None) -> Optional[Dict[str, Any]]: + # YiYi TODO: add quantization info + def get_model_info(self, component_id: str, fields: Optional[Union[str, List[str]]] = None) -> Optional[Dict[str, Any]]: """Get comprehensive information about a component. Args: - name: Name of the component to get info for + component_id: Name of the component to get info for fields: Optional field(s) to return. Can be a string for single field or list of fields. If None, returns all fields. @@ -404,23 +591,32 @@ def get_model_info(self, name: str, fields: Optional[Union[str, List[str]]] = No If fields is specified, returns only those fields. If a single field is requested as string, returns just that field's value. """ - if name not in self.components: - raise ValueError(f"Component '{name}' not found in ComponentsManager") + if component_id not in self.components: + raise ValueError(f"Component '{component_id}' not found in ComponentsManager") - component = self.components[name] + component = self.components[component_id] # Build complete info dict first info = { - "model_id": name, - "added_time": self.added_time[name], + "model_id": component_id, + "added_time": self.added_time[component_id], + "collection": ", ".join([coll for coll, comps in self.collections.items() if component_id in comps]) or None, } # Additional info for torch.nn.Module components if isinstance(component, torch.nn.Module): + # Check for hook information + has_hook = hasattr(component, "_hf_hook") + execution_device = None + if has_hook and hasattr(component._hf_hook, "execution_device"): + execution_device = component._hf_hook.execution_device + info.update({ "class_name": component.__class__.__name__, "size_gb": get_memory_footprint(component) / (1024**3), "adapters": None, # Default to None + "has_hook": has_hook, + "execution_device": execution_device, }) # Get adapters if applicable @@ -454,12 +650,64 @@ def get_model_info(self, name: str, fields: Optional[Union[str, List[str]]] = No return info def __repr__(self): + # Helper to get simple name without UUID + def get_simple_name(name): + # Extract the base name by splitting on underscore and taking first part + # This assumes names are in format "name_uuid" + parts = name.split('_') + # If we have at least 2 parts and the last part looks like a UUID, remove it + if len(parts) > 1 and len(parts[-1]) >= 8 and '-' in parts[-1]: + return '_'.join(parts[:-1]) + return name + + # Extract load_id if available + def get_load_id(component): + if hasattr(component, "_diffusers_load_id"): + return component._diffusers_load_id + return "N/A" + + # Format device info compactly + def format_device(component, info): + if not info["has_hook"]: + return str(getattr(component, 'device', 'N/A')) + else: + device = str(getattr(component, 'device', 'N/A')) + exec_device = str(info['execution_device'] or 'N/A') + return f"{device}({exec_device})" + + # Get all simple names to calculate width + simple_names = [get_simple_name(id) for id in self.components.keys()] + + # Get max length of load_ids for models + load_ids = [ + get_load_id(component) + for component in self.components.values() + if isinstance(component, torch.nn.Module) and hasattr(component, "_diffusers_load_id") + ] + max_load_id_len = max([15] + [len(str(lid)) for lid in load_ids]) if load_ids else 15 + + # Get all collections for each component + component_collections = {} + for name in self.components.keys(): + component_collections[name] = [] + for coll, comps in self.collections.items(): + if name in comps: + component_collections[name].append(coll) + if not component_collections[name]: + component_collections[name] = ["N/A"] + + # Find the maximum collection name length + all_collections = [coll for colls in component_collections.values() for coll in colls] + max_collection_len = max(10, max(len(str(c)) for c in all_collections)) if all_collections else 10 + col_widths = { - "id": max(15, max(len(id) for id in self.components.keys())), + "name": max(15, max(len(name) for name in simple_names)), "class": max(25, max(len(component.__class__.__name__) for component in self.components.values())), - "device": 10, + "device": 15, # Reduced since using more compact format "dtype": 15, "size": 10, + "load_id": max_load_id_len, + "collection": max_collection_len } # Create the header lines @@ -476,17 +724,33 @@ def __repr__(self): if models: output += "Models:\n" + dash_line # Column headers - output += f"{'Model ID':<{col_widths['id']}} | {'Class':<{col_widths['class']}} | " - output += f"{'Device':<{col_widths['device']}} | {'Dtype':<{col_widths['dtype']}} | Size (GB)\n" + output += f"{'Name':<{col_widths['name']}} | {'Class':<{col_widths['class']}} | " + output += f"{'Device':<{col_widths['device']}} | {'Dtype':<{col_widths['dtype']}} | " + output += f"{'Size (GB)':<{col_widths['size']}} | {'Load ID':<{col_widths['load_id']}} | Collection\n" output += dash_line # Model entries for name, component in models.items(): info = self.get_model_info(name) - device = str(getattr(component, "device", "N/A")) + simple_name = get_simple_name(name) + device_str = format_device(component, info) dtype = str(component.dtype) if hasattr(component, "dtype") else "N/A" - output += f"{name:<{col_widths['id']}} | {info['class_name']:<{col_widths['class']}} | " - output += f"{device:<{col_widths['device']}} | {dtype:<{col_widths['dtype']}} | {info['size_gb']:.2f}\n" + load_id = get_load_id(component) + + # Print first collection on the main line + first_collection = component_collections[name][0] if component_collections[name] else "N/A" + + output += f"{simple_name:<{col_widths['name']}} | {info['class_name']:<{col_widths['class']}} | " + output += f"{device_str:<{col_widths['device']}} | {dtype:<{col_widths['dtype']}} | " + output += f"{info['size_gb']:<{col_widths['size']}.2f} | {load_id:<{col_widths['load_id']}} | {first_collection}\n" + + # Print additional collections on separate lines if they exist + for i in range(1, len(component_collections[name])): + collection = component_collections[name][i] + output += f"{'':<{col_widths['name']}} | {'':<{col_widths['class']}} | " + output += f"{'':<{col_widths['device']}} | {'':<{col_widths['dtype']}} | " + output += f"{'':<{col_widths['size']}} | {'':<{col_widths['load_id']}} | {collection}\n" + output += dash_line # Other components section @@ -495,12 +759,24 @@ def __repr__(self): output += "\n" output += "Other Components:\n" + dash_line # Column headers for other components - output += f"{'Component ID':<{col_widths['id']}} | {'Class':<{col_widths['class']}}\n" + output += f"{'Name':<{col_widths['name']}} | {'Class':<{col_widths['class']}} | Collection\n" output += dash_line # Other component entries for name, component in others.items(): - output += f"{name:<{col_widths['id']}} | {component.__class__.__name__:<{col_widths['class']}}\n" + info = self.get_model_info(name) + simple_name = get_simple_name(name) + + # Print first collection on the main line + first_collection = component_collections[name][0] if component_collections[name] else "N/A" + + output += f"{simple_name:<{col_widths['name']}} | {component.__class__.__name__:<{col_widths['class']}} | {first_collection}\n" + + # Print additional collections on separate lines if they exist + for i in range(1, len(component_collections[name])): + collection = component_collections[name][i] + output += f"{'':<{col_widths['name']}} | {'':<{col_widths['class']}} | {collection}\n" + output += dash_line # Add additional component info @@ -508,7 +784,8 @@ def __repr__(self): for name in self.components: info = self.get_model_info(name) if info is not None and (info.get("adapters") is not None or info.get("ip_adapter")): - output += f"\n{name}:\n" + simple_name = get_simple_name(name) + output += f"\n{simple_name}:\n" if info.get("adapters") is not None: output += f" Adapters: {info['adapters']}\n" if info.get("ip_adapter"): @@ -517,7 +794,7 @@ def __repr__(self): return output - def add_from_pretrained(self, pretrained_model_name_or_path, prefix: Optional[str] = None, **kwargs): + def from_pretrained(self, pretrained_model_name_or_path, prefix: Optional[str] = None, **kwargs): """ Load components from a pretrained model and add them to the manager. @@ -527,17 +804,12 @@ def add_from_pretrained(self, pretrained_model_name_or_path, prefix: Optional[st If provided, components will be named as "{prefix}_{component_name}" **kwargs: Additional arguments to pass to DiffusionPipeline.from_pretrained() """ - from ..pipelines.pipeline_utils import DiffusionPipeline - - pipe = DiffusionPipeline.from_pretrained(pretrained_model_name_or_path, **kwargs) - for name, component in pipe.components.items(): - - if component is None: - continue - - # Add prefix if specified - component_name = f"{prefix}_{name}" if prefix else name - + subfolder = kwargs.pop("subfolder", None) + # YiYi TODO: extend AutoModel to support non-diffusers models + if subfolder: + from ..models import AutoModel + component = AutoModel.from_pretrained(pretrained_model_name_or_path, subfolder=subfolder, **kwargs) + component_name = f"{prefix}_{subfolder}" if prefix else subfolder if component_name not in self.components: self.add(component_name, component) else: @@ -546,6 +818,59 @@ def add_from_pretrained(self, pretrained_model_name_or_path, prefix: Optional[st f"1. remove the existing component with remove('{component_name}')\n" f"2. Use a different prefix: add_from_pretrained(..., prefix='{prefix}_2')" ) + else: + from ..pipelines.pipeline_utils import DiffusionPipeline + pipe = DiffusionPipeline.from_pretrained(pretrained_model_name_or_path, **kwargs) + for name, component in pipe.components.items(): + + if component is None: + continue + + # Add prefix if specified + component_name = f"{prefix}_{name}" if prefix else name + + if component_name not in self.components: + self.add(component_name, component) + else: + logger.warning( + f"Component '{component_name}' already exists in ComponentsManager and will not be added. To add it, either:\n" + f"1. remove the existing component with remove('{component_name}')\n" + f"2. Use a different prefix: add_from_pretrained(..., prefix='{prefix}_2')" + ) + + def get_one(self, component_id: Optional[str] = None, name: Optional[str] = None, collection: Optional[str] = None, load_id: Optional[str] = None) -> Any: + """ + Get a single component by name. Raises an error if multiple components match or none are found. + + Args: + name: Component name or pattern + collection: Optional collection to filter by + load_id: Optional load_id to filter by + + Returns: + A single component + + Raises: + ValueError: If no components match or multiple components match + """ + + # if component_id is provided, return the component + if component_id is not None and (name is not None or collection is not None or load_id is not None): + raise ValueError(" if component_id is provided, name, collection, and load_id must be None") + elif component_id is not None: + if component_id not in self.components: + raise ValueError(f"Component '{component_id}' not found in ComponentsManager") + return self.components[component_id] + + results = self.get(name, collection, load_id) + + if not results: + raise ValueError(f"No components found matching '{name}'") + + if len(results) > 1: + raise ValueError(f"Multiple components found matching '{name}': {list(results.keys())}") + + return next(iter(results.values())) def summarize_dict_by_value_and_parts(d: Dict[str, Any]) -> Dict[str, Any]: """Summarizes a dictionary by finding common prefixes that share the same value. diff --git a/src/diffusers/modular_pipelines/modular_pipeline.py b/src/diffusers/modular_pipelines/modular_pipeline.py new file mode 100644 index 000000000000..3136c3bb11f1 --- /dev/null +++ b/src/diffusers/modular_pipelines/modular_pipeline.py @@ -0,0 +1,2200 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import inspect + + +import traceback +import warnings +from collections import OrderedDict +from dataclasses import dataclass, field +from typing import Any, Dict, List, Tuple, Union, Optional +from copy import deepcopy + + +import torch +from tqdm.auto import tqdm +import re +import os +import importlib + +from huggingface_hub.utils import validate_hf_hub_args + +from ..configuration_utils import ConfigMixin, FrozenDict +from ..utils import ( + is_accelerate_available, + logging, + PushToHubMixin, +) +from ..pipelines.pipeline_loading_utils import simple_get_class_obj, _fetch_class_library_tuple +from .modular_pipeline_utils import ( + ComponentSpec, + ConfigSpec, + InputParam, + OutputParam, + format_components, + format_configs, + format_inputs_short, + format_intermediates_short, + make_doc_string, +) +from .components_manager import ComponentsManager +from ..utils.dynamic_modules_utils import get_class_from_dynamic_module, resolve_trust_remote_code + +from copy import deepcopy +if is_accelerate_available(): + import accelerate + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +MODULAR_LOADER_MAPPING = OrderedDict( + [ + ("stable-diffusion-xl", "StableDiffusionXLModularLoader"), + ] +) + + +@dataclass +class PipelineState: + """ + [`PipelineState`] stores the state of a pipeline. It is used to pass data between pipeline blocks. + """ + + inputs: Dict[str, Any] = field(default_factory=dict) + intermediates: Dict[str, Any] = field(default_factory=dict) + input_kwargs: Dict[str, list[str, Any]] = field(default_factory=dict) + intermediate_kwargs: Dict[str, list[str, Any]] = field(default_factory=dict) + + def add_input(self, key: str, value: Any, kwargs_type: str = None): + """ + Add an input to the pipeline state with optional metadata. + + Args: + key (str): The key for the input + value (Any): The input value + kwargs_type (str): The kwargs_type to store with the input + """ + self.inputs[key] = value + if kwargs_type is not None: + if kwargs_type not in self.input_kwargs: + self.input_kwargs[kwargs_type] = [key] + else: + self.input_kwargs[kwargs_type].append(key) + + def add_intermediate(self, key: str, value: Any, kwargs_type: str = None): + """ + Add an intermediate value to the pipeline state with optional metadata. + + Args: + key (str): The key for the intermediate value + value (Any): The intermediate value + kwargs_type (str): The kwargs_type to store with the intermediate value + """ + self.intermediates[key] = value + if kwargs_type is not None: + if kwargs_type not in self.intermediate_kwargs: + self.intermediate_kwargs[kwargs_type] = [key] + else: + self.intermediate_kwargs[kwargs_type].append(key) + + def get_input(self, key: str, default: Any = None) -> Any: + value = self.inputs.get(key, default) + if value is not None: + return deepcopy(value) + + def get_inputs(self, keys: List[str], default: Any = None) -> Dict[str, Any]: + return {key: self.inputs.get(key, default) for key in keys} + + def get_inputs_kwargs(self, kwargs_type: str) -> Dict[str, Any]: + """ + Get all inputs with matching kwargs_type. + + Args: + kwargs_type (str): The kwargs_type to filter by + + Returns: + Dict[str, Any]: Dictionary of inputs with matching kwargs_type + """ + input_names = self.input_kwargs.get(kwargs_type, []) + return self.get_inputs(input_names) + + def get_intermediates_kwargs(self, kwargs_type: str) -> Dict[str, Any]: + """ + Get all intermediates with matching kwargs_type. + + Args: + kwargs_type (str): The kwargs_type to filter by + + Returns: + Dict[str, Any]: Dictionary of intermediates with matching kwargs_type + """ + intermediate_names = self.intermediate_kwargs.get(kwargs_type, []) + return self.get_intermediates(intermediate_names) + + def get_intermediate(self, key: str, default: Any = None) -> Any: + return self.intermediates.get(key, default) + + def get_intermediates(self, keys: List[str], default: Any = None) -> Dict[str, Any]: + return {key: self.intermediates.get(key, default) for key in keys} + + def to_dict(self) -> Dict[str, Any]: + return {**self.__dict__, "inputs": self.inputs, "intermediates": self.intermediates} + + def __repr__(self): + def format_value(v): + if hasattr(v, "shape") and hasattr(v, "dtype"): + return f"Tensor(dtype={v.dtype}, shape={v.shape})" + elif isinstance(v, list) and len(v) > 0 and hasattr(v[0], "shape") and hasattr(v[0], "dtype"): + return f"[Tensor(dtype={v[0].dtype}, shape={v[0].shape}), ...]" + else: + return repr(v) + + inputs = "\n".join(f" {k}: {format_value(v)}" for k, v in self.inputs.items()) + intermediates = "\n".join(f" {k}: {format_value(v)}" for k, v in self.intermediates.items()) + + # Format input_kwargs and intermediate_kwargs + input_kwargs_str = "\n".join(f" {k}: {v}" for k, v in self.input_kwargs.items()) + intermediate_kwargs_str = "\n".join(f" {k}: {v}" for k, v in self.intermediate_kwargs.items()) + + return ( + f"PipelineState(\n" + f" inputs={{\n{inputs}\n }},\n" + f" intermediates={{\n{intermediates}\n }},\n" + f" input_kwargs={{\n{input_kwargs_str}\n }},\n" + f" intermediate_kwargs={{\n{intermediate_kwargs_str}\n }}\n" + f")" + ) + + +@dataclass +class BlockState: + """ + Container for block state data with attribute access and formatted representation. + """ + def __init__(self, **kwargs): + for key, value in kwargs.items(): + setattr(self, key, value) + + def __getitem__(self, key: str): + # allows block_state["foo"] + return getattr(self, key, None) + + def __setitem__(self, key: str, value: Any): + # allows block_state["foo"] = "bar" + setattr(self, key, value) + + def as_dict(self): + """ + Convert BlockState to a dictionary. + + Returns: + Dict[str, Any]: Dictionary containing all attributes of the BlockState + """ + return {key: value for key, value in self.__dict__.items()} + + def __repr__(self): + def format_value(v): + # Handle tensors directly + if hasattr(v, "shape") and hasattr(v, "dtype"): + return f"Tensor(dtype={v.dtype}, shape={v.shape})" + + # Handle lists of tensors + elif isinstance(v, list): + if len(v) > 0 and hasattr(v[0], "shape") and hasattr(v[0], "dtype"): + shapes = [t.shape for t in v] + return f"List[{len(v)}] of Tensors with shapes {shapes}" + return repr(v) + + # Handle tuples of tensors + elif isinstance(v, tuple): + if len(v) > 0 and hasattr(v[0], "shape") and hasattr(v[0], "dtype"): + shapes = [t.shape for t in v] + return f"Tuple[{len(v)}] of Tensors with shapes {shapes}" + return repr(v) + + # Handle dicts with tensor values + elif isinstance(v, dict): + formatted_dict = {} + for k, val in v.items(): + if hasattr(val, "shape") and hasattr(val, "dtype"): + formatted_dict[k] = f"Tensor(shape={val.shape}, dtype={val.dtype})" + elif isinstance(val, list) and len(val) > 0 and hasattr(val[0], "shape") and hasattr(val[0], "dtype"): + shapes = [t.shape for t in val] + formatted_dict[k] = f"List[{len(val)}] of Tensors with shapes {shapes}" + else: + formatted_dict[k] = repr(val) + return formatted_dict + + # Default case + return repr(v) + + attributes = "\n".join(f" {k}: {format_value(v)}" for k, v in self.__dict__.items()) + return f"BlockState(\n{attributes}\n)" + + + +class ModularPipelineMixin(ConfigMixin): + """ + Mixin for all PipelineBlocks: PipelineBlock, AutoPipelineBlocks, SequentialPipelineBlocks + """ + + config_name = "config.json" + + @classmethod + def _get_signature_keys(cls, obj): + parameters = inspect.signature(obj.__init__).parameters + required_parameters = {k: v for k, v in parameters.items() if v.default == inspect._empty} + optional_parameters = set({k for k, v in parameters.items() if v.default != inspect._empty}) + expected_modules = set(required_parameters.keys()) - {"self"} + + return expected_modules, optional_parameters + + + @classmethod + def from_pretrained( + cls, + pretrained_model_name_or_path: str, + trust_remote_code: Optional[bool] = None, + **kwargs, + ): + hub_kwargs_names = [ + "cache_dir", + "force_download", + "local_files_only", + "proxies", + "resume_download", + "revision", + "subfolder", + "token", + ] + hub_kwargs = {name: kwargs.pop(name) for name in hub_kwargs_names if name in kwargs} + + config = cls.load_config(pretrained_model_name_or_path) + has_remote_code = "auto_map" in config and cls.__name__ in config["auto_map"] + trust_remote_code = resolve_trust_remote_code( + trust_remote_code, pretrained_model_name_or_path, has_remote_code + ) + if not (has_remote_code and trust_remote_code): + raise ValueError("TODO") + + class_ref = config["auto_map"][cls.__name__] + module_file, class_name = class_ref.split(".") + module_file = module_file + ".py" + block_cls = get_class_from_dynamic_module( + pretrained_model_name_or_path, + module_file=module_file, + class_name=class_name, + is_modular=True, + **hub_kwargs, + **kwargs, + ) + expected_kwargs, optional_kwargs = block_cls._get_signature_keys(block_cls) + block_kwargs = { + name: kwargs.pop(name) for name in kwargs if name in expected_kwargs or name in optional_kwargs + } + + return block_cls(**block_kwargs) + + def setup_loader(self, modular_repo: Optional[Union[str, os.PathLike]] = None, component_manager: Optional[ComponentsManager] = None, collection: Optional[str] = None): + """ + create a ModularLoader, optionally accept modular_repo to load from hub. + """ + + # Import components loader (it is model-specific class) + loader_class_name = MODULAR_LOADER_MAPPING.get(self.model_name, ModularLoader.__name__) + diffusers_module = importlib.import_module("diffusers") + loader_class = getattr(diffusers_module, loader_class_name) + + # Create deep copies to avoid modifying the original specs + component_specs = deepcopy(self.expected_components) + config_specs = deepcopy(self.expected_configs) + # Create the loader with the updated specs + specs = component_specs + config_specs + + self.loader = loader_class(specs, modular_repo=modular_repo, component_manager=component_manager, collection=collection) + + + @property + def default_call_parameters(self) -> Dict[str, Any]: + params = {} + for input_param in self.inputs: + params[input_param.name] = input_param.default + return params + + def run(self, state: PipelineState = None, output: Union[str, List[str]] = None, **kwargs): + """ + Run one or more blocks in sequence, optionally you can pass a previous pipeline state. + """ + if state is None: + state = PipelineState() + + if not hasattr(self, "loader"): + logger.info("Loader is not set, please call `setup_loader()` if you need to load checkpoints for your pipeline.") + self.loader = None + + # Make a copy of the input kwargs + passed_kwargs = kwargs.copy() + + + # Add inputs to state, using defaults if not provided in the kwargs or the state + # if same input already in the state, will override it if provided in the kwargs + + intermediates_inputs = [inp.name for inp in self.intermediates_inputs] + for expected_input_param in self.inputs: + name = expected_input_param.name + default = expected_input_param.default + kwargs_type = expected_input_param.kwargs_type + if name in passed_kwargs: + if name not in intermediates_inputs: + state.add_input(name, passed_kwargs.pop(name), kwargs_type) + else: + state.add_input(name, passed_kwargs[name], kwargs_type) + elif name not in state.inputs: + state.add_input(name, default, kwargs_type) + + for expected_intermediate_param in self.intermediates_inputs: + name = expected_intermediate_param.name + kwargs_type = expected_intermediate_param.kwargs_type + if name in passed_kwargs: + state.add_intermediate(name, passed_kwargs.pop(name), kwargs_type) + + # Warn about unexpected inputs + if len(passed_kwargs) > 0: + warnings.warn(f"Unexpected input '{passed_kwargs.keys()}' provided. This input will be ignored.") + # Run the pipeline + with torch.no_grad(): + try: + pipeline, state = self(self.loader, state) + except Exception: + error_msg = f"Error in block: ({self.__class__.__name__}):\n" + logger.error(error_msg) + raise + + if output is None: + return state + + + elif isinstance(output, str): + return state.get_intermediate(output) + + elif isinstance(output, (list, tuple)): + return state.get_intermediates(output) + else: + raise ValueError(f"Output '{output}' is not a valid output type") + + @torch.compiler.disable + def progress_bar(self, iterable=None, total=None): + if not hasattr(self, "_progress_bar_config"): + self._progress_bar_config = {} + elif not isinstance(self._progress_bar_config, dict): + raise ValueError( + f"`self._progress_bar_config` should be of type `dict`, but is {type(self._progress_bar_config)}." + ) + + if iterable is not None: + return tqdm(iterable, **self._progress_bar_config) + elif total is not None: + return tqdm(total=total, **self._progress_bar_config) + else: + raise ValueError("Either `total` or `iterable` has to be defined.") + + def set_progress_bar_config(self, **kwargs): + self._progress_bar_config = kwargs + + +class PipelineBlock(ModularPipelineMixin): + + model_name = None + + @property + def description(self) -> str: + """Description of the block. Must be implemented by subclasses.""" + # raise NotImplementedError("description method must be implemented in subclasses") + return "TODO: add a description" + + @property + def expected_components(self) -> List[ComponentSpec]: + return [] + + @property + def expected_configs(self) -> List[ConfigSpec]: + return [] + + + @property + def inputs(self) -> List[InputParam]: + """List of input parameters. Must be implemented by subclasses.""" + return [] + + @property + def intermediates_inputs(self) -> List[InputParam]: + """List of intermediate input parameters. Must be implemented by subclasses.""" + return [] + + @property + def intermediates_outputs(self) -> List[OutputParam]: + """List of intermediate output parameters. Must be implemented by subclasses.""" + return [] + + def _get_outputs(self): + return self.intermediates_outputs + + # YiYi TODO: is it too easy for user to unintentionally override these properties? + # Adding outputs attributes here for consistency between PipelineBlock/AutoPipelineBlocks/SequentialPipelineBlocks + @property + def outputs(self) -> List[OutputParam]: + return self._get_outputs() + + def _get_required_inputs(self): + input_names = [] + for input_param in self.inputs: + if input_param.required: + input_names.append(input_param.name) + return input_names + + @property + def required_inputs(self) -> List[str]: + return self._get_required_inputs() + + + def _get_required_intermediates_inputs(self): + input_names = [] + for input_param in self.intermediates_inputs: + if input_param.required: + input_names.append(input_param.name) + return input_names + + # YiYi TODO: maybe we do not need this, it is only used in docstring, + # intermediate_inputs is by default required, unless you manually handle it inside the block + @property + def required_intermediates_inputs(self) -> List[str]: + return self._get_required_intermediates_inputs() + + + def __call__(self, pipeline, state: PipelineState) -> PipelineState: + raise NotImplementedError("__call__ method must be implemented in subclasses") + + def __repr__(self): + class_name = self.__class__.__name__ + base_class = self.__class__.__bases__[0].__name__ + + # Format description with proper indentation + desc_lines = self.description.split('\n') + desc = [] + # First line with "Description:" label + desc.append(f" Description: {desc_lines[0]}") + # Subsequent lines with proper indentation + if len(desc_lines) > 1: + desc.extend(f" {line}" for line in desc_lines[1:]) + desc = '\n'.join(desc) + '\n' + + # Components section - use format_components with add_empty_lines=False + expected_components = getattr(self, "expected_components", []) + components_str = format_components(expected_components, indent_level=2, add_empty_lines=False) + components = " " + components_str.replace("\n", "\n ") + + # Configs section - use format_configs with add_empty_lines=False + expected_configs = getattr(self, "expected_configs", []) + configs_str = format_configs(expected_configs, indent_level=2, add_empty_lines=False) + configs = " " + configs_str.replace("\n", "\n ") + + # Inputs section + inputs_str = format_inputs_short(self.inputs) + inputs = "Inputs:\n " + inputs_str + + # Intermediates section + intermediates_str = format_intermediates_short(self.intermediates_inputs, self.required_intermediates_inputs, self.intermediates_outputs) + intermediates = f"Intermediates:\n{intermediates_str}" + + return ( + f"{class_name}(\n" + f" Class: {base_class}\n" + f"{desc}" + f"{components}\n" + f"{configs}\n" + f" {inputs}\n" + f" {intermediates}\n" + f")" + ) + + + @property + def doc(self): + return make_doc_string( + self.inputs, + self.intermediates_inputs, + self.outputs, + self.description, + class_name=self.__class__.__name__, + expected_components=self.expected_components, + expected_configs=self.expected_configs + ) + + + # YiYi TODO: input and inteermediate inputs with same name? should warn? + def get_block_state(self, state: PipelineState) -> dict: + """Get all inputs and intermediates in one dictionary""" + data = {} + + # Check inputs + for input_param in self.inputs: + if input_param.name: + value = state.get_input(input_param.name) + if input_param.required and value is None: + raise ValueError(f"Required input '{input_param.name}' is missing") + elif value is not None or (value is None and input_param.name not in data): + data[input_param.name] = value + elif input_param.kwargs_type: + # if kwargs_type is provided, get all inputs with matching kwargs_type + if input_param.kwargs_type not in data: + data[input_param.kwargs_type] = {} + inputs_kwargs = state.get_inputs_kwargs(input_param.kwargs_type) + if inputs_kwargs: + for k, v in inputs_kwargs.items(): + if v is not None: + data[k] = v + data[input_param.kwargs_type][k] = v + + # Check intermediates + for input_param in self.intermediates_inputs: + if input_param.name: + value = state.get_intermediate(input_param.name) + if input_param.required and value is None: + raise ValueError(f"Required intermediate input '{input_param.name}' is missing") + elif value is not None or (value is None and input_param.name not in data): + data[input_param.name] = value + elif input_param.kwargs_type: + # if kwargs_type is provided, get all intermediates with matching kwargs_type + if input_param.kwargs_type not in data: + data[input_param.kwargs_type] = {} + intermediates_kwargs = state.get_intermediates_kwargs(input_param.kwargs_type) + if intermediates_kwargs: + for k, v in intermediates_kwargs.items(): + if v is not None: + if k not in data: + data[k] = v + data[input_param.kwargs_type][k] = v + return BlockState(**data) + + def add_block_state(self, state: PipelineState, block_state: BlockState): + for output_param in self.intermediates_outputs: + if not hasattr(block_state, output_param.name): + raise ValueError(f"Intermediate output '{output_param.name}' is missing in block state") + param = getattr(block_state, output_param.name) + state.add_intermediate(output_param.name, param, output_param.kwargs_type) + + for input_param in self.intermediates_inputs: + if hasattr(block_state, input_param.name): + param = getattr(block_state, input_param.name) + # Only add if the value is different from what's in the state + current_value = state.get_intermediate(input_param.name) + if current_value is not param: # Using identity comparison to check if object was modified + state.add_intermediate(input_param.name, param, input_param.kwargs_type) + + for input_param in self.intermediates_inputs: + if input_param.name and hasattr(block_state, input_param.name): + param = getattr(block_state, input_param.name) + # Only add if the value is different from what's in the state + current_value = state.get_intermediate(input_param.name) + if current_value is not param: # Using identity comparison to check if object was modified + state.add_intermediate(input_param.name, param, input_param.kwargs_type) + elif input_param.kwargs_type: + # if it is a kwargs type, e.g. "guider_input_fields", it is likely to be a list of parameters + # we need to first find out which inputs are and loop through them. + intermediates_kwargs = state.get_intermediates_kwargs(input_param.kwargs_type) + for param_name, current_value in intermediates_kwargs.items(): + param = getattr(block_state, param_name) + if current_value is not param: # Using identity comparison to check if object was modified + state.add_intermediate(param_name, param, input_param.kwargs_type) + + +def combine_inputs(*named_input_lists: List[Tuple[str, List[InputParam]]]) -> List[InputParam]: + """ + Combines multiple lists of InputParam objects from different blocks. For duplicate inputs, updates only if + current default value is None and new default value is not None. Warns if multiple non-None default values + exist for the same input. + + Args: + named_input_lists: List of tuples containing (block_name, input_param_list) pairs + + Returns: + List[InputParam]: Combined list of unique InputParam objects + """ + combined_dict = {} # name -> InputParam + value_sources = {} # name -> block_name + + for block_name, inputs in named_input_lists: + for input_param in inputs: + if input_param.name is None and input_param.kwargs_type is not None: + input_name = "*_" + input_param.kwargs_type + else: + input_name = input_param.name + if input_name in combined_dict: + current_param = combined_dict[input_name] + if (current_param.default is not None and + input_param.default is not None and + current_param.default != input_param.default): + warnings.warn( + f"Multiple different default values found for input '{input_name}': " + f"{current_param.default} (from block '{value_sources[input_name]}') and " + f"{input_param.default} (from block '{block_name}'). Using {current_param.default}." + ) + if current_param.default is None and input_param.default is not None: + combined_dict[input_name] = input_param + value_sources[input_name] = block_name + else: + combined_dict[input_name] = input_param + value_sources[input_name] = block_name + + return list(combined_dict.values()) + +def combine_outputs(*named_output_lists: List[Tuple[str, List[OutputParam]]]) -> List[OutputParam]: + """ + Combines multiple lists of OutputParam objects from different blocks. For duplicate outputs, + keeps the first occurrence of each output name. + + Args: + named_output_lists: List of tuples containing (block_name, output_param_list) pairs + + Returns: + List[OutputParam]: Combined list of unique OutputParam objects + """ + combined_dict = {} # name -> OutputParam + + for block_name, outputs in named_output_lists: + for output_param in outputs: + if (output_param.name not in combined_dict) or (combined_dict[output_param.name].kwargs_type is None and output_param.kwargs_type is not None): + combined_dict[output_param.name] = output_param + + return list(combined_dict.values()) + + +class AutoPipelineBlocks(ModularPipelineMixin): + """ + A class that automatically selects a block to run based on the inputs. + + Attributes: + block_classes: List of block classes to be used + block_names: List of prefixes for each block + block_trigger_inputs: List of input names that trigger specific blocks, with None for default + """ + + block_classes = [] + block_names = [] + block_trigger_inputs = [] + + def __init__(self): + blocks = OrderedDict() + for block_name, block_cls in zip(self.block_names, self.block_classes): + blocks[block_name] = block_cls() + self.blocks = blocks + if not (len(self.block_classes) == len(self.block_names) == len(self.block_trigger_inputs)): + raise ValueError(f"In {self.__class__.__name__}, the number of block_classes, block_names, and block_trigger_inputs must be the same.") + default_blocks = [t for t in self.block_trigger_inputs if t is None] + # can only have 1 or 0 default block, and has to put in the last + # the order of blocksmatters here because the first block with matching trigger will be dispatched + # e.g. blocks = [inpaint, img2img] and block_trigger_inputs = ["mask", "image"] + # if both mask and image are provided, it is inpaint; if only image is provided, it is img2img + if len(default_blocks) > 1 or ( + len(default_blocks) == 1 and self.block_trigger_inputs[-1] is not None + ): + raise ValueError( + f"In {self.__class__.__name__}, exactly one None must be specified as the last element " + "in block_trigger_inputs." + ) + + # Map trigger inputs to block objects + self.trigger_to_block_map = dict(zip(self.block_trigger_inputs, self.blocks.values())) + self.trigger_to_block_name_map = dict(zip(self.block_trigger_inputs, self.blocks.keys())) + self.block_to_trigger_map = dict(zip(self.blocks.keys(), self.block_trigger_inputs)) + + @property + def model_name(self): + return next(iter(self.blocks.values())).model_name + + @property + def description(self): + return "" + + @property + def expected_components(self): + expected_components = [] + for block in self.blocks.values(): + for component in block.expected_components: + if component not in expected_components: + expected_components.append(component) + return expected_components + + @property + def expected_configs(self): + expected_configs = [] + for block in self.blocks.values(): + for config in block.expected_configs: + if config not in expected_configs: + expected_configs.append(config) + return expected_configs + + + @property + def required_inputs(self) -> List[str]: + if None not in self.block_trigger_inputs: + return [] + first_block = next(iter(self.blocks.values())) + required_by_all = set(getattr(first_block, "required_inputs", set())) + + # Intersect with required inputs from all other blocks + for block in list(self.blocks.values())[1:]: + block_required = set(getattr(block, "required_inputs", set())) + required_by_all.intersection_update(block_required) + + return list(required_by_all) + + # YiYi TODO: maybe we do not need this, it is only used in docstring, + # intermediate_inputs is by default required, unless you manually handle it inside the block + @property + def required_intermediates_inputs(self) -> List[str]: + if None not in self.block_trigger_inputs: + return [] + first_block = next(iter(self.blocks.values())) + required_by_all = set(getattr(first_block, "required_intermediates_inputs", set())) + + # Intersect with required inputs from all other blocks + for block in list(self.blocks.values())[1:]: + block_required = set(getattr(block, "required_intermediates_inputs", set())) + required_by_all.intersection_update(block_required) + + return list(required_by_all) + + + # YiYi TODO: add test for this + @property + def inputs(self) -> List[Tuple[str, Any]]: + named_inputs = [(name, block.inputs) for name, block in self.blocks.items()] + combined_inputs = combine_inputs(*named_inputs) + # mark Required inputs only if that input is required by all the blocks + for input_param in combined_inputs: + if input_param.name in self.required_inputs: + input_param.required = True + else: + input_param.required = False + return combined_inputs + + + @property + def intermediates_inputs(self) -> List[str]: + named_inputs = [(name, block.intermediates_inputs) for name, block in self.blocks.items()] + combined_inputs = combine_inputs(*named_inputs) + # mark Required inputs only if that input is required by all the blocks + for input_param in combined_inputs: + if input_param.name in self.required_intermediates_inputs: + input_param.required = True + else: + input_param.required = False + return combined_inputs + + @property + def intermediates_outputs(self) -> List[str]: + named_outputs = [(name, block.intermediates_outputs) for name, block in self.blocks.items()] + combined_outputs = combine_outputs(*named_outputs) + return combined_outputs + + @property + def outputs(self) -> List[str]: + named_outputs = [(name, block.outputs) for name, block in self.blocks.items()] + combined_outputs = combine_outputs(*named_outputs) + return combined_outputs + + @torch.no_grad() + def __call__(self, pipeline, state: PipelineState) -> PipelineState: + # Find default block first (if any) + + block = self.trigger_to_block_map.get(None) + for input_name in self.block_trigger_inputs: + if input_name is not None and state.get_input(input_name) is not None: + block = self.trigger_to_block_map[input_name] + break + elif input_name is not None and state.get_intermediate(input_name) is not None: + block = self.trigger_to_block_map[input_name] + break + + if block is None: + logger.warning(f"skipping auto block: {self.__class__.__name__}") + return pipeline, state + + try: + logger.info(f"Running block: {block.__class__.__name__}, trigger: {input_name}") + return block(pipeline, state) + except Exception as e: + error_msg = ( + f"\nError in block: {block.__class__.__name__}\n" + f"Error details: {str(e)}\n" + f"Traceback:\n{traceback.format_exc()}" + ) + logger.error(error_msg) + raise + + def _get_trigger_inputs(self): + """ + Returns a set of all unique trigger input values found in the blocks. + Returns: Set[str] containing all unique block_trigger_inputs values + """ + def fn_recursive_get_trigger(blocks): + trigger_values = set() + + if blocks is not None: + for name, block in blocks.items(): + # Check if current block has trigger inputs(i.e. auto block) + if hasattr(block, 'block_trigger_inputs') and block.block_trigger_inputs is not None: + # Add all non-None values from the trigger inputs list + trigger_values.update(t for t in block.block_trigger_inputs if t is not None) + + # If block has blocks, recursively check them + if hasattr(block, 'blocks'): + nested_triggers = fn_recursive_get_trigger(block.blocks) + trigger_values.update(nested_triggers) + + return trigger_values + + trigger_inputs = set(self.block_trigger_inputs) + trigger_inputs.update(fn_recursive_get_trigger(self.blocks)) + + return trigger_inputs + + @property + def trigger_inputs(self): + return self._get_trigger_inputs() + + def __repr__(self): + class_name = self.__class__.__name__ + base_class = self.__class__.__bases__[0].__name__ + header = ( + f"{class_name}(\n Class: {base_class}\n" + if base_class and base_class != "object" + else f"{class_name}(\n" + ) + + + if self.trigger_inputs: + header += "\n" + header += " " + "=" * 100 + "\n" + header += " This pipeline contains blocks that are selected at runtime based on inputs.\n" + header += f" Trigger Inputs: {self.trigger_inputs}\n" + # Get first trigger input as example + example_input = next(t for t in self.trigger_inputs if t is not None) + header += f" Use `get_execution_blocks()` with input names to see selected blocks (e.g. `get_execution_blocks('{example_input}')`).\n" + header += " " + "=" * 100 + "\n\n" + + # Format description with proper indentation + desc_lines = self.description.split('\n') + desc = [] + # First line with "Description:" label + desc.append(f" Description: {desc_lines[0]}") + # Subsequent lines with proper indentation + if len(desc_lines) > 1: + desc.extend(f" {line}" for line in desc_lines[1:]) + desc = '\n'.join(desc) + '\n' + + # Components section - focus only on expected components + expected_components = getattr(self, "expected_components", []) + components_str = format_components(expected_components, indent_level=2, add_empty_lines=False) + + # Configs section - use format_configs with add_empty_lines=False + expected_configs = getattr(self, "expected_configs", []) + configs_str = format_configs(expected_configs, indent_level=2, add_empty_lines=False) + + # Blocks section - moved to the end with simplified format + blocks_str = " Blocks:\n" + for i, (name, block) in enumerate(self.blocks.items()): + # Get trigger input for this block + trigger = None + if hasattr(self, 'block_to_trigger_map'): + trigger = self.block_to_trigger_map.get(name) + # Format the trigger info + if trigger is None: + trigger_str = "[default]" + elif isinstance(trigger, (list, tuple)): + trigger_str = f"[trigger: {', '.join(str(t) for t in trigger)}]" + else: + trigger_str = f"[trigger: {trigger}]" + # For AutoPipelineBlocks, add bullet points + blocks_str += f" • {name} {trigger_str} ({block.__class__.__name__})\n" + else: + # For SequentialPipelineBlocks, show execution order + blocks_str += f" [{i}] {name} ({block.__class__.__name__})\n" + + # Add block description + desc_lines = block.description.split('\n') + indented_desc = desc_lines[0] + if len(desc_lines) > 1: + indented_desc += '\n' + '\n'.join(' ' + line for line in desc_lines[1:]) + blocks_str += f" Description: {indented_desc}\n\n" + + # Build the representation with conditional sections + result = f"{header}\n{desc}" + + # Only add components section if it has content + if components_str.strip(): + result += f"\n\n{components_str}" + + # Only add configs section if it has content + if configs_str.strip(): + result += f"\n\n{configs_str}" + + # Always add blocks section + result += f"\n\n{blocks_str})" + + return result + + + @property + def doc(self): + return make_doc_string( + self.inputs, + self.intermediates_inputs, + self.outputs, + self.description, + class_name=self.__class__.__name__, + expected_components=self.expected_components, + expected_configs=self.expected_configs + ) + +class SequentialPipelineBlocks(ModularPipelineMixin): + """ + A class that combines multiple pipeline block classes into one. When called, it will call each block in sequence. + """ + block_classes = [] + block_names = [] + + + @property + def description(self): + return "" + + @property + def model_name(self): + return next(iter(self.blocks.values())).model_name + + + @property + def expected_components(self): + expected_components = [] + for block in self.blocks.values(): + for component in block.expected_components: + if component not in expected_components: + expected_components.append(component) + return expected_components + + @property + def expected_configs(self): + expected_configs = [] + for block in self.blocks.values(): + for config in block.expected_configs: + if config not in expected_configs: + expected_configs.append(config) + return expected_configs + + @classmethod + def from_blocks_dict(cls, blocks_dict: Dict[str, Any]) -> "SequentialPipelineBlocks": + """Creates a SequentialPipelineBlocks instance from a dictionary of blocks. + + Args: + blocks_dict: Dictionary mapping block names to block instances + + Returns: + A new SequentialPipelineBlocks instance + """ + instance = cls() + instance.block_classes = [block.__class__ for block in blocks_dict.values()] + instance.block_names = list(blocks_dict.keys()) + instance.blocks = blocks_dict + return instance + + def __init__(self): + blocks = OrderedDict() + for block_name, block_cls in zip(self.block_names, self.block_classes): + blocks[block_name] = block_cls() + self.blocks = blocks + + + @property + def required_inputs(self) -> List[str]: + # Get the first block from the dictionary + first_block = next(iter(self.blocks.values())) + required_by_any = set(getattr(first_block, "required_inputs", set())) + + # Union with required inputs from all other blocks + for block in list(self.blocks.values())[1:]: + block_required = set(getattr(block, "required_inputs", set())) + required_by_any.update(block_required) + + return list(required_by_any) + + # YiYi TODO: maybe we do not need this, it is only used in docstring, + # intermediate_inputs is by default required, unless you manually handle it inside the block + @property + def required_intermediates_inputs(self) -> List[str]: + required_intermediates_inputs = [] + for input_param in self.intermediates_inputs: + if input_param.required: + required_intermediates_inputs.append(input_param.name) + return required_intermediates_inputs + + # YiYi TODO: add test for this + @property + def inputs(self) -> List[Tuple[str, Any]]: + return self.get_inputs() + + def get_inputs(self): + named_inputs = [(name, block.inputs) for name, block in self.blocks.items()] + combined_inputs = combine_inputs(*named_inputs) + # mark Required inputs only if that input is required any of the blocks + for input_param in combined_inputs: + if input_param.name in self.required_inputs: + input_param.required = True + else: + input_param.required = False + return combined_inputs + + @property + def intermediates_inputs(self) -> List[str]: + return self.get_intermediates_inputs() + + def get_intermediates_inputs(self): + inputs = [] + outputs = set() + added_inputs = set() + + # Go through all blocks in order + for block in self.blocks.values(): + # Add inputs that aren't in outputs yet + for inp in block.intermediates_inputs: + if inp.name not in outputs and inp.name not in added_inputs: + inputs.append(inp) + added_inputs.add(inp.name) + + # Only add outputs if the block cannot be skipped + should_add_outputs = True + if hasattr(block, "block_trigger_inputs") and None not in block.block_trigger_inputs: + should_add_outputs = False + + if should_add_outputs: + # Add this block's outputs + block_intermediates_outputs = [out.name for out in block.intermediates_outputs] + outputs.update(block_intermediates_outputs) + return inputs + + @property + def intermediates_outputs(self) -> List[str]: + named_outputs = [] + for name, block in self.blocks.items(): + inp_names = set([inp.name for inp in block.intermediates_inputs]) + # so we only need to list new variables as intermediates_outputs, but if user wants to list these they modified it's still fine (a.k.a we don't enforce) + # filter out them here so they do not end up as intermediates_outputs + if name not in inp_names: + named_outputs.append((name, block.intermediates_outputs)) + combined_outputs = combine_outputs(*named_outputs) + return combined_outputs + + # YiYi TODO: I think we can remove the outputs property + @property + def outputs(self) -> List[str]: + # return next(reversed(self.blocks.values())).intermediates_outputs + return self.intermediates_outputs + @torch.no_grad() + def __call__(self, pipeline, state: PipelineState) -> PipelineState: + for block_name, block in self.blocks.items(): + try: + pipeline, state = block(pipeline, state) + except Exception as e: + error_msg = ( + f"\nError in block: ({block_name}, {block.__class__.__name__})\n" + f"Error details: {str(e)}\n" + f"Traceback:\n{traceback.format_exc()}" + ) + logger.error(error_msg) + raise + return pipeline, state + + def _get_trigger_inputs(self): + """ + Returns a set of all unique trigger input values found in the blocks. + Returns: Set[str] containing all unique block_trigger_inputs values + """ + def fn_recursive_get_trigger(blocks): + trigger_values = set() + + if blocks is not None: + for name, block in blocks.items(): + # Check if current block has trigger inputs(i.e. auto block) + if hasattr(block, 'block_trigger_inputs') and block.block_trigger_inputs is not None: + # Add all non-None values from the trigger inputs list + trigger_values.update(t for t in block.block_trigger_inputs if t is not None) + + # If block has blocks, recursively check them + if hasattr(block, 'blocks'): + nested_triggers = fn_recursive_get_trigger(block.blocks) + trigger_values.update(nested_triggers) + + return trigger_values + + return fn_recursive_get_trigger(self.blocks) + + @property + def trigger_inputs(self): + return self._get_trigger_inputs() + + def _traverse_trigger_blocks(self, trigger_inputs): + # Convert trigger_inputs to a set for easier manipulation + active_triggers = set(trigger_inputs) + def fn_recursive_traverse(block, block_name, active_triggers): + result_blocks = OrderedDict() + + # sequential(include loopsequential) or PipelineBlock + if not hasattr(block, 'block_trigger_inputs'): + if hasattr(block, 'blocks'): + # sequential or LoopSequentialPipelineBlocks (keep traversing) + for sub_block_name, sub_block in block.blocks.items(): + blocks_to_update = fn_recursive_traverse(sub_block, sub_block_name, active_triggers) + blocks_to_update = fn_recursive_traverse(sub_block, sub_block_name, active_triggers) + blocks_to_update = {f"{block_name}.{k}": v for k,v in blocks_to_update.items()} + result_blocks.update(blocks_to_update) + else: + # PipelineBlock + result_blocks[block_name] = block + # Add this block's output names to active triggers if defined + if hasattr(block, 'outputs'): + active_triggers.update(out.name for out in block.outputs) + return result_blocks + + # auto + else: + # Find first block_trigger_input that matches any value in our active_triggers + this_block = None + matching_trigger = None + for trigger_input in block.block_trigger_inputs: + if trigger_input is not None and trigger_input in active_triggers: + this_block = block.trigger_to_block_map[trigger_input] + matching_trigger = trigger_input + break + + # If no matches found, try to get the default (None) block + if this_block is None and None in block.block_trigger_inputs: + this_block = block.trigger_to_block_map[None] + matching_trigger = None + + if this_block is not None: + # sequential/auto (keep traversing) + if hasattr(this_block, 'blocks'): + result_blocks.update(fn_recursive_traverse(this_block, block_name, active_triggers)) + else: + # PipelineBlock + result_blocks[block_name] = this_block + # Add this block's output names to active triggers if defined + # YiYi TODO: do we need outputs here? can it just be intermediate_outputs? can we get rid of outputs attribute? + if hasattr(this_block, 'outputs'): + active_triggers.update(out.name for out in this_block.outputs) + + return result_blocks + + all_blocks = OrderedDict() + for block_name, block in self.blocks.items(): + blocks_to_update = fn_recursive_traverse(block, block_name, active_triggers) + all_blocks.update(blocks_to_update) + return all_blocks + + def get_execution_blocks(self, *trigger_inputs): + trigger_inputs_all = self.trigger_inputs + + if trigger_inputs is not None: + + if not isinstance(trigger_inputs, (list, tuple, set)): + trigger_inputs = [trigger_inputs] + invalid_inputs = [x for x in trigger_inputs if x not in trigger_inputs_all] + if invalid_inputs: + logger.warning( + f"The following trigger inputs will be ignored as they are not supported: {invalid_inputs}" + ) + trigger_inputs = [x for x in trigger_inputs if x in trigger_inputs_all] + + if trigger_inputs is None: + if None in trigger_inputs_all: + trigger_inputs = [None] + else: + trigger_inputs = [trigger_inputs_all[0]] + blocks_triggered = self._traverse_trigger_blocks(trigger_inputs) + return SequentialPipelineBlocks.from_blocks_dict(blocks_triggered) + + def __repr__(self): + class_name = self.__class__.__name__ + base_class = self.__class__.__bases__[0].__name__ + header = ( + f"{class_name}(\n Class: {base_class}\n" + if base_class and base_class != "object" + else f"{class_name}(\n" + ) + + + if self.trigger_inputs: + header += "\n" + header += " " + "=" * 100 + "\n" + header += " This pipeline contains blocks that are selected at runtime based on inputs.\n" + header += f" Trigger Inputs: {self.trigger_inputs}\n" + # Get first trigger input as example + example_input = next(t for t in self.trigger_inputs if t is not None) + header += f" Use `get_execution_blocks()` with input names to see selected blocks (e.g. `get_execution_blocks('{example_input}')`).\n" + header += " " + "=" * 100 + "\n\n" + + # Format description with proper indentation + desc_lines = self.description.split('\n') + desc = [] + # First line with "Description:" label + desc.append(f" Description: {desc_lines[0]}") + # Subsequent lines with proper indentation + if len(desc_lines) > 1: + desc.extend(f" {line}" for line in desc_lines[1:]) + desc = '\n'.join(desc) + '\n' + + # Components section - focus only on expected components + expected_components = getattr(self, "expected_components", []) + components_str = format_components(expected_components, indent_level=2, add_empty_lines=False) + + # Configs section - use format_configs with add_empty_lines=False + expected_configs = getattr(self, "expected_configs", []) + configs_str = format_configs(expected_configs, indent_level=2, add_empty_lines=False) + + # Blocks section - moved to the end with simplified format + blocks_str = " Blocks:\n" + for i, (name, block) in enumerate(self.blocks.items()): + # Get trigger input for this block + trigger = None + if hasattr(self, 'block_to_trigger_map'): + trigger = self.block_to_trigger_map.get(name) + # Format the trigger info + if trigger is None: + trigger_str = "[default]" + elif isinstance(trigger, (list, tuple)): + trigger_str = f"[trigger: {', '.join(str(t) for t in trigger)}]" + else: + trigger_str = f"[trigger: {trigger}]" + # For AutoPipelineBlocks, add bullet points + blocks_str += f" • {name} {trigger_str} ({block.__class__.__name__})\n" + else: + # For SequentialPipelineBlocks, show execution order + blocks_str += f" [{i}] {name} ({block.__class__.__name__})\n" + + # Add block description + desc_lines = block.description.split('\n') + indented_desc = desc_lines[0] + if len(desc_lines) > 1: + indented_desc += '\n' + '\n'.join(' ' + line for line in desc_lines[1:]) + blocks_str += f" Description: {indented_desc}\n\n" + + # Build the representation with conditional sections + result = f"{header}\n{desc}" + + # Only add components section if it has content + if components_str.strip(): + result += f"\n\n{components_str}" + + # Only add configs section if it has content + if configs_str.strip(): + result += f"\n\n{configs_str}" + + # Always add blocks section + result += f"\n\n{blocks_str})" + + return result + + + @property + def doc(self): + return make_doc_string( + self.inputs, + self.intermediates_inputs, + self.outputs, + self.description, + class_name=self.__class__.__name__, + expected_components=self.expected_components, + expected_configs=self.expected_configs + ) + +#YiYi TODO: __repr__ +class LoopSequentialPipelineBlocks(ModularPipelineMixin): + """ + A class that combines multiple pipeline block classes into a For Loop. When called, it will call each block in sequence. + """ + + model_name = None + block_classes = [] + block_names = [] + + @property + def description(self) -> str: + """Description of the block. Must be implemented by subclasses.""" + raise NotImplementedError("description method must be implemented in subclasses") + + @property + def loop_expected_components(self) -> List[ComponentSpec]: + return [] + + @property + def loop_expected_configs(self) -> List[ConfigSpec]: + return [] + + @property + def loop_inputs(self) -> List[InputParam]: + """List of input parameters. Must be implemented by subclasses.""" + return [] + + @property + def loop_intermediates_inputs(self) -> List[InputParam]: + """List of intermediate input parameters. Must be implemented by subclasses.""" + return [] + + @property + def loop_intermediates_outputs(self) -> List[OutputParam]: + """List of intermediate output parameters. Must be implemented by subclasses.""" + return [] + + + @property + def loop_required_inputs(self) -> List[str]: + input_names = [] + for input_param in self.loop_inputs: + if input_param.required: + input_names.append(input_param.name) + return input_names + + @property + def loop_required_intermediates_inputs(self) -> List[str]: + input_names = [] + for input_param in self.loop_intermediates_inputs: + if input_param.required: + input_names.append(input_param.name) + return input_names + + # modified from SequentialPipelineBlocks to include loop_expected_components + @property + def expected_components(self): + expected_components = [] + for block in self.blocks.values(): + for component in block.expected_components: + if component not in expected_components: + expected_components.append(component) + for component in self.loop_expected_components: + if component not in expected_components: + expected_components.append(component) + return expected_components + + # modified from SequentialPipelineBlocks to include loop_expected_configs + @property + def expected_configs(self): + expected_configs = [] + for block in self.blocks.values(): + for config in block.expected_configs: + if config not in expected_configs: + expected_configs.append(config) + for config in self.loop_expected_configs: + if config not in expected_configs: + expected_configs.append(config) + return expected_configs + + # modified from SequentialPipelineBlocks to include loop_inputs + def get_inputs(self): + named_inputs = [(name, block.inputs) for name, block in self.blocks.items()] + named_inputs.append(("loop", self.loop_inputs)) + combined_inputs = combine_inputs(*named_inputs) + # mark Required inputs only if that input is required any of the blocks + for input_param in combined_inputs: + if input_param.name in self.required_inputs: + input_param.required = True + else: + input_param.required = False + return combined_inputs + + # Copied from SequentialPipelineBlocks + @property + def inputs(self): + return self.get_inputs() + + + # modified from SequentialPipelineBlocks to include loop_intermediates_inputs + @property + def intermediates_inputs(self): + intermediates = self.get_intermediates_inputs() + intermediate_names = [input.name for input in intermediates] + for loop_intermediate_input in self.loop_intermediates_inputs: + if loop_intermediate_input.name not in intermediate_names: + intermediates.append(loop_intermediate_input) + return intermediates + + + # Copied from SequentialPipelineBlocks + def get_intermediates_inputs(self): + inputs = [] + outputs = set() + + # Go through all blocks in order + for block in self.blocks.values(): + # Add inputs that aren't in outputs yet + inputs.extend(input_name for input_name in block.intermediates_inputs if input_name.name not in outputs) + + # Only add outputs if the block cannot be skipped + should_add_outputs = True + if hasattr(block, "block_trigger_inputs") and None not in block.block_trigger_inputs: + should_add_outputs = False + + if should_add_outputs: + # Add this block's outputs + block_intermediates_outputs = [out.name for out in block.intermediates_outputs] + outputs.update(block_intermediates_outputs) + return inputs + + + # modified from SequentialPipelineBlocks, if any additionan input required by the loop is required by the block + @property + def required_inputs(self) -> List[str]: + # Get the first block from the dictionary + first_block = next(iter(self.blocks.values())) + required_by_any = set(getattr(first_block, "required_inputs", set())) + + required_by_loop = set(getattr(self, "loop_required_inputs", set())) + required_by_any.update(required_by_loop) + + # Union with required inputs from all other blocks + for block in list(self.blocks.values())[1:]: + block_required = set(getattr(block, "required_inputs", set())) + required_by_any.update(block_required) + + return list(required_by_any) + + # YiYi TODO: maybe we do not need this, it is only used in docstring, + # intermediate_inputs is by default required, unless you manually handle it inside the block + @property + def required_intermediates_inputs(self) -> List[str]: + required_intermediates_inputs = [] + for input_param in self.intermediates_inputs: + if input_param.required: + required_intermediates_inputs.append(input_param.name) + for input_param in self.loop_intermediates_inputs: + if input_param.required: + required_intermediates_inputs.append(input_param.name) + return required_intermediates_inputs + + + # YiYi TODO: this need to be thought about more + # modified from SequentialPipelineBlocks to include loop_intermediates_outputs + @property + def intermediates_outputs(self) -> List[str]: + named_outputs = [(name, block.intermediates_outputs) for name, block in self.blocks.items()] + combined_outputs = combine_outputs(*named_outputs) + for output in self.loop_intermediates_outputs: + if output.name not in set([output.name for output in combined_outputs]): + combined_outputs.append(output) + return combined_outputs + + # YiYi TODO: this need to be thought about more + # copied from SequentialPipelineBlocks + @property + def outputs(self) -> List[str]: + return next(reversed(self.blocks.values())).intermediates_outputs + + + def __init__(self): + blocks = OrderedDict() + for block_name, block_cls in zip(self.block_names, self.block_classes): + blocks[block_name] = block_cls() + self.blocks = blocks + + @classmethod + def from_blocks_dict(cls, blocks_dict: Dict[str, Any]) -> "LoopSequentialPipelineBlocks": + """Creates a LoopSequentialPipelineBlocks instance from a dictionary of blocks. + + Args: + blocks_dict: Dictionary mapping block names to block instances + + Returns: + A new LoopSequentialPipelineBlocks instance + """ + instance = cls() + instance.block_classes = [block.__class__ for block in blocks_dict.values()] + instance.block_names = list(blocks_dict.keys()) + instance.blocks = blocks_dict + return instance + + def loop_step(self, components, state: PipelineState, **kwargs): + + for block_name, block in self.blocks.items(): + try: + components, state = block(components, state, **kwargs) + except Exception as e: + error_msg = ( + f"\nError in block: ({block_name}, {block.__class__.__name__})\n" + f"Error details: {str(e)}\n" + f"Traceback:\n{traceback.format_exc()}" + ) + logger.error(error_msg) + raise + return components, state + + def __call__(self, components, state: PipelineState) -> PipelineState: + raise NotImplementedError("`__call__` method needs to be implemented by the subclass") + + + def get_block_state(self, state: PipelineState) -> dict: + """Get all inputs and intermediates in one dictionary""" + data = {} + + # Check inputs + for input_param in self.inputs: + if input_param.name: + value = state.get_input(input_param.name) + if input_param.required and value is None: + raise ValueError(f"Required input '{input_param.name}' is missing") + elif value is not None or (value is None and input_param.name not in data): + data[input_param.name] = value + elif input_param.kwargs_type: + # if kwargs_type is provided, get all inputs with matching kwargs_type + if input_param.kwargs_type not in data: + data[input_param.kwargs_type] = {} + inputs_kwargs = state.get_inputs_kwargs(input_param.kwargs_type) + if inputs_kwargs: + for k, v in inputs_kwargs.items(): + if v is not None: + data[k] = v + data[input_param.kwargs_type][k] = v + + # Check intermediates + for input_param in self.intermediates_inputs: + if input_param.name: + value = state.get_intermediate(input_param.name) + if input_param.required and value is None: + raise ValueError(f"Required intermediate input '{input_param.name}' is missing") + elif value is not None or (value is None and input_param.name not in data): + data[input_param.name] = value + elif input_param.kwargs_type: + # if kwargs_type is provided, get all intermediates with matching kwargs_type + if input_param.kwargs_type not in data: + data[input_param.kwargs_type] = {} + intermediates_kwargs = state.get_intermediates_kwargs(input_param.kwargs_type) + if intermediates_kwargs: + for k, v in intermediates_kwargs.items(): + if v is not None: + if k not in data: + data[k] = v + data[input_param.kwargs_type][k] = v + return BlockState(**data) + + def add_block_state(self, state: PipelineState, block_state: BlockState): + for output_param in self.intermediates_outputs: + if not hasattr(block_state, output_param.name): + raise ValueError(f"Intermediate output '{output_param.name}' is missing in block state") + param = getattr(block_state, output_param.name) + state.add_intermediate(output_param.name, param, output_param.kwargs_type) + + for input_param in self.intermediates_inputs: + if input_param.name and hasattr(block_state, input_param.name): + param = getattr(block_state, input_param.name) + # Only add if the value is different from what's in the state + current_value = state.get_intermediate(input_param.name) + if current_value is not param: # Using identity comparison to check if object was modified + state.add_intermediate(input_param.name, param, input_param.kwargs_type) + elif input_param.kwargs_type: + # if it is a kwargs type, e.g. "guider_input_fields", it is likely to be a list of parameters + # we need to first find out which inputs are and loop through them. + intermediates_kwargs = state.get_intermediates_kwargs(input_param.kwargs_type) + for param_name, current_value in intermediates_kwargs.items(): + if not hasattr(block_state, param_name): + continue + param = getattr(block_state, param_name) + if current_value is not param: # Using identity comparison to check if object was modified + state.add_intermediate(param_name, param, input_param.kwargs_type) + + + @property + def doc(self): + return make_doc_string( + self.inputs, + self.intermediates_inputs, + self.outputs, + self.description, + class_name=self.__class__.__name__, + expected_components=self.expected_components, + expected_configs=self.expected_configs + ) + + # modified from SequentialPipelineBlocks, + #(does not need trigger_inputs related part so removed them, + # do not need to support auto block for loop blocks) + def __repr__(self): + class_name = self.__class__.__name__ + base_class = self.__class__.__bases__[0].__name__ + header = ( + f"{class_name}(\n Class: {base_class}\n" + if base_class and base_class != "object" + else f"{class_name}(\n" + ) + + # Format description with proper indentation + desc_lines = self.description.split('\n') + desc = [] + # First line with "Description:" label + desc.append(f" Description: {desc_lines[0]}") + # Subsequent lines with proper indentation + if len(desc_lines) > 1: + desc.extend(f" {line}" for line in desc_lines[1:]) + desc = '\n'.join(desc) + '\n' + + # Components section - focus only on expected components + expected_components = getattr(self, "expected_components", []) + components_str = format_components(expected_components, indent_level=2, add_empty_lines=False) + + # Configs section - use format_configs with add_empty_lines=False + expected_configs = getattr(self, "expected_configs", []) + configs_str = format_configs(expected_configs, indent_level=2, add_empty_lines=False) + + # Blocks section - moved to the end with simplified format + blocks_str = " Blocks:\n" + for i, (name, block) in enumerate(self.blocks.items()): + + # For SequentialPipelineBlocks, show execution order + blocks_str += f" [{i}] {name} ({block.__class__.__name__})\n" + + # Add block description + desc_lines = block.description.split('\n') + indented_desc = desc_lines[0] + if len(desc_lines) > 1: + indented_desc += '\n' + '\n'.join(' ' + line for line in desc_lines[1:]) + blocks_str += f" Description: {indented_desc}\n\n" + + # Build the representation with conditional sections + result = f"{header}\n{desc}" + + # Only add components section if it has content + if components_str.strip(): + result += f"\n\n{components_str}" + + # Only add configs section if it has content + if configs_str.strip(): + result += f"\n\n{configs_str}" + + # Always add blocks section + result += f"\n\n{blocks_str})" + + return result + + + + +# YiYi TODO: +# 1. move the modular_repo arg and the logic to fetch info from repo out of __init__ so that __init__ alwasy create an default modular_model_index config +# 2. look into the serialization of modular_model_index.json, make sure the items are properly ordered like model_index.json (currently a mess) +# 3. do we need ConfigSpec? seems pretty unnecessrary for loader, can just add and kwargs to the loader +# 4. add validator for methods where we accpet kwargs to be passed to from_pretrained() +class ModularLoader(ConfigMixin, PushToHubMixin): + """ + Base class for all Modular pipelines loaders. + + """ + config_name = "modular_model_index.json" + + + def register_components(self, **kwargs): + """ + Register components with their corresponding specifications. + + This method is responsible for: + 1. Sets component objects as attributes on the loader (e.g., self.unet = unet) + 2. Updates the modular_model_index.json configuration for serialization + 4. Adds components to the component manager if one is attached + + This method is called when: + - Components are first initialized in __init__: + - from_pretrained components not loaded during __init__ so they are registered as None; + - non from_pretrained components are created during __init__ and registered as the object itself + - Components are updated with the `update()` method: e.g. loader.update(unet=unet) or loader.update(guider=guider_spec) + - (from_pretrained) Components are loaded with the `load()` method: e.g. loader.load(component_names=["unet"]) + + Args: + **kwargs: Keyword arguments where keys are component names and values are component objects. + E.g., register_components(unet=unet_model, text_encoder=encoder_model) + + Notes: + - Components must be created from ComponentSpec (have _diffusers_load_id attribute) + - When registering None for a component, it updates the modular_model_index.json config but sets attribute to None + """ + for name, module in kwargs.items(): + # current component spec + component_spec = self._component_specs.get(name) + if component_spec is None: + logger.warning(f"ModularLoader.register_components: skipping unknown component '{name}'") + continue + + # check if it is the first time registration, i.e. calling from __init__ + is_registered = hasattr(self, name) + + # make sure the component is created from ComponentSpec + if module is not None and not hasattr(module, "_diffusers_load_id"): + raise ValueError(f"`ModularLoader` only supports components created from `ComponentSpec`.") + + if module is not None: + # actual library and class name of the module + library, class_name = _fetch_class_library_tuple(module) # e.g. ("diffusers", "UNet2DConditionModel") + + # extract the loading spec from the updated component spec that'll be used as part of modular_model_index.json config + # e.g. {"repo": "stabilityai/stable-diffusion-2-1", + # "type_hint": ("diffusers", "UNet2DConditionModel"), + # "subfolder": "unet", + # "variant": None, + # "revision": None} + component_spec_dict = self._component_spec_to_dict(component_spec) + + else: + # if module is None, e.g. self.register_components(unet=None) during __init__ + # we do not update the spec, + # but we still need to update the modular_model_index.json config based oncomponent spec + library, class_name = None, None + component_spec_dict = self._component_spec_to_dict(component_spec) + register_dict = {name: (library, class_name, component_spec_dict)} + + # set the component as attribute + # if it is not set yet, just set it and skip the process to check and warn below + if not is_registered: + self.register_to_config(**register_dict) + setattr(self, name, module) + if module is not None and module._diffusers_load_id != "null" and self._component_manager is not None: + self._component_manager.add(name, module, self._collection) + continue + + current_module = getattr(self, name, None) + # skip if the component is already registered with the same object + if current_module is module: + logger.info(f"ModularLoader.register_components: {name} is already registered with same object, skipping") + continue + + # warn if unregister + if current_module is not None and module is None: + logger.info( + f"ModularLoader.register_components: setting '{name}' to None " + f"(was {current_module.__class__.__name__})" + ) + # same type, new instance → replace but send debug log + elif current_module is not None \ + and module is not None \ + and isinstance(module, current_module.__class__) \ + and current_module != module: + logger.debug( + f"ModularLoader.register_components: replacing existing '{name}' " + f"(same type {type(current_module).__name__}, new instance)" + ) + + # update modular_model_index.json config + self.register_to_config(**register_dict) + # finally set models + setattr(self, name, module) + # add to component manager if one is attached + if module is not None and module._diffusers_load_id != "null" and self._component_manager is not None: + self._component_manager.add(name, module, self._collection) + + + + # YiYi TODO: add warning for passing multiple ComponentSpec/ConfigSpec with the same name + def __init__(self, specs: List[Union[ComponentSpec, ConfigSpec]], modular_repo: Optional[str] = None, component_manager: Optional[ComponentsManager] = None, collection: Optional[str] = None, **kwargs): + """ + Initialize the loader with a list of component specs and config specs. + """ + self._component_manager = component_manager + self._collection = collection + self._component_specs = { + spec.name: deepcopy(spec) for spec in specs if isinstance(spec, ComponentSpec) + } + self._config_specs = { + spec.name: deepcopy(spec) for spec in specs if isinstance(spec, ConfigSpec) + } + + # update component_specs and config_specs from modular_repo + if modular_repo is not None: + config_dict = self.load_config(modular_repo, **kwargs) + + for name, value in config_dict.items(): + # only update component_spec for from_pretrained components + if name in self._component_specs and self._component_specs[name].default_creation_method == "from_pretrained" and isinstance(value, (tuple, list)) and len(value) == 3: + library, class_name, component_spec_dict = value + component_spec = self._dict_to_component_spec(name, component_spec_dict) + self._component_specs[name] = component_spec + + elif name in self._config_specs: + self._config_specs[name].default = value + + register_components_dict = {} + for name, component_spec in self._component_specs.items(): + if component_spec.default_creation_method == "from_config": + component = component_spec.create() + else: + component = None + register_components_dict[name] = component + self.register_components(**register_components_dict) + + default_configs = {} + for name, config_spec in self._config_specs.items(): + default_configs[name] = config_spec.default + self.register_to_config(**default_configs) + + + @property + def device(self) -> torch.device: + r""" + Returns: + `torch.device`: The torch device on which the pipeline is located. + """ + modules = self.components.values() + modules = [m for m in modules if isinstance(m, torch.nn.Module)] + + for module in modules: + return module.device + + return torch.device("cpu") + + @property + # Copied from diffusers.pipelines.pipeline_utils.DiffusionPipeline._execution_device + def _execution_device(self): + r""" + Returns the device on which the pipeline's models will be executed. After calling + [`~DiffusionPipeline.enable_sequential_cpu_offload`] the execution device can only be inferred from + Accelerate's module hooks. + """ + for name, model in self.components.items(): + if not isinstance(model, torch.nn.Module): + continue + + if not hasattr(model, "_hf_hook"): + return self.device + for module in model.modules(): + if ( + hasattr(module, "_hf_hook") + and hasattr(module._hf_hook, "execution_device") + and module._hf_hook.execution_device is not None + ): + return torch.device(module._hf_hook.execution_device) + return self.device + + @property + def device(self) -> torch.device: + r""" + Returns: + `torch.device`: The torch device on which the pipeline is located. + """ + + modules = [m for m in self.components.values() if isinstance(m, torch.nn.Module)] + + for module in modules: + return module.device + + return torch.device("cpu") + + @property + def dtype(self) -> torch.dtype: + r""" + Returns: + `torch.dtype`: The torch dtype on which the pipeline is located. + """ + modules = self.components.values() + modules = [m for m in modules if isinstance(m, torch.nn.Module)] + + for module in modules: + return module.dtype + + return torch.float32 + + + @property + def components(self) -> Dict[str, Any]: + # return only components we've actually set as attributes on self + return { + name: getattr(self, name) + for name in self._component_specs.keys() + if hasattr(self, name) + } + + def update(self, **kwargs): + """ + Update components and configs after instance creation. + + Args: + + """ + """ + Update components and configuration values after the loader has been instantiated. + + This method allows you to: + 1. Replace existing components with new ones (e.g., updating the unet or text_encoder) + 2. Update configuration values (e.g., changing requires_safety_checker flag) + + Args: + **kwargs: Component objects or configuration values to update: + - Component objects: Must be created using ComponentSpec (e.g., `unet=new_unet, text_encoder=new_encoder`) + - Configuration values: Simple values to update configuration settings (e.g., `requires_safety_checker=False`) + - ComponentSpec objects: if passed a ComponentSpec object, only support from_config spec, will call create() method to create it + + Raises: + ValueError: If a component wasn't created using ComponentSpec (doesn't have `_diffusers_load_id` attribute) + + Examples: + ```python + # Update multiple components at once + loader.update( + unet=new_unet_model, + text_encoder=new_text_encoder + ) + + # Update configuration values + loader.update( + requires_safety_checker=False, + guidance_rescale=0.7 + ) + + # Update both components and configs together + loader.update( + unet=new_unet_model, + requires_safety_checker=False + ) + # update with ComponentSpec objects + loader.update( + guider=ComponentSpec(name="guider", type_hint=ClassifierFreeGuidance, config={"guidance_scale": 5.0}, default_creation_method="from_config") + ) + ``` + """ + + # extract component_specs_updates & config_specs_updates from `specs` + passed_component_specs = {k: kwargs.pop(k) for k in self._component_specs if k in kwargs and isinstance(kwargs[k], ComponentSpec)} + passed_components = {k: kwargs.pop(k) for k in self._component_specs if k in kwargs and not isinstance(kwargs[k], ComponentSpec)} + passed_config_values = {k: kwargs.pop(k) for k in self._config_specs if k in kwargs} + + for name, component in passed_components.items(): + if not hasattr(component, "_diffusers_load_id"): + raise ValueError(f"`ModularLoader` only supports components created from `ComponentSpec`.") + + # YiYi TODO: remove this if we remove support for non config mixin components in `create()` method + if component._diffusers_load_id == "null" and not isinstance(component, ConfigMixin): + raise ValueError( + f"The passed component '{name}' is not supported in update() method " + f"because it is not supported in `ComponentSpec.from_component()`. " + f"Please pass a ComponentSpec object instead." + ) + current_component_spec = self._component_specs[name] + # warn if type changed + if current_component_spec.type_hint is not None and not isinstance(component, current_component_spec.type_hint): + logger.warning(f"ModularLoader.update: adding {name} with new type: {component.__class__.__name__}, previous type: {current_component_spec.type_hint.__name__}") + # update _component_specs based on the new component + new_component_spec = ComponentSpec.from_component(name, component) + self._component_specs[name] = new_component_spec + + if len(kwargs) > 0: + logger.warning(f"Unexpected keyword arguments, will be ignored: {kwargs.keys()}") + + created_components = {} + for name, component_spec in passed_component_specs.items(): + if component_spec.default_creation_method == "from_pretrained": + raise ValueError(f"ComponentSpec object with default_creation_method == 'from_pretrained' is not supported in update() method") + created_components[name] = component_spec.create() + current_component_spec = self._component_specs[name] + # warn if type changed + if current_component_spec.type_hint is not None and not isinstance(created_components[name], current_component_spec.type_hint): + logger.warning(f"ModularLoader.update: adding {name} with new type: {created_components[name].__class__.__name__}, previous type: {current_component_spec.type_hint.__name__}") + # update _component_specs based on the user passed component_spec + self._component_specs[name] = component_spec + self.register_components(**passed_components, **created_components) + + + config_to_register = {} + for name, new_value in passed_config_values.items(): + + # e.g. requires_aesthetics_score = False + self._config_specs[name].default = new_value + config_to_register[name] = new_value + self.register_to_config(**config_to_register) + + + # YiYi TODO: support map for additional from_pretrained kwargs + def load(self, component_names: Optional[List[str]] = None, **kwargs): + """ + Load selectedcomponents from specs. + + Args: + component_names: List of component names to load + **kwargs: additional kwargs to be passed to `from_pretrained()`.Can be: + - a single value to be applied to all components to be loaded, e.g. torch_dtype=torch.bfloat16 + - a dict, e.g. torch_dtype={"unet": torch.bfloat16, "default": torch.float32} + - if potentially override ComponentSpec if passed a different loading field in kwargs, e.g. `repo`, `variant`, `revision`, etc. + """ + # if not specific name, load all the components with default_creation_method == "from_pretrained" + if component_names is None: + component_names = [name for name in self._component_specs.keys() if self._component_specs[name].default_creation_method == "from_pretrained"] + elif not isinstance(component_names, list): + component_names = [component_names] + + components_to_load = set([name for name in component_names if name in self._component_specs]) + unknown_component_names = set([name for name in component_names if name not in self._component_specs]) + if len(unknown_component_names) > 0: + logger.warning(f"Unknown components will be ignored: {unknown_component_names}") + + components_to_register = {} + for name in components_to_load: + spec = self._component_specs[name] + component_load_kwargs = {} + for key, value in kwargs.items(): + if not isinstance(value, dict): + # if the value is a single value, apply it to all components + component_load_kwargs[key] = value + else: + if name in value: + # if it is a dict, check if the component name is in the dict + component_load_kwargs[key] = value[name] + elif "default" in value: + # check if the default is specified + component_load_kwargs[key] = value["default"] + try: + components_to_register[name] = spec.load(**component_load_kwargs) + except Exception as e: + logger.warning(f"Failed to create component '{name}': {e}") + + # Register all components at once + self.register_components(**components_to_register) + + # YiYi TODO: should support to method + def to(self, *args, **kwargs): + pass + + # YiYi TODO: + # 1. should support save some components too! currently only modular_model_index.json is saved + # 2. maybe order the json file to make it more readable: configs first, then components + def save_pretrained(self, save_directory: Union[str, os.PathLike], push_to_hub: bool = False, spec_only: bool = True, **kwargs): + + component_names = list(self._component_specs.keys()) + config_names = list(self._config_specs.keys()) + self.register_to_config(_components_names=component_names, _configs_names=config_names) + self.save_config(save_directory=save_directory, push_to_hub=push_to_hub, **kwargs) + config = dict(self.config) + config.pop("_components_names", None) + config.pop("_configs_names", None) + self._internal_dict = FrozenDict(config) + + + @classmethod + @validate_hf_hub_args + def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], spec_only: bool = True, component_manager: Optional[ComponentsManager] = None, collection: Optional[str] = None, **kwargs): + + config_dict = cls.load_config(pretrained_model_name_or_path, **kwargs) + expected_component = set(config_dict.pop("_components_names")) + expected_config = set(config_dict.pop("_configs_names")) + + component_specs = [] + config_specs = [] + for name, value in config_dict.items(): + if name in expected_component and isinstance(value, (tuple, list)) and len(value) == 3: + library, class_name, component_spec_dict = value + # only pick up pretrained components from the repo + if component_spec_dict.get("repo", None) is not None: + component_spec = cls._dict_to_component_spec(name, component_spec_dict) + component_specs.append(component_spec) + + elif name in expected_config: + config_specs.append(ConfigSpec(name=name, default=value)) + + return cls(component_specs + config_specs, component_manager=component_manager, collection=collection) + + + @staticmethod + def _component_spec_to_dict(component_spec: ComponentSpec) -> Any: + """ + Convert a ComponentSpec into a JSON‐serializable dict for saving in + `modular_model_index.json`. + + This dict contains: + - "type_hint": Tuple[str, str] + The fully‐qualified module path and class name of the component. + - All loading fields defined by `component_spec.loading_fields()`, typically: + - "repo": Optional[str] + The model repository (e.g., "stabilityai/stable-diffusion-xl"). + - "subfolder": Optional[str] + A subfolder within the repo where this component lives. + - "variant": Optional[str] + An optional variant identifier for the model. + - "revision": Optional[str] + A specific git revision (commit hash, tag, or branch). + - ... any other loading fields defined on the spec. + + Args: + component_spec (ComponentSpec): + The spec object describing one pipeline component. + + Returns: + Dict[str, Any]: A mapping suitable for JSON serialization. + + Example: + >>> from diffusers.pipelines.modular_pipeline_utils import ComponentSpec + >>> from diffusers.models.unet import UNet2DConditionModel + >>> spec = ComponentSpec( + ... name="unet", + ... type_hint=UNet2DConditionModel, + ... config=None, + ... repo="path/to/repo", + ... subfolder="subfolder", + ... variant=None, + ... revision=None, + ... default_creation_method="from_pretrained", + ... ) + >>> ModularLoader._component_spec_to_dict(spec) + { + "type_hint": ("diffusers.models.unet", "UNet2DConditionModel"), + "repo": "path/to/repo", + "subfolder": "subfolder", + "variant": None, + "revision": None, + } + """ + if component_spec.type_hint is not None: + lib_name, cls_name = _fetch_class_library_tuple(component_spec.type_hint) + else: + lib_name = None + cls_name = None + load_spec_dict = {k: getattr(component_spec, k) for k in component_spec.loading_fields()} + return { + "type_hint": (lib_name, cls_name), + **load_spec_dict, + } + + @staticmethod + def _dict_to_component_spec( + name: str, + spec_dict: Dict[str, Any], + ) -> ComponentSpec: + """ + Reconstruct a ComponentSpec from a dict. + """ + # make a shallow copy so we can pop() safely + spec_dict = spec_dict.copy() + # pull out and resolve the stored type_hint + lib_name, cls_name = spec_dict.pop("type_hint") + if lib_name is not None and cls_name is not None: + type_hint = simple_get_class_obj(lib_name, cls_name) + else: + type_hint = None + + # re‐assemble the ComponentSpec + return ComponentSpec( + name=name, + type_hint=type_hint, + **spec_dict, + ) \ No newline at end of file diff --git a/src/diffusers/modular_pipelines/modular_pipeline_utils.py b/src/diffusers/modular_pipelines/modular_pipeline_utils.py new file mode 100644 index 000000000000..6d6704f4eb38 --- /dev/null +++ b/src/diffusers/modular_pipelines/modular_pipeline_utils.py @@ -0,0 +1,595 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import re +import inspect +from dataclasses import dataclass, asdict, field, fields +from typing import Any, Dict, List, Optional, Tuple, Type, Union, Literal + +from ..utils.import_utils import is_torch_available +from ..configuration_utils import FrozenDict, ConfigMixin + +if is_torch_available(): + import torch + + +# YiYi TODO: +# 1. validate the dataclass fields +# 2. add a validator for create_* methods, make sure they are valid inputs to pass to from_pretrained() +@dataclass +class ComponentSpec: + """Specification for a pipeline component. + + A component can be created in two ways: + 1. From scratch using __init__ with a config dict + 2. using `from_pretrained` + + Attributes: + name: Name of the component + type_hint: Type of the component (e.g. UNet2DConditionModel) + description: Optional description of the component + config: Optional config dict for __init__ creation + repo: Optional repo path for from_pretrained creation + subfolder: Optional subfolder in repo + variant: Optional variant in repo + revision: Optional revision in repo + default_creation_method: Preferred creation method - "from_config" or "from_pretrained" + """ + name: Optional[str] = None + type_hint: Optional[Type] = None + description: Optional[str] = None + config: Optional[FrozenDict[str, Any]] = None + # YiYi Notes: should we change it to pretrained_model_name_or_path for consistency? a bit long for a field name + repo: Optional[Union[str, List[str]]] = field(default=None, metadata={"loading": True}) + subfolder: Optional[str] = field(default=None, metadata={"loading": True}) + variant: Optional[str] = field(default=None, metadata={"loading": True}) + revision: Optional[str] = field(default=None, metadata={"loading": True}) + default_creation_method: Literal["from_config", "from_pretrained"] = "from_pretrained" + + + def __hash__(self): + """Make ComponentSpec hashable, using load_id as the hash value.""" + return hash((self.name, self.load_id, self.default_creation_method)) + + def __eq__(self, other): + """Compare ComponentSpec objects based on name and load_id.""" + if not isinstance(other, ComponentSpec): + return False + return (self.name == other.name and + self.load_id == other.load_id and + self.default_creation_method == other.default_creation_method) + + @classmethod + def from_component(cls, name: str, component: Any) -> Any: + """Create a ComponentSpec from a Component created by `create` or `load` method.""" + + if not hasattr(component, "_diffusers_load_id"): + raise ValueError("Component is not created by `create` or `load` method") + # throw a error if component is created with `create` method but not a subclass of ConfigMixin + # YiYi TODO: remove this check if we remove support for non configmixin in `create()` method + if component._diffusers_load_id == "null" and not isinstance(component, ConfigMixin): + raise ValueError( + "We currently only support creating ComponentSpec from a component with " + "created with `ComponentSpec.load` method" + "or created with `ComponentSpec.create` and a subclass of ConfigMixin" + ) + + type_hint = component.__class__ + default_creation_method = "from_config" if component._diffusers_load_id == "null" else "from_pretrained" + + if isinstance(component, ConfigMixin): + config = component.config + else: + config = None + + load_spec = cls.decode_load_id(component._diffusers_load_id) + + return cls(name=name, type_hint=type_hint, config=config, default_creation_method=default_creation_method, **load_spec) + + @classmethod + def loading_fields(cls) -> List[str]: + """ + Return the names of all loading‐related fields + (i.e. those whose field.metadata["loading"] is True). + """ + return [f.name for f in fields(cls) if f.metadata.get("loading", False)] + + + @property + def load_id(self) -> str: + """ + Unique identifier for this spec's pretrained load, + composed of repo|subfolder|variant|revision (no empty segments). + """ + parts = [getattr(self, k) for k in self.loading_fields()] + parts = ["null" if p is None else p for p in parts] + return "|".join(p for p in parts if p) + + @classmethod + def decode_load_id(cls, load_id: str) -> Dict[str, Optional[str]]: + """ + Decode a load_id string back into a dictionary of loading fields and values. + + Args: + load_id: The load_id string to decode, format: "repo|subfolder|variant|revision" + where None values are represented as "null" + + Returns: + Dict mapping loading field names to their values. e.g. + { + "repo": "path/to/repo", + "subfolder": "subfolder", + "variant": "variant", + "revision": "revision" + } + If a segment value is "null", it's replaced with None. + Returns None if load_id is "null" (indicating component not created with `load` method). + """ + + # Get all loading fields in order + loading_fields = cls.loading_fields() + result = {f: None for f in loading_fields} + + if load_id == "null": + return result + + # Split the load_id + parts = load_id.split("|") + + # Map parts to loading fields by position + for i, part in enumerate(parts): + if i < len(loading_fields): + # Convert "null" string back to None + result[loading_fields[i]] = None if part == "null" else part + + return result + + + # YiYi TODO: I think we should only support ConfigMixin for this method (after we make guider and image_processors config mixin) + # otherwise we cannot do spec -> spec.create() -> component -> ComponentSpec.from_component(component) + # the config info is lost in the process + # remove error check in from_component spec and ModularLoader.update() if we remove support for non configmixin in `create()` method + def create(self, config: Optional[Union[FrozenDict, Dict[str, Any]]] = None, **kwargs) -> Any: + """Create component using from_config with config.""" + + if self.type_hint is None or not isinstance(self.type_hint, type): + raise ValueError( + f"`type_hint` is required when using from_config creation method." + ) + + config = config or self.config or {} + + if issubclass(self.type_hint, ConfigMixin): + component = self.type_hint.from_config(config, **kwargs) + else: + signature_params = inspect.signature(self.type_hint.__init__).parameters + init_kwargs = {} + for k, v in config.items(): + if k in signature_params: + init_kwargs[k] = v + for k, v in kwargs.items(): + if k in signature_params: + init_kwargs[k] = v + component = self.type_hint(**init_kwargs) + + component._diffusers_load_id = "null" + if hasattr(component, "config"): + self.config = component.config + + return component + + # YiYi TODO: add guard for type of model, if it is supported by from_pretrained + def load(self, **kwargs) -> Any: + """Load component using from_pretrained.""" + + # select loading fields from kwargs passed from user: e.g. repo, subfolder, variant, revision, note the list could change + passed_loading_kwargs = {key: kwargs.pop(key) for key in self.loading_fields() if key in kwargs} + # merge loading field value in the spec with user passed values to create load_kwargs + load_kwargs = {key: passed_loading_kwargs.get(key, getattr(self, key)) for key in self.loading_fields()} + # repo is a required argument for from_pretrained, a.k.a. pretrained_model_name_or_path + repo = load_kwargs.pop("repo", None) + if repo is None: + raise ValueError(f"`repo` info is required when using `load` method (you can directly set it in `repo` field of the ComponentSpec or pass it as an argument)") + + if self.type_hint is None: + try: + from diffusers import AutoModel + component = AutoModel.from_pretrained(repo, **load_kwargs, **kwargs) + except Exception as e: + raise ValueError(f"Unable to load {self.name} without `type_hint`: {e}") + # update type_hint if AutoModel load successfully + self.type_hint = component.__class__ + else: + try: + component = self.type_hint.from_pretrained(repo, **load_kwargs, **kwargs) + except Exception as e: + raise ValueError(f"Unable to load {self.name} using load method: {e}") + + self.repo = repo + for k, v in load_kwargs.items(): + setattr(self, k, v) + component._diffusers_load_id = self.load_id + + return component + + + +@dataclass +class ConfigSpec: + """Specification for a pipeline configuration parameter.""" + name: str + default: Any + description: Optional[str] = None + + +# YiYi Notes: both inputs and intermediates_inputs are InputParam objects +# however some fields are not relevant for intermediates_inputs +# e.g. unlike inputs, required only used in docstring for intermediate_inputs, we do not check if a required intermediate inputs is passed +# default is not used for intermediates_inputs, we only use default from inputs, so it is ignored if it is set for intermediates_inputs +# -> should we use different class for inputs and intermediates_inputs? +@dataclass +class InputParam: + """Specification for an input parameter.""" + name: str = None + type_hint: Any = None + default: Any = None + required: bool = False + description: str = "" + kwargs_type: str = None # YiYi Notes: remove this feature (maybe) + + def __repr__(self): + return f"<{self.name}: {'required' if self.required else 'optional'}, default={self.default}>" + + +@dataclass +class OutputParam: + """Specification for an output parameter.""" + name: str + type_hint: Any = None + description: str = "" + kwargs_type: str = None # YiYi notes: remove this feature (maybe) + + def __repr__(self): + return f"<{self.name}: {self.type_hint.__name__ if hasattr(self.type_hint, '__name__') else str(self.type_hint)}>" + + +def format_inputs_short(inputs): + """ + Format input parameters into a string representation, with required params first followed by optional ones. + + Args: + inputs: List of input parameters with 'required' and 'name' attributes, and 'default' for optional params + + Returns: + str: Formatted string of input parameters + + Example: + >>> inputs = [ + ... InputParam(name="prompt", required=True), + ... InputParam(name="image", required=True), + ... InputParam(name="guidance_scale", required=False, default=7.5), + ... InputParam(name="num_inference_steps", required=False, default=50) + ... ] + >>> format_inputs_short(inputs) + 'prompt, image, guidance_scale=7.5, num_inference_steps=50' + """ + required_inputs = [param for param in inputs if param.required] + optional_inputs = [param for param in inputs if not param.required] + + required_str = ", ".join(param.name for param in required_inputs) + optional_str = ", ".join(f"{param.name}={param.default}" for param in optional_inputs) + + inputs_str = required_str + if optional_str: + inputs_str = f"{inputs_str}, {optional_str}" if required_str else optional_str + + return inputs_str + + +def format_intermediates_short(intermediates_inputs, required_intermediates_inputs, intermediates_outputs): + """ + Formats intermediate inputs and outputs of a block into a string representation. + + Args: + intermediates_inputs: List of intermediate input parameters + required_intermediates_inputs: List of required intermediate input names + intermediates_outputs: List of intermediate output parameters + + Returns: + str: Formatted string like: + Intermediates: + - inputs: Required(latents), dtype + - modified: latents # variables that appear in both inputs and outputs + - outputs: images # new outputs only + """ + # Handle inputs + input_parts = [] + for inp in intermediates_inputs: + if inp.name in required_intermediates_inputs: + input_parts.append(f"Required({inp.name})") + else: + if inp.name is None and inp.kwargs_type is not None: + inp_name = "*_" + inp.kwargs_type + else: + inp_name = inp.name + input_parts.append(inp_name) + + # Handle modified variables (appear in both inputs and outputs) + inputs_set = {inp.name for inp in intermediates_inputs} + modified_parts = [] + new_output_parts = [] + + for out in intermediates_outputs: + if out.name in inputs_set: + modified_parts.append(out.name) + else: + new_output_parts.append(out.name) + + result = [] + if input_parts: + result.append(f" - inputs: {', '.join(input_parts)}") + if modified_parts: + result.append(f" - modified: {', '.join(modified_parts)}") + if new_output_parts: + result.append(f" - outputs: {', '.join(new_output_parts)}") + + return "\n".join(result) if result else " (none)" + + +def format_params(params, header="Args", indent_level=4, max_line_length=115): + """Format a list of InputParam or OutputParam objects into a readable string representation. + + Args: + params: List of InputParam or OutputParam objects to format + header: Header text to use (e.g. "Args" or "Returns") + indent_level: Number of spaces to indent each parameter line (default: 4) + max_line_length: Maximum length for each line before wrapping (default: 115) + + Returns: + A formatted string representing all parameters + """ + if not params: + return "" + + base_indent = " " * indent_level + param_indent = " " * (indent_level + 4) + desc_indent = " " * (indent_level + 8) + formatted_params = [] + + def get_type_str(type_hint): + if hasattr(type_hint, "__origin__") and type_hint.__origin__ is Union: + types = [t.__name__ if hasattr(t, "__name__") else str(t) for t in type_hint.__args__] + return f"Union[{', '.join(types)}]" + return type_hint.__name__ if hasattr(type_hint, "__name__") else str(type_hint) + + def wrap_text(text, indent, max_length): + """Wrap text while preserving markdown links and maintaining indentation.""" + words = text.split() + lines = [] + current_line = [] + current_length = 0 + + for word in words: + word_length = len(word) + (1 if current_line else 0) + + if current_line and current_length + word_length > max_length: + lines.append(" ".join(current_line)) + current_line = [word] + current_length = len(word) + else: + current_line.append(word) + current_length += word_length + + if current_line: + lines.append(" ".join(current_line)) + + return f"\n{indent}".join(lines) + + # Add the header + formatted_params.append(f"{base_indent}{header}:") + + for param in params: + # Format parameter name and type + type_str = get_type_str(param.type_hint) if param.type_hint != Any else "" + param_str = f"{param_indent}{param.name} (`{type_str}`" + + # Add optional tag and default value if parameter is an InputParam and optional + if hasattr(param, "required"): + if not param.required: + param_str += ", *optional*" + if param.default is not None: + param_str += f", defaults to {param.default}" + param_str += "):" + + # Add description on a new line with additional indentation and wrapping + if param.description: + desc = re.sub( + r'\[(.*?)\]\((https?://[^\s\)]+)\)', + r'[\1](\2)', + param.description + ) + wrapped_desc = wrap_text(desc, desc_indent, max_line_length) + param_str += f"\n{desc_indent}{wrapped_desc}" + + formatted_params.append(param_str) + + return "\n\n".join(formatted_params) + + +def format_input_params(input_params, indent_level=4, max_line_length=115): + """Format a list of InputParam objects into a readable string representation. + + Args: + input_params: List of InputParam objects to format + indent_level: Number of spaces to indent each parameter line (default: 4) + max_line_length: Maximum length for each line before wrapping (default: 115) + + Returns: + A formatted string representing all input parameters + """ + return format_params(input_params, "Inputs", indent_level, max_line_length) + + +def format_output_params(output_params, indent_level=4, max_line_length=115): + """Format a list of OutputParam objects into a readable string representation. + + Args: + output_params: List of OutputParam objects to format + indent_level: Number of spaces to indent each parameter line (default: 4) + max_line_length: Maximum length for each line before wrapping (default: 115) + + Returns: + A formatted string representing all output parameters + """ + return format_params(output_params, "Outputs", indent_level, max_line_length) + + +def format_components(components, indent_level=4, max_line_length=115, add_empty_lines=True): + """Format a list of ComponentSpec objects into a readable string representation. + + Args: + components: List of ComponentSpec objects to format + indent_level: Number of spaces to indent each component line (default: 4) + max_line_length: Maximum length for each line before wrapping (default: 115) + add_empty_lines: Whether to add empty lines between components (default: True) + + Returns: + A formatted string representing all components + """ + if not components: + return "" + + base_indent = " " * indent_level + component_indent = " " * (indent_level + 4) + formatted_components = [] + + # Add the header + formatted_components.append(f"{base_indent}Components:") + if add_empty_lines: + formatted_components.append("") + + # Add each component with optional empty lines between them + for i, component in enumerate(components): + # Get type name, handling special cases + type_name = component.type_hint.__name__ if hasattr(component.type_hint, "__name__") else str(component.type_hint) + + component_desc = f"{component_indent}{component.name} (`{type_name}`)" + if component.description: + component_desc += f": {component.description}" + + # Get the loading fields dynamically + loading_field_values = [] + for field_name in component.loading_fields(): + field_value = getattr(component, field_name) + if field_value is not None: + loading_field_values.append(f"{field_name}={field_value}") + + # Add loading field information if available + if loading_field_values: + component_desc += f" [{', '.join(loading_field_values)}]" + + formatted_components.append(component_desc) + + # Add an empty line after each component except the last one + if add_empty_lines and i < len(components) - 1: + formatted_components.append("") + + return "\n".join(formatted_components) + + +def format_configs(configs, indent_level=4, max_line_length=115, add_empty_lines=True): + """Format a list of ConfigSpec objects into a readable string representation. + + Args: + configs: List of ConfigSpec objects to format + indent_level: Number of spaces to indent each config line (default: 4) + max_line_length: Maximum length for each line before wrapping (default: 115) + add_empty_lines: Whether to add empty lines between configs (default: True) + + Returns: + A formatted string representing all configs + """ + if not configs: + return "" + + base_indent = " " * indent_level + config_indent = " " * (indent_level + 4) + formatted_configs = [] + + # Add the header + formatted_configs.append(f"{base_indent}Configs:") + if add_empty_lines: + formatted_configs.append("") + + # Add each config with optional empty lines between them + for i, config in enumerate(configs): + config_desc = f"{config_indent}{config.name} (default: {config.default})" + if config.description: + config_desc += f": {config.description}" + formatted_configs.append(config_desc) + + # Add an empty line after each config except the last one + if add_empty_lines and i < len(configs) - 1: + formatted_configs.append("") + + return "\n".join(formatted_configs) + + +def make_doc_string(inputs, intermediates_inputs, outputs, description="", class_name=None, expected_components=None, expected_configs=None): + """ + Generates a formatted documentation string describing the pipeline block's parameters and structure. + + Args: + inputs: List of input parameters + intermediates_inputs: List of intermediate input parameters + outputs: List of output parameters + description (str, *optional*): Description of the block + class_name (str, *optional*): Name of the class to include in the documentation + expected_components (List[ComponentSpec], *optional*): List of expected components + expected_configs (List[ConfigSpec], *optional*): List of expected configurations + + Returns: + str: A formatted string containing information about components, configs, call parameters, + intermediate inputs/outputs, and final outputs. + """ + output = "" + + # Add class name if provided + if class_name: + output += f"class {class_name}\n\n" + + # Add description + if description: + desc_lines = description.strip().split('\n') + aligned_desc = '\n'.join(' ' + line for line in desc_lines) + output += aligned_desc + "\n\n" + + # Add components section if provided + if expected_components and len(expected_components) > 0: + components_str = format_components(expected_components, indent_level=2) + output += components_str + "\n\n" + + # Add configs section if provided + if expected_configs and len(expected_configs) > 0: + configs_str = format_configs(expected_configs, indent_level=2) + output += configs_str + "\n\n" + + # Add inputs section + output += format_input_params(inputs + intermediates_inputs, indent_level=2) + + # Add outputs section + output += "\n\n" + output += format_output_params(outputs, indent_level=2) + + return output \ No newline at end of file diff --git a/src/diffusers/modular_pipelines/node_utils.py b/src/diffusers/modular_pipelines/node_utils.py new file mode 100644 index 000000000000..9ee9c069277d --- /dev/null +++ b/src/diffusers/modular_pipelines/node_utils.py @@ -0,0 +1,519 @@ +from ..configuration_utils import ConfigMixin +from .modular_pipeline import SequentialPipelineBlocks, ModularPipelineMixin +from .modular_pipeline_utils import InputParam, OutputParam +from ..image_processor import PipelineImageInput +from pathlib import Path +import json +import os + +from typing import Union, List, Optional, Tuple +import torch +import PIL +import numpy as np +import logging +logger = logging.getLogger(__name__) + +# YiYi Notes: this is actually for SDXL, put it here for now +SDXL_INPUTS_SCHEMA = { + "prompt": InputParam("prompt", type_hint=Union[str, List[str]], description="The prompt or prompts to guide the image generation"), + "prompt_2": InputParam("prompt_2", type_hint=Union[str, List[str]], description="The prompt or prompts to be sent to the tokenizer_2 and text_encoder_2"), + "negative_prompt": InputParam("negative_prompt", type_hint=Union[str, List[str]], description="The prompt or prompts not to guide the image generation"), + "negative_prompt_2": InputParam("negative_prompt_2", type_hint=Union[str, List[str]], description="The negative prompt or prompts for text_encoder_2"), + "cross_attention_kwargs": InputParam("cross_attention_kwargs", type_hint=Optional[dict], description="Kwargs dictionary passed to the AttentionProcessor"), + "clip_skip": InputParam("clip_skip", type_hint=Optional[int], description="Number of layers to skip in CLIP text encoder"), + "image": InputParam("image", type_hint=PipelineImageInput, required=True, description="The image(s) to modify for img2img or inpainting"), + "mask_image": InputParam("mask_image", type_hint=PipelineImageInput, required=True, description="Mask image for inpainting, white pixels will be repainted"), + "generator": InputParam("generator", type_hint=Optional[Union[torch.Generator, List[torch.Generator]]], description="Generator(s) for deterministic generation"), + "height": InputParam("height", type_hint=Optional[int], description="Height in pixels of the generated image"), + "width": InputParam("width", type_hint=Optional[int], description="Width in pixels of the generated image"), + "num_images_per_prompt": InputParam("num_images_per_prompt", type_hint=int, default=1, description="Number of images to generate per prompt"), + "num_inference_steps": InputParam("num_inference_steps", type_hint=int, default=50, description="Number of denoising steps"), + "timesteps": InputParam("timesteps", type_hint=Optional[torch.Tensor], description="Custom timesteps for the denoising process"), + "sigmas": InputParam("sigmas", type_hint=Optional[torch.Tensor], description="Custom sigmas for the denoising process"), + "denoising_end": InputParam("denoising_end", type_hint=Optional[float], description="Fraction of denoising process to complete before termination"), + # YiYi Notes: img2img defaults to 0.3, inpainting defaults to 0.9999 + "strength": InputParam("strength", type_hint=float, default=0.3, description="How much to transform the reference image"), + "denoising_start": InputParam("denoising_start", type_hint=Optional[float], description="Starting point of the denoising process"), + "latents": InputParam("latents", type_hint=Optional[torch.Tensor], description="Pre-generated noisy latents for image generation"), + "padding_mask_crop": InputParam("padding_mask_crop", type_hint=Optional[Tuple[int, int]], description="Size of margin in crop for image and mask"), + "original_size": InputParam("original_size", type_hint=Optional[Tuple[int, int]], description="Original size of the image for SDXL's micro-conditioning"), + "target_size": InputParam("target_size", type_hint=Optional[Tuple[int, int]], description="Target size for SDXL's micro-conditioning"), + "negative_original_size": InputParam("negative_original_size", type_hint=Optional[Tuple[int, int]], description="Negative conditioning based on image resolution"), + "negative_target_size": InputParam("negative_target_size", type_hint=Optional[Tuple[int, int]], description="Negative conditioning based on target resolution"), + "crops_coords_top_left": InputParam("crops_coords_top_left", type_hint=Tuple[int, int], default=(0, 0), description="Top-left coordinates for SDXL's micro-conditioning"), + "negative_crops_coords_top_left": InputParam("negative_crops_coords_top_left", type_hint=Tuple[int, int], default=(0, 0), description="Negative conditioning crop coordinates"), + "aesthetic_score": InputParam("aesthetic_score", type_hint=float, default=6.0, description="Simulates aesthetic score of generated image"), + "negative_aesthetic_score": InputParam("negative_aesthetic_score", type_hint=float, default=2.0, description="Simulates negative aesthetic score"), + "eta": InputParam("eta", type_hint=float, default=0.0, description="Parameter η in the DDIM paper"), + "output_type": InputParam("output_type", type_hint=str, default="pil", description="Output format (pil/tensor/np.array)"), + "ip_adapter_image": InputParam("ip_adapter_image", type_hint=PipelineImageInput, required=True, description="Image(s) to be used as IP adapter"), + "control_image": InputParam("control_image", type_hint=PipelineImageInput, required=True, description="ControlNet input condition"), + "control_guidance_start": InputParam("control_guidance_start", type_hint=Union[float, List[float]], default=0.0, description="When ControlNet starts applying"), + "control_guidance_end": InputParam("control_guidance_end", type_hint=Union[float, List[float]], default=1.0, description="When ControlNet stops applying"), + "controlnet_conditioning_scale": InputParam("controlnet_conditioning_scale", type_hint=Union[float, List[float]], default=1.0, description="Scale factor for ControlNet outputs"), + "guess_mode": InputParam("guess_mode", type_hint=bool, default=False, description="Enables ControlNet encoder to recognize input without prompts"), + "control_mode": InputParam("control_mode", type_hint=List[int], required=True, description="Control mode for union controlnet") +} + +SDXL_INTERMEDIATE_INPUTS_SCHEMA = { + "prompt_embeds": InputParam("prompt_embeds", type_hint=torch.Tensor, required=True, description="Text embeddings used to guide image generation"), + "negative_prompt_embeds": InputParam("negative_prompt_embeds", type_hint=torch.Tensor, description="Negative text embeddings"), + "pooled_prompt_embeds": InputParam("pooled_prompt_embeds", type_hint=torch.Tensor, required=True, description="Pooled text embeddings"), + "negative_pooled_prompt_embeds": InputParam("negative_pooled_prompt_embeds", type_hint=torch.Tensor, description="Negative pooled text embeddings"), + "batch_size": InputParam("batch_size", type_hint=int, required=True, description="Number of prompts"), + "dtype": InputParam("dtype", type_hint=torch.dtype, description="Data type of model tensor inputs"), + "preprocess_kwargs": InputParam("preprocess_kwargs", type_hint=Optional[dict], description="Kwargs for ImageProcessor"), + "latents": InputParam("latents", type_hint=torch.Tensor, required=True, description="Initial latents for denoising process"), + "timesteps": InputParam("timesteps", type_hint=torch.Tensor, required=True, description="Timesteps for inference"), + "num_inference_steps": InputParam("num_inference_steps", type_hint=int, required=True, description="Number of denoising steps"), + "latent_timestep": InputParam("latent_timestep", type_hint=torch.Tensor, required=True, description="Initial noise level timestep"), + "image_latents": InputParam("image_latents", type_hint=torch.Tensor, required=True, description="Latents representing reference image"), + "mask": InputParam("mask", type_hint=torch.Tensor, required=True, description="Mask for inpainting"), + "masked_image_latents": InputParam("masked_image_latents", type_hint=torch.Tensor, description="Masked image latents for inpainting"), + "add_time_ids": InputParam("add_time_ids", type_hint=torch.Tensor, required=True, description="Time ids for conditioning"), + "negative_add_time_ids": InputParam("negative_add_time_ids", type_hint=torch.Tensor, description="Negative time ids"), + "timestep_cond": InputParam("timestep_cond", type_hint=torch.Tensor, description="Timestep conditioning for LCM"), + "noise": InputParam("noise", type_hint=torch.Tensor, description="Noise added to image latents"), + "crops_coords": InputParam("crops_coords", type_hint=Optional[Tuple[int]], description="Crop coordinates"), + "ip_adapter_embeds": InputParam("ip_adapter_embeds", type_hint=List[torch.Tensor], description="Image embeddings for IP-Adapter"), + "negative_ip_adapter_embeds": InputParam("negative_ip_adapter_embeds", type_hint=List[torch.Tensor], description="Negative image embeddings for IP-Adapter"), + "images": InputParam("images", type_hint=Union[List[PIL.Image.Image], List[torch.Tensor], List[np.array]], required=True, description="Generated images") +} + +SDXL_PARAM_SCHEMA = {**SDXL_INPUTS_SCHEMA, **SDXL_INTERMEDIATE_INPUTS_SCHEMA} + + +DEFAULT_PARAM_MAPS = { + "prompt": { + "label": "Prompt", + "type": "string", + "default": "a bear sitting in a chair drinking a milkshake", + "display": "textarea", + }, + "negative_prompt": { + "label": "Negative Prompt", + "type": "string", + "default": "deformed, ugly, wrong proportion, low res, bad anatomy, worst quality, low quality", + "display": "textarea", + }, + + "num_inference_steps": { + "label": "Steps", + "type": "int", + "default": 25, + "min": 1, + "max": 1000, + }, + "seed": { + "label": "Seed", + "type": "int", + "default": 0, + "min": 0, + "display": "random", + }, + "width": { + "label": "Width", + "type": "int", + "display": "text", + "default": 1024, + "min": 8, + "max": 8192, + "step": 8, + "group": "dimensions", + }, + "height": { + "label": "Height", + "type": "int", + "display": "text", + "default": 1024, + "min": 8, + "max": 8192, + "step": 8, + "group": "dimensions", + }, + "images": { + "label": "Images", + "type": "image", + "display": "output", + }, + "image": { + "label": "Image", + "type": "image", + "display": "input", + }, +} + +DEFAULT_TYPE_MAPS ={ + "int": { + "type": "int", + "default": 0, + "min": 0, + }, + "float": { + "type": "float", + "default": 0.0, + "min": 0.0, + }, + "str": { + "type": "string", + "default": "", + }, + "bool": { + "type": "boolean", + "default": False, + }, + "image": { + "type": "image", + }, +} + +DEFAULT_MODEL_KEYS = ["unet", "vae", "text_encoder", "tokenizer", "controlnet", "transformer", "image_encoder"] +DEFAULT_CATEGORY = "Modular Diffusers" +DEFAULT_EXCLUDE_MODEL_KEYS = ["processor", "feature_extractor", "safety_checker"] +DEFAULT_PARAMS_GROUPS_KEYS = { + "text_encoders": ["text_encoder", "tokenizer"], + "ip_adapter_embeds": ["ip_adapter_embeds"], + "prompt_embeddings": ["prompt_embeds"], +} + + +def get_group_name(name, group_params_keys=DEFAULT_PARAMS_GROUPS_KEYS): + """ + Get the group name for a given parameter name, if not part of a group, return None + e.g. "prompt_embeds" -> "text_embeds", "text_encoder" -> "text_encoders", "prompt" -> None + """ + if name is None: + return None + for group_name, group_keys in group_params_keys.items(): + for group_key in group_keys: + if group_key in name: + return group_name + return None + + +class ModularNode(ConfigMixin): + + config_name = "node_config.json" + + @classmethod + def from_pretrained( + cls, + pretrained_model_name_or_path: str, + trust_remote_code: Optional[bool] = None, + **kwargs, + ): + blocks = ModularPipelineMixin.from_pretrained(pretrained_model_name_or_path, trust_remote_code=trust_remote_code, **kwargs) + return cls(blocks, **kwargs) + + def __init__(self, blocks, category=DEFAULT_CATEGORY, label=None, **kwargs): + self.blocks = blocks + + if label is None: + label = self.blocks.__class__.__name__ + # blocks param name -> mellon param name + self.name_mapping = {} + + input_params = {} + # pass or create a default param dict for each input + # e.g. for prompt, + # prompt = { + # "name": "text_input", # the name of the input in node defination, could be different from the input name in diffusers + # "label": "Prompt", + # "type": "string", + # "default": "a bear sitting in a chair drinking a milkshake", + # "display": "textarea"} + # if type is not specified, it'll be a "custom" param of its own type + # e.g. you can pass ModularNode(scheduler = {name :"scheduler"}) + # it will get this spec in node defination {"scheduler": {"label": "Scheduler", "type": "scheduler", "display": "input"}} + # name can be a dict, in that case, it is part of a "dict" input in mellon nodes, e.g. text_encoder= {name: {"text_encoders": "text_encoder"}} + inputs = self.blocks.inputs + self.blocks.intermediates_inputs + for inp in inputs: + param = kwargs.pop(inp.name, None) + if param: + # user can pass a param dict for all inputs, e.g. ModularNode(prompt = {...}) + input_params[inp.name] = param + mellon_name = param.pop("name", inp.name) + if mellon_name != inp.name: + self.name_mapping[inp.name] = mellon_name + continue + + if not inp.name in DEFAULT_PARAM_MAPS and not inp.required and not get_group_name(inp.name): + continue + + if inp.name in DEFAULT_PARAM_MAPS: + # first check if it's in the default param map, if so, directly use that + param = DEFAULT_PARAM_MAPS[inp.name].copy() + elif get_group_name(inp.name): + param = get_group_name(inp.name) + if inp.name not in self.name_mapping: + self.name_mapping[inp.name] = param + else: + # if not, check if it's in the SDXL input schema, if so, + # 1. use the type hint to determine the type + # 2. use the default param dict for the type e.g. if "steps" is a "int" type, {"steps": {"type": "int", "default": 0, "min": 0}} + if inp.type_hint is not None: + type_str = str(inp.type_hint).lower() + else: + inp_spec = SDXL_PARAM_SCHEMA.get(inp.name, None) + type_str = str(inp_spec.type_hint).lower() if inp_spec else "" + for type_key, type_param in DEFAULT_TYPE_MAPS.items(): + if type_key in type_str: + param = type_param.copy() + param["label"] = inp.name + param["display"] = "input" + break + else: + param = inp.name + # add the param dict to the inp_params dict + input_params[inp.name] = param + + + component_params = {} + for comp in self.blocks.expected_components: + param = kwargs.pop(comp.name, None) + if param: + component_params[comp.name] = param + mellon_name = param.pop("name", comp.name) + if mellon_name != comp.name: + self.name_mapping[comp.name] = mellon_name + continue + + to_exclude = False + for exclude_key in DEFAULT_EXCLUDE_MODEL_KEYS: + if exclude_key in comp.name: + to_exclude = True + break + if to_exclude: + continue + + if get_group_name(comp.name): + param = get_group_name(comp.name) + if comp.name not in self.name_mapping: + self.name_mapping[comp.name] = param + elif comp.name in DEFAULT_MODEL_KEYS: + param = {"label": comp.name, "type": "diffusers_auto_model", "display": "input"} + else: + param = comp.name + # add the param dict to the model_params dict + component_params[comp.name] = param + + output_params = {} + if isinstance(self.blocks, SequentialPipelineBlocks): + last_block_name = list(self.blocks.blocks.keys())[-1] + outputs = self.blocks.blocks[last_block_name].intermediates_outputs + else: + outputs = self.blocks.intermediates_outputs + + for out in outputs: + param = kwargs.pop(out.name, None) + if param: + output_params[out.name] = param + mellon_name = param.pop("name", out.name) + if mellon_name != out.name: + self.name_mapping[out.name] = mellon_name + continue + + if out.name in DEFAULT_PARAM_MAPS: + param = DEFAULT_PARAM_MAPS[out.name].copy() + param["display"] = "output" + else: + group_name = get_group_name(out.name) + if group_name: + param = group_name + if out.name not in self.name_mapping: + self.name_mapping[out.name] = param + else: + param = out.name + # add the param dict to the outputs dict + output_params[out.name] = param + + if len(kwargs) > 0: + logger.warning(f"Unused kwargs: {kwargs}") + + register_dict = { + "category": category, + "label": label, + "input_params": input_params, + "component_params": component_params, + "output_params": output_params, + "name_mapping": self.name_mapping, + } + self.register_to_config(**register_dict) + + def setup(self, components, collection=None): + self.blocks.setup_loader(component_manager=components, collection=collection) + self._components_manager = components + + @property + def mellon_config(self): + return self._convert_to_mellon_config() + + def _convert_to_mellon_config(self): + + node = {} + node["label"] = self.config.label + node["category"] = self.config.category + + node_param = {} + for inp_name, inp_param in self.config.input_params.items(): + if inp_name in self.name_mapping: + mellon_name = self.name_mapping[inp_name] + else: + mellon_name = inp_name + if isinstance(inp_param, str): + param = { + "label": inp_param, + "type": inp_param, + "display": "input", + } + else: + param = inp_param + + if mellon_name not in node_param: + node_param[mellon_name] = param + else: + logger.debug(f"Input param {mellon_name} already exists in node_param, skipping {inp_name}") + + + for comp_name, comp_param in self.config.component_params.items(): + if comp_name in self.name_mapping: + mellon_name = self.name_mapping[comp_name] + else: + mellon_name = comp_name + if isinstance(comp_param, str): + param = { + "label": comp_param, + "type": comp_param, + "display": "input", + } + else: + param = comp_param + + if mellon_name not in node_param: + node_param[mellon_name] = param + else: + logger.debug(f"Component param {comp_param} already exists in node_param, skipping {comp_name}") + + + for out_name, out_param in self.config.output_params.items(): + if out_name in self.name_mapping: + mellon_name = self.name_mapping[out_name] + else: + mellon_name = out_name + if isinstance(out_param, str): + param = { + "label": out_param, + "type": out_param, + "display": "output", + } + else: + param = out_param + + if mellon_name not in node_param: + node_param[mellon_name] = param + else: + logger.debug(f"Output param {out_param} already exists in node_param, skipping {out_name}") + node["params"] = node_param + return node + + def save_mellon_config(self, file_path): + """ + Save the Mellon configuration to a JSON file. + + Args: + file_path (str or Path): Path where the JSON file will be saved + + Returns: + Path: Path to the saved config file + """ + file_path = Path(file_path) + + # Create directory if it doesn't exist + os.makedirs(file_path.parent, exist_ok=True) + + # Create a combined dictionary with module definition and name mapping + config = { + "module": self.mellon_config, + "name_mapping": self.name_mapping + } + + # Save the config to file + with open(file_path, 'w', encoding='utf-8') as f: + json.dump(config, f, indent=2) + + logger.info(f"Mellon config and name mapping saved to {file_path}") + + return file_path + + @classmethod + def load_mellon_config(cls, file_path): + """ + Load a Mellon configuration from a JSON file. + + Args: + file_path (str or Path): Path to the JSON file containing Mellon config + + Returns: + dict: The loaded combined configuration containing 'module' and 'name_mapping' + """ + file_path = Path(file_path) + + if not file_path.exists(): + raise FileNotFoundError(f"Config file not found: {file_path}") + + with open(file_path, 'r', encoding='utf-8') as f: + config = json.load(f) + + logger.info(f"Mellon config loaded from {file_path}") + + + return config + + def process_inputs(self, **kwargs): + + params_components = {} + for comp_name, comp_param in self.config.component_params.items(): + logger.debug(f"component: {comp_name}") + mellon_comp_name = self.name_mapping.get(comp_name, comp_name) + if mellon_comp_name in kwargs: + if isinstance(kwargs[mellon_comp_name], dict) and comp_name in kwargs[mellon_comp_name]: + comp = kwargs[mellon_comp_name].pop(comp_name) + else: + comp = kwargs.pop(mellon_comp_name) + if comp: + params_components[comp_name] = self._components_manager.get_one(comp["model_id"]) + + + params_run = {} + for inp_name, inp_param in self.config.input_params.items(): + logger.debug(f"input: {inp_name}") + mellon_inp_name = self.name_mapping.get(inp_name, inp_name) + if mellon_inp_name in kwargs: + if isinstance(kwargs[mellon_inp_name], dict) and inp_name in kwargs[mellon_inp_name]: + inp = kwargs[mellon_inp_name].pop(inp_name) + else: + inp = kwargs.pop(mellon_inp_name) + if inp is not None: + params_run[inp_name] = inp + + return_output_names = list(self.config.output_params.keys()) + + return params_components, params_run, return_output_names + + def execute(self, **kwargs): + params_components, params_run, return_output_names = self.process_inputs(**kwargs) + + self.blocks.loader.update(**params_components) + output = self.blocks.run(**params_run, output=return_output_names) + return output + + + + + + + + + + + diff --git a/src/diffusers/modular_pipelines/stable_diffusion_xl/__init__.py b/src/diffusers/modular_pipelines/stable_diffusion_xl/__init__.py new file mode 100644 index 000000000000..f3f961d61a13 --- /dev/null +++ b/src/diffusers/modular_pipelines/stable_diffusion_xl/__init__.py @@ -0,0 +1,51 @@ +from typing import TYPE_CHECKING + +from ...utils import ( + DIFFUSERS_SLOW_IMPORT, + OptionalDependencyNotAvailable, + _LazyModule, + get_objects_from_module, + is_torch_available, + is_transformers_available, +) + + +_dummy_objects = {} +_import_structure = {} + +try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils import dummy_torch_and_transformers_objects # noqa F403 + + _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) +else: + _import_structure["modular_pipeline_presets"] = ["StableDiffusionXLAutoPipeline"] + _import_structure["modular_loader"] = ["StableDiffusionXLModularLoader"] + _import_structure["encoders"] = ["StableDiffusionXLAutoIPAdapterStep", "StableDiffusionXLTextEncoderStep", "StableDiffusionXLAutoVaeEncoderStep"] + _import_structure["decoders"] = ["StableDiffusionXLAutoDecodeStep"] + +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import * # noqa F403 + else: + from .modular_pipeline_presets import StableDiffusionXLAutoPipeline + from .modular_loader import StableDiffusionXLModularLoader + from .encoders import StableDiffusionXLAutoIPAdapterStep, StableDiffusionXLTextEncoderStep, StableDiffusionXLAutoVaeEncoderStep + from .decoders import StableDiffusionXLAutoDecodeStep +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) + + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) diff --git a/src/diffusers/modular_pipelines/stable_diffusion_xl/before_denoise.py b/src/diffusers/modular_pipelines/stable_diffusion_xl/before_denoise.py new file mode 100644 index 000000000000..07f096249c0d --- /dev/null +++ b/src/diffusers/modular_pipelines/stable_diffusion_xl/before_denoise.py @@ -0,0 +1,1764 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Any, List, Optional, Tuple, Union, Dict + +import PIL +import torch +from collections import OrderedDict + +from ...image_processor import VaeImageProcessor, PipelineImageInput +from ...loaders import StableDiffusionXLLoraLoaderMixin, TextualInversionLoaderMixin, ModularIPAdapterMixin +from ...models import ControlNetModel, UNet2DConditionModel, AutoencoderKL, ControlNetUnionModel +from ...utils import logging +from ...utils.torch_utils import randn_tensor, unwrap_module + +from ...pipelines.pipeline_utils import DiffusionPipeline, StableDiffusionMixin +from ...pipelines.controlnet.multicontrolnet import MultiControlNetModel +from ...schedulers import EulerDiscreteScheduler +from ...configuration_utils import FrozenDict + +from .modular_loader import StableDiffusionXLModularLoader +from ..modular_pipeline_utils import ComponentSpec, ConfigSpec, InputParam, OutputParam +from ..modular_pipeline import ( + AutoPipelineBlocks, + ModularLoader, + PipelineBlock, + PipelineState, + SequentialPipelineBlocks, +) + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + + + + +# TODO(yiyi, aryan): We need another step before text encoder to set the `num_inference_steps` attribute for guider so that +# things like when to do guidance and how many conditions to be prepared can be determined. Currently, this is done by +# always assuming you want to do guidance in the Guiders. So, negative embeddings are prepared regardless of what the +# configuration of guider is. + + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents( + encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + +def prepare_latents_img2img(vae, scheduler, image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None, add_noise=True): + + if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)): + raise ValueError( + f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}" + ) + + image = image.to(device=device, dtype=dtype) + + batch_size = batch_size * num_images_per_prompt + + if image.shape[1] == 4: + init_latents = image + + else: + latents_mean = latents_std = None + if hasattr(vae.config, "latents_mean") and vae.config.latents_mean is not None: + latents_mean = torch.tensor(vae.config.latents_mean).view(1, 4, 1, 1) + if hasattr(vae.config, "latents_std") and vae.config.latents_std is not None: + latents_std = torch.tensor(vae.config.latents_std).view(1, 4, 1, 1) + # make sure the VAE is in float32 mode, as it overflows in float16 + if vae.config.force_upcast: + image = image.float() + vae.to(dtype=torch.float32) + + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + elif isinstance(generator, list): + if image.shape[0] < batch_size and batch_size % image.shape[0] == 0: + image = torch.cat([image] * (batch_size // image.shape[0]), dim=0) + elif image.shape[0] < batch_size and batch_size % image.shape[0] != 0: + raise ValueError( + f"Cannot duplicate `image` of batch size {image.shape[0]} to effective batch_size {batch_size} " + ) + + init_latents = [ + retrieve_latents(vae.encode(image[i : i + 1]), generator=generator[i]) + for i in range(batch_size) + ] + init_latents = torch.cat(init_latents, dim=0) + else: + init_latents = retrieve_latents(vae.encode(image), generator=generator) + + if vae.config.force_upcast: + vae.to(dtype) + + init_latents = init_latents.to(dtype) + if latents_mean is not None and latents_std is not None: + latents_mean = latents_mean.to(device=device, dtype=dtype) + latents_std = latents_std.to(device=device, dtype=dtype) + init_latents = (init_latents - latents_mean) * vae.config.scaling_factor / latents_std + else: + init_latents = vae.config.scaling_factor * init_latents + + if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0: + # expand init_latents for batch_size + additional_image_per_prompt = batch_size // init_latents.shape[0] + init_latents = torch.cat([init_latents] * additional_image_per_prompt, dim=0) + elif batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0: + raise ValueError( + f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts." + ) + else: + init_latents = torch.cat([init_latents], dim=0) + + if add_noise: + shape = init_latents.shape + noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + # get latents + init_latents = scheduler.add_noise(init_latents, noise, timestep) + + latents = init_latents + + return latents + + +class StableDiffusionXLInputStep(PipelineBlock): + model_name = "stable-diffusion-xl" + + @property + def description(self) -> str: + return ( + "Input processing step that:\n" + " 1. Determines `batch_size` and `dtype` based on `prompt_embeds`\n" + " 2. Adjusts input tensor shapes based on `batch_size` (number of prompts) and `num_images_per_prompt`\n\n" + "All input tensors are expected to have either batch_size=1 or match the batch_size\n" + "of prompt_embeds. The tensors will be duplicated across the batch dimension to\n" + "have a final batch_size of batch_size * num_images_per_prompt." + ) + + @property + def inputs(self) -> List[InputParam]: + return [ + InputParam("num_images_per_prompt", default=1), + ] + + @property + def intermediates_inputs(self) -> List[str]: + return [ + InputParam("prompt_embeds", required=True, type_hint=torch.Tensor, description="Pre-generated text embeddings. Can be generated from text_encoder step."), + InputParam("negative_prompt_embeds", type_hint=torch.Tensor, description="Pre-generated negative text embeddings. Can be generated from text_encoder step."), + InputParam("pooled_prompt_embeds", required=True, type_hint=torch.Tensor, description="Pre-generated pooled text embeddings. Can be generated from text_encoder step."), + InputParam("negative_pooled_prompt_embeds", description="Pre-generated negative pooled text embeddings. Can be generated from text_encoder step."), + InputParam("ip_adapter_embeds", type_hint=List[torch.Tensor], description="Pre-generated image embeddings for IP-Adapter. Can be generated from ip_adapter step."), + InputParam("negative_ip_adapter_embeds", type_hint=List[torch.Tensor], description="Pre-generated negative image embeddings for IP-Adapter. Can be generated from ip_adapter step."), + ] + + @property + def intermediates_outputs(self) -> List[str]: + return [ + OutputParam("batch_size", type_hint=int, description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt"), + OutputParam("dtype", type_hint=torch.dtype, description="Data type of model tensor inputs (determined by `prompt_embeds`)"), + OutputParam("prompt_embeds", type_hint=torch.Tensor, kwargs_type="guider_input_fields", description="text embeddings used to guide the image generation"), + OutputParam("negative_prompt_embeds", type_hint=torch.Tensor, kwargs_type="guider_input_fields", description="negative text embeddings used to guide the image generation"), + OutputParam("pooled_prompt_embeds", type_hint=torch.Tensor, kwargs_type="guider_input_fields", description="pooled text embeddings used to guide the image generation"), + OutputParam("negative_pooled_prompt_embeds", type_hint=torch.Tensor, kwargs_type="guider_input_fields", description="negative pooled text embeddings used to guide the image generation"), + OutputParam("ip_adapter_embeds", type_hint=List[torch.Tensor], kwargs_type="guider_input_fields", description="image embeddings for IP-Adapter"), + OutputParam("negative_ip_adapter_embeds", type_hint=List[torch.Tensor], kwargs_type="guider_input_fields", description="negative image embeddings for IP-Adapter"), + ] + + def check_inputs(self, components, block_state): + + if block_state.prompt_embeds is not None and block_state.negative_prompt_embeds is not None: + if block_state.prompt_embeds.shape != block_state.negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {block_state.prompt_embeds.shape} != `negative_prompt_embeds`" + f" {block_state.negative_prompt_embeds.shape}." + ) + + if block_state.prompt_embeds is not None and block_state.pooled_prompt_embeds is None: + raise ValueError( + "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`." + ) + + if block_state.negative_prompt_embeds is not None and block_state.negative_pooled_prompt_embeds is None: + raise ValueError( + "If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`." + ) + + if block_state.ip_adapter_embeds is not None and not isinstance(block_state.ip_adapter_embeds, list): + raise ValueError("`ip_adapter_embeds` must be a list") + + if block_state.negative_ip_adapter_embeds is not None and not isinstance(block_state.negative_ip_adapter_embeds, list): + raise ValueError("`negative_ip_adapter_embeds` must be a list") + + if block_state.ip_adapter_embeds is not None and block_state.negative_ip_adapter_embeds is not None: + for i, ip_adapter_embed in enumerate(block_state.ip_adapter_embeds): + if ip_adapter_embed.shape != block_state.negative_ip_adapter_embeds[i].shape: + raise ValueError( + "`ip_adapter_embeds` and `negative_ip_adapter_embeds` must have the same shape when passed directly, but" + f" got: `ip_adapter_embeds` {ip_adapter_embed.shape} != `negative_ip_adapter_embeds`" + f" {block_state.negative_ip_adapter_embeds[i].shape}." + ) + + @torch.no_grad() + def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + self.check_inputs(components, block_state) + + block_state.batch_size = block_state.prompt_embeds.shape[0] + block_state.dtype = block_state.prompt_embeds.dtype + + _, seq_len, _ = block_state.prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + block_state.prompt_embeds = block_state.prompt_embeds.repeat(1, block_state.num_images_per_prompt, 1) + block_state.prompt_embeds = block_state.prompt_embeds.view(block_state.batch_size * block_state.num_images_per_prompt, seq_len, -1) + + if block_state.negative_prompt_embeds is not None: + _, seq_len, _ = block_state.negative_prompt_embeds.shape + block_state.negative_prompt_embeds = block_state.negative_prompt_embeds.repeat(1, block_state.num_images_per_prompt, 1) + block_state.negative_prompt_embeds = block_state.negative_prompt_embeds.view(block_state.batch_size * block_state.num_images_per_prompt, seq_len, -1) + + block_state.pooled_prompt_embeds = block_state.pooled_prompt_embeds.repeat(1, block_state.num_images_per_prompt, 1) + block_state.pooled_prompt_embeds = block_state.pooled_prompt_embeds.view(block_state.batch_size * block_state.num_images_per_prompt, -1) + + if block_state.negative_pooled_prompt_embeds is not None: + block_state.negative_pooled_prompt_embeds = block_state.negative_pooled_prompt_embeds.repeat(1, block_state.num_images_per_prompt, 1) + block_state.negative_pooled_prompt_embeds = block_state.negative_pooled_prompt_embeds.view(block_state.batch_size * block_state.num_images_per_prompt, -1) + + if block_state.ip_adapter_embeds is not None: + for i, ip_adapter_embed in enumerate(block_state.ip_adapter_embeds): + block_state.ip_adapter_embeds[i] = torch.cat([ip_adapter_embed] * block_state.num_images_per_prompt, dim=0) + + if block_state.negative_ip_adapter_embeds is not None: + for i, negative_ip_adapter_embed in enumerate(block_state.negative_ip_adapter_embeds): + block_state.negative_ip_adapter_embeds[i] = torch.cat([negative_ip_adapter_embed] * block_state.num_images_per_prompt, dim=0) + + self.add_block_state(state, block_state) + + return components, state + + +class StableDiffusionXLImg2ImgSetTimestepsStep(PipelineBlock): + model_name = "stable-diffusion-xl" + + @property + def expected_components(self) -> List[ComponentSpec]: + return [ + ComponentSpec("scheduler", EulerDiscreteScheduler), + ] + + @property + def description(self) -> str: + return ( + "Step that sets the timesteps for the scheduler and determines the initial noise level (latent_timestep) for image-to-image/inpainting generation.\n" + \ + "The latent_timestep is calculated from the `strength` parameter - higher strength means starting from a noisier version of the input image." + ) + + @property + def inputs(self) -> List[InputParam]: + return [ + InputParam("num_inference_steps", default=50), + InputParam("timesteps"), + InputParam("sigmas"), + InputParam("denoising_end"), + InputParam("strength", default=0.3), + InputParam("denoising_start"), + # YiYi TODO: do we need num_images_per_prompt here? + InputParam("num_images_per_prompt", default=1), + ] + + @property + def intermediates_inputs(self) -> List[str]: + return [ + InputParam("batch_size", required=True, type_hint=int, description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt"), + ] + + @property + def intermediates_outputs(self) -> List[str]: + return [ + OutputParam("timesteps", type_hint=torch.Tensor, description="The timesteps to use for inference"), + OutputParam("num_inference_steps", type_hint=int, description="The number of denoising steps to perform at inference time"), + OutputParam("latent_timestep", type_hint=torch.Tensor, description="The timestep that represents the initial noise level for image-to-image generation") + ] + + # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_img2img.StableDiffusionXLImg2ImgPipeline.get_timesteps with self -> components + def get_timesteps(self, components, num_inference_steps, strength, device, denoising_start=None): + # get the original timestep using init_timestep + if denoising_start is None: + init_timestep = min(int(num_inference_steps * strength), num_inference_steps) + t_start = max(num_inference_steps - init_timestep, 0) + + timesteps = components.scheduler.timesteps[t_start * components.scheduler.order :] + if hasattr(components.scheduler, "set_begin_index"): + components.scheduler.set_begin_index(t_start * components.scheduler.order) + + return timesteps, num_inference_steps - t_start + + else: + # Strength is irrelevant if we directly request a timestep to start at; + # that is, strength is determined by the denoising_start instead. + discrete_timestep_cutoff = int( + round( + components.scheduler.config.num_train_timesteps + - (denoising_start * components.scheduler.config.num_train_timesteps) + ) + ) + + num_inference_steps = (components.scheduler.timesteps < discrete_timestep_cutoff).sum().item() + if components.scheduler.order == 2 and num_inference_steps % 2 == 0: + # if the scheduler is a 2nd order scheduler we might have to do +1 + # because `num_inference_steps` might be even given that every timestep + # (except the highest one) is duplicated. If `num_inference_steps` is even it would + # mean that we cut the timesteps in the middle of the denoising step + # (between 1st and 2nd derivative) which leads to incorrect results. By adding 1 + # we ensure that the denoising process always ends after the 2nd derivate step of the scheduler + num_inference_steps = num_inference_steps + 1 + + # because t_n+1 >= t_n, we slice the timesteps starting from the end + t_start = len(components.scheduler.timesteps) - num_inference_steps + timesteps = components.scheduler.timesteps[t_start:] + if hasattr(components.scheduler, "set_begin_index"): + components.scheduler.set_begin_index(t_start) + return timesteps, num_inference_steps + + + + @torch.no_grad() + def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + + block_state.device = components._execution_device + + block_state.timesteps, block_state.num_inference_steps = retrieve_timesteps( + components.scheduler, block_state.num_inference_steps, block_state.device, block_state.timesteps, block_state.sigmas + ) + + def denoising_value_valid(dnv): + return isinstance(dnv, float) and 0 < dnv < 1 + + block_state.timesteps, block_state.num_inference_steps = self.get_timesteps( + components, + block_state.num_inference_steps, + block_state.strength, + block_state.device, + denoising_start=block_state.denoising_start if denoising_value_valid(block_state.denoising_start) else None, + ) + block_state.latent_timestep = block_state.timesteps[:1].repeat(block_state.batch_size * block_state.num_images_per_prompt) + + if block_state.denoising_end is not None and isinstance(block_state.denoising_end, float) and block_state.denoising_end > 0 and block_state.denoising_end < 1: + block_state.discrete_timestep_cutoff = int( + round( + components.scheduler.config.num_train_timesteps + - (block_state.denoising_end * components.scheduler.config.num_train_timesteps) + ) + ) + block_state.num_inference_steps = len(list(filter(lambda ts: ts >= block_state.discrete_timestep_cutoff, block_state.timesteps))) + block_state.timesteps = block_state.timesteps[:block_state.num_inference_steps] + + self.add_block_state(state, block_state) + + return components, state + + +class StableDiffusionXLSetTimestepsStep(PipelineBlock): + + model_name = "stable-diffusion-xl" + + @property + def expected_components(self) -> List[ComponentSpec]: + return [ + ComponentSpec("scheduler", EulerDiscreteScheduler), + ] + + @property + def description(self) -> str: + return ( + "Step that sets the scheduler's timesteps for inference" + ) + + @property + def inputs(self) -> List[InputParam]: + return [ + InputParam("num_inference_steps", default=50), + InputParam("timesteps"), + InputParam("sigmas"), + InputParam("denoising_end"), + ] + + @property + def intermediates_outputs(self) -> List[OutputParam]: + return [OutputParam("timesteps", type_hint=torch.Tensor, description="The timesteps to use for inference"), + OutputParam("num_inference_steps", type_hint=int, description="The number of denoising steps to perform at inference time")] + + + @torch.no_grad() + def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + + block_state.device = components._execution_device + + block_state.timesteps, block_state.num_inference_steps = retrieve_timesteps( + components.scheduler, block_state.num_inference_steps, block_state.device, block_state.timesteps, block_state.sigmas + ) + + if block_state.denoising_end is not None and isinstance(block_state.denoising_end, float) and block_state.denoising_end > 0 and block_state.denoising_end < 1: + block_state.discrete_timestep_cutoff = int( + round( + components.scheduler.config.num_train_timesteps + - (block_state.denoising_end * components.scheduler.config.num_train_timesteps) + ) + ) + block_state.num_inference_steps = len(list(filter(lambda ts: ts >= block_state.discrete_timestep_cutoff, block_state.timesteps))) + block_state.timesteps = block_state.timesteps[:block_state.num_inference_steps] + + self.add_block_state(state, block_state) + return components, state + + +class StableDiffusionXLInpaintPrepareLatentsStep(PipelineBlock): + + model_name = "stable-diffusion-xl" + + @property + def expected_components(self) -> List[ComponentSpec]: + return [ + ComponentSpec("scheduler", EulerDiscreteScheduler), + ] + + @property + def description(self) -> str: + return ( + "Step that prepares the latents for the inpainting process" + ) + + @property + def inputs(self) -> List[Tuple[str, Any]]: + return [ + InputParam("latents"), + InputParam("num_images_per_prompt", default=1), + InputParam("denoising_start"), + InputParam( + "strength", + default=0.9999, + description="Conceptually, indicates how much to transform the reference `image` (the masked portion of image for inpainting). Must be between 0 and 1. `image` " + "will be used as a starting point, adding more noise to it the larger the `strength`. The number of " + "denoising steps depends on the amount of noise initially added. When `strength` is 1, added noise will " + "be maximum and the denoising process will run for the full number of iterations specified in " + "`num_inference_steps`. A value of 1, therefore, essentially ignores `image`. Note that in the case of " + "`denoising_start` being declared as an integer, the value of `strength` will be ignored." + ), + ] + + @property + def intermediates_inputs(self) -> List[str]: + return [ + InputParam("generator"), + InputParam( + "batch_size", + required=True, + type_hint=int, + description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can be generated in input step." + ), + InputParam( + "latent_timestep", + required=True, + type_hint=torch.Tensor, + description="The timestep that represents the initial noise level for image-to-image/inpainting generation. Can be generated in set_timesteps step." + ), + InputParam( + "image_latents", + required=True, + type_hint=torch.Tensor, + description="The latents representing the reference image for image-to-image/inpainting generation. Can be generated in vae_encode step." + ), + InputParam( + "mask", + required=True, + type_hint=torch.Tensor, + description="The mask for the inpainting generation. Can be generated in vae_encode step." + ), + InputParam( + "masked_image_latents", + type_hint=torch.Tensor, + description="The masked image latents for the inpainting generation (only for inpainting-specific unet). Can be generated in vae_encode step." + ), + InputParam( + "dtype", + type_hint=torch.dtype, + description="The dtype of the model inputs" + ) + ] + + @property + def intermediates_outputs(self) -> List[str]: + return [OutputParam("latents", type_hint=torch.Tensor, description="The initial latents to use for the denoising process"), + OutputParam("mask", type_hint=torch.Tensor, description="The mask to use for inpainting generation"), + OutputParam("masked_image_latents", type_hint=torch.Tensor, description="The masked image latents to use for the inpainting generation (only for inpainting-specific unet)"), + OutputParam("noise", type_hint=torch.Tensor, description="The noise added to the image latents, used for inpainting generation")] + + + # Modified from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_inpaint.StableDiffusionXLInpaintPipeline._encode_vae_image with self -> components + # YiYi TODO: update the _encode_vae_image so that we can use #Coped from + @staticmethod + def _encode_vae_image(components, image: torch.Tensor, generator: torch.Generator): + + latents_mean = latents_std = None + if hasattr(components.vae.config, "latents_mean") and components.vae.config.latents_mean is not None: + latents_mean = torch.tensor(components.vae.config.latents_mean).view(1, 4, 1, 1) + if hasattr(components.vae.config, "latents_std") and components.vae.config.latents_std is not None: + latents_std = torch.tensor(components.vae.config.latents_std).view(1, 4, 1, 1) + + dtype = image.dtype + if components.vae.config.force_upcast: + image = image.float() + components.vae.to(dtype=torch.float32) + + if isinstance(generator, list): + image_latents = [ + retrieve_latents(components.vae.encode(image[i : i + 1]), generator=generator[i]) + for i in range(image.shape[0]) + ] + image_latents = torch.cat(image_latents, dim=0) + else: + image_latents = retrieve_latents(components.vae.encode(image), generator=generator) + + if components.vae.config.force_upcast: + components.vae.to(dtype) + + image_latents = image_latents.to(dtype) + if latents_mean is not None and latents_std is not None: + latents_mean = latents_mean.to(device=image_latents.device, dtype=dtype) + latents_std = latents_std.to(device=image_latents.device, dtype=dtype) + image_latents = (image_latents - latents_mean) * components.vae.config.scaling_factor / latents_std + else: + image_latents = components.vae.config.scaling_factor * image_latents + + return image_latents + + # Modified from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_inpaint.StableDiffusionXLInpaintPipeline.prepare_latents adding components as first argument + def prepare_latents_inpaint( + self, + components, + batch_size, + num_channels_latents, + height, + width, + dtype, + device, + generator, + latents=None, + image=None, + timestep=None, + is_strength_max=True, + add_noise=True, + return_noise=False, + return_image_latents=False, + ): + shape = ( + batch_size, + num_channels_latents, + int(height) // components.vae_scale_factor, + int(width) // components.vae_scale_factor, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if (image is None or timestep is None) and not is_strength_max: + raise ValueError( + "Since strength < 1. initial latents are to be initialised as a combination of Image + Noise." + "However, either the image or the noise timestep has not been provided." + ) + + if image.shape[1] == 4: + image_latents = image.to(device=device, dtype=dtype) + image_latents = image_latents.repeat(batch_size // image_latents.shape[0], 1, 1, 1) + elif return_image_latents or (latents is None and not is_strength_max): + image = image.to(device=device, dtype=dtype) + image_latents = self._encode_vae_image(components, image=image, generator=generator) + image_latents = image_latents.repeat(batch_size // image_latents.shape[0], 1, 1, 1) + + if latents is None and add_noise: + noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + # if strength is 1. then initialise the latents to noise, else initial to image + noise + latents = noise if is_strength_max else components.scheduler.add_noise(image_latents, noise, timestep) + # if pure noise then scale the initial latents by the Scheduler's init sigma + latents = latents * components.scheduler.init_noise_sigma if is_strength_max else latents + elif add_noise: + noise = latents.to(device) + latents = noise * components.scheduler.init_noise_sigma + else: + noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + latents = image_latents.to(device) + + outputs = (latents,) + + if return_noise: + outputs += (noise,) + + if return_image_latents: + outputs += (image_latents,) + + return outputs + + # modified from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_inpaint.StableDiffusionXLInpaintPipeline.prepare_mask_latents + # do not accept do_classifier_free_guidance + def prepare_mask_latents( + self, components, mask, masked_image, batch_size, height, width, dtype, device, generator + ): + # resize the mask to latents shape as we concatenate the mask to the latents + # we do that before converting to dtype to avoid breaking in case we're using cpu_offload + # and half precision + mask = torch.nn.functional.interpolate( + mask, size=(height // components.vae_scale_factor, width // components.vae_scale_factor) + ) + mask = mask.to(device=device, dtype=dtype) + + # duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method + if mask.shape[0] < batch_size: + if not batch_size % mask.shape[0] == 0: + raise ValueError( + "The passed mask and the required batch size don't match. Masks are supposed to be duplicated to" + f" a total batch size of {batch_size}, but {mask.shape[0]} masks were passed. Make sure the number" + " of masks that you pass is divisible by the total requested batch size." + ) + mask = mask.repeat(batch_size // mask.shape[0], 1, 1, 1) + + if masked_image is not None and masked_image.shape[1] == 4: + masked_image_latents = masked_image + else: + masked_image_latents = None + + if masked_image is not None: + if masked_image_latents is None: + masked_image = masked_image.to(device=device, dtype=dtype) + masked_image_latents = self._encode_vae_image(components, masked_image, generator=generator) + + if masked_image_latents.shape[0] < batch_size: + if not batch_size % masked_image_latents.shape[0] == 0: + raise ValueError( + "The passed images and the required batch size don't match. Images are supposed to be duplicated" + f" to a total batch size of {batch_size}, but {masked_image_latents.shape[0]} images were passed." + " Make sure the number of images that you pass is divisible by the total requested batch size." + ) + masked_image_latents = masked_image_latents.repeat( + batch_size // masked_image_latents.shape[0], 1, 1, 1 + ) + + # aligning device to prevent device errors when concating it with the latent model input + masked_image_latents = masked_image_latents.to(device=device, dtype=dtype) + + return mask, masked_image_latents + + + @torch.no_grad() + def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + + block_state.dtype = block_state.dtype if block_state.dtype is not None else components.vae.dtype + block_state.device = components._execution_device + + block_state.is_strength_max = block_state.strength == 1.0 + + # for non-inpainting specific unet, we do not need masked_image_latents + if hasattr(components,"unet") and components.unet is not None: + if components.unet.config.in_channels == 4: + block_state.masked_image_latents = None + + block_state.add_noise = True if block_state.denoising_start is None else False + + block_state.height = block_state.image_latents.shape[-2] * components.vae_scale_factor + block_state.width = block_state.image_latents.shape[-1] * components.vae_scale_factor + + block_state.latents, block_state.noise = self.prepare_latents_inpaint( + components, + block_state.batch_size * block_state.num_images_per_prompt, + components.num_channels_latents, + block_state.height, + block_state.width, + block_state.dtype, + block_state.device, + block_state.generator, + block_state.latents, + image=block_state.image_latents, + timestep=block_state.latent_timestep, + is_strength_max=block_state.is_strength_max, + add_noise=block_state.add_noise, + return_noise=True, + return_image_latents=False, + ) + + # 7. Prepare mask latent variables + block_state.mask, block_state.masked_image_latents = self.prepare_mask_latents( + components, + block_state.mask, + block_state.masked_image_latents, + block_state.batch_size * block_state.num_images_per_prompt, + block_state.height, + block_state.width, + block_state.dtype, + block_state.device, + block_state.generator, + ) + + self.add_block_state(state, block_state) + + return components, state + + +class StableDiffusionXLImg2ImgPrepareLatentsStep(PipelineBlock): + model_name = "stable-diffusion-xl" + + @property + def expected_components(self) -> List[ComponentSpec]: + return [ + ComponentSpec("vae", AutoencoderKL), + ComponentSpec("scheduler", EulerDiscreteScheduler), + ] + + @property + def description(self) -> str: + return ( + "Step that prepares the latents for the image-to-image generation process" + ) + + @property + def inputs(self) -> List[Tuple[str, Any]]: + return [ + InputParam("latents"), + InputParam("num_images_per_prompt", default=1), + InputParam("denoising_start"), + ] + + @property + def intermediates_inputs(self) -> List[InputParam]: + return [ + InputParam("generator"), + InputParam("latent_timestep", required=True, type_hint=torch.Tensor, description="The timestep that represents the initial noise level for image-to-image/inpainting generation. Can be generated in set_timesteps step."), + InputParam("image_latents", required=True, type_hint=torch.Tensor, description="The latents representing the reference image for image-to-image/inpainting generation. Can be generated in vae_encode step."), + InputParam("batch_size", required=True, type_hint=int, description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can be generated in input step."), + InputParam("dtype", required=True, type_hint=torch.dtype, description="The dtype of the model inputs")] + + @property + def intermediates_outputs(self) -> List[OutputParam]: + return [OutputParam("latents", type_hint=torch.Tensor, description="The initial latents to use for the denoising process")] + + @torch.no_grad() + def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + + block_state.dtype = block_state.dtype if block_state.dtype is not None else components.vae.dtype + block_state.device = components._execution_device + block_state.add_noise = True if block_state.denoising_start is None else False + if block_state.latents is None: + block_state.latents = prepare_latents_img2img( + components.vae, + components.scheduler, + block_state.image_latents, + block_state.latent_timestep, + block_state.batch_size, + block_state.num_images_per_prompt, + block_state.dtype, + block_state.device, + block_state.generator, + block_state.add_noise, + ) + + self.add_block_state(state, block_state) + + return components, state + + +class StableDiffusionXLPrepareLatentsStep(PipelineBlock): + model_name = "stable-diffusion-xl" + + @property + def expected_components(self) -> List[ComponentSpec]: + return [ + ComponentSpec("scheduler", EulerDiscreteScheduler), + ] + + @property + def description(self) -> str: + return ( + "Prepare latents step that prepares the latents for the text-to-image generation process" + ) + + @property + def inputs(self) -> List[InputParam]: + return [ + InputParam("height"), + InputParam("width"), + InputParam("latents"), + InputParam("num_images_per_prompt", default=1), + ] + + @property + def intermediates_inputs(self) -> List[InputParam]: + return [ + InputParam("generator"), + InputParam( + "batch_size", + required=True, + type_hint=int, + description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can be generated in input step." + ), + InputParam( + "dtype", + type_hint=torch.dtype, + description="The dtype of the model inputs" + ) + ] + + @property + def intermediates_outputs(self) -> List[OutputParam]: + return [ + OutputParam( + "latents", + type_hint=torch.Tensor, + description="The initial latents to use for the denoising process" + ) + ] + + + @staticmethod + def check_inputs(components, block_state): + if ( + block_state.height is not None + and block_state.height % components.vae_scale_factor != 0 + or block_state.width is not None + and block_state.width % components.vae_scale_factor != 0 + ): + raise ValueError( + f"`height` and `width` have to be divisible by {components.vae_scale_factor} but are {block_state.height} and {block_state.width}." + ) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents with self -> components + @staticmethod + def prepare_latents(components, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): + shape = ( + batch_size, + num_channels_latents, + int(height) // components.vae_scale_factor, + int(width) // components.vae_scale_factor, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * components.scheduler.init_noise_sigma + return latents + + + @torch.no_grad() + def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + + if block_state.dtype is None: + block_state.dtype = components.vae.dtype + + block_state.device = components._execution_device + + self.check_inputs(components, block_state) + + block_state.height = block_state.height or components.default_sample_size * components.vae_scale_factor + block_state.width = block_state.width or components.default_sample_size * components.vae_scale_factor + block_state.num_channels_latents = components.num_channels_latents + block_state.latents = self.prepare_latents( + components, + block_state.batch_size * block_state.num_images_per_prompt, + block_state.num_channels_latents, + block_state.height, + block_state.width, + block_state.dtype, + block_state.device, + block_state.generator, + block_state.latents, + ) + + self.add_block_state(state, block_state) + + return components, state + + +class StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep(PipelineBlock): + + model_name = "stable-diffusion-xl" + + @property + def expected_configs(self) -> List[ConfigSpec]: + return [ConfigSpec("requires_aesthetics_score", False),] + + @property + def description(self) -> str: + return ( + "Step that prepares the additional conditioning for the image-to-image/inpainting generation process" + ) + + @property + def inputs(self) -> List[Tuple[str, Any]]: + return [ + InputParam("original_size"), + InputParam("target_size"), + InputParam("negative_original_size"), + InputParam("negative_target_size"), + InputParam("crops_coords_top_left", default=(0, 0)), + InputParam("negative_crops_coords_top_left", default=(0, 0)), + InputParam("num_images_per_prompt", default=1), + InputParam("aesthetic_score", default=6.0), + InputParam("negative_aesthetic_score", default=2.0), + ] + + @property + def intermediates_inputs(self) -> List[InputParam]: + return [ + InputParam("latents", required=True, type_hint=torch.Tensor, description="The initial latents to use for the denoising process. Can be generated in prepare_latent step."), + InputParam("pooled_prompt_embeds", required=True, type_hint=torch.Tensor, description="The pooled prompt embeddings to use for the denoising process (used to determine shapes and dtypes for other additional conditioning inputs). Can be generated in text_encoder step."), + InputParam("batch_size", required=True, type_hint=int, description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can be generated in input step."), + ] + + @property + def intermediates_outputs(self) -> List[OutputParam]: + return [OutputParam("add_time_ids", type_hint=torch.Tensor, kwargs_type="guider_input_fields", description="The time ids to condition the denoising process"), + OutputParam("negative_add_time_ids", type_hint=torch.Tensor, kwargs_type="guider_input_fields", description="The negative time ids to condition the denoising process"), + OutputParam("timestep_cond", type_hint=torch.Tensor, description="The timestep cond to use for LCM")] + + # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_img2img.StableDiffusionXLImg2ImgPipeline._get_add_time_ids with self -> components + @staticmethod + def _get_add_time_ids_img2img( + components, + original_size, + crops_coords_top_left, + target_size, + aesthetic_score, + negative_aesthetic_score, + negative_original_size, + negative_crops_coords_top_left, + negative_target_size, + dtype, + text_encoder_projection_dim=None, + ): + if components.config.requires_aesthetics_score: + add_time_ids = list(original_size + crops_coords_top_left + (aesthetic_score,)) + add_neg_time_ids = list( + negative_original_size + negative_crops_coords_top_left + (negative_aesthetic_score,) + ) + else: + add_time_ids = list(original_size + crops_coords_top_left + target_size) + add_neg_time_ids = list(negative_original_size + crops_coords_top_left + negative_target_size) + + passed_add_embed_dim = ( + components.unet.config.addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim + ) + expected_add_embed_dim = components.unet.add_embedding.linear_1.in_features + + if ( + expected_add_embed_dim > passed_add_embed_dim + and (expected_add_embed_dim - passed_add_embed_dim) == components.unet.config.addition_time_embed_dim + ): + raise ValueError( + f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. Please make sure to enable `requires_aesthetics_score` with `pipe.register_to_config(requires_aesthetics_score=True)` to make sure `aesthetic_score` {aesthetic_score} and `negative_aesthetic_score` {negative_aesthetic_score} is correctly used by the model." + ) + elif ( + expected_add_embed_dim < passed_add_embed_dim + and (passed_add_embed_dim - expected_add_embed_dim) == components.unet.config.addition_time_embed_dim + ): + raise ValueError( + f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. Please make sure to disable `requires_aesthetics_score` with `pipe.register_to_config(requires_aesthetics_score=False)` to make sure `target_size` {target_size} is correctly used by the model." + ) + elif expected_add_embed_dim != passed_add_embed_dim: + raise ValueError( + f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`." + ) + + add_time_ids = torch.tensor([add_time_ids], dtype=dtype) + add_neg_time_ids = torch.tensor([add_neg_time_ids], dtype=dtype) + + return add_time_ids, add_neg_time_ids + + # Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding + @staticmethod + def get_guidance_scale_embedding( + w: torch.Tensor, embedding_dim: int = 512, dtype: torch.dtype = torch.float32 + ) -> torch.Tensor: + """ + See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298 + + Args: + w (`torch.Tensor`): + Generate embedding vectors with a specified guidance scale to subsequently enrich timestep embeddings. + embedding_dim (`int`, *optional*, defaults to 512): + Dimension of the embeddings to generate. + dtype (`torch.dtype`, *optional*, defaults to `torch.float32`): + Data type of the generated embeddings. + + Returns: + `torch.Tensor`: Embedding vectors with shape `(len(w), embedding_dim)`. + """ + assert len(w.shape) == 1 + w = w * 1000.0 + + half_dim = embedding_dim // 2 + emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb) + emb = w.to(dtype)[:, None] * emb[None, :] + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) + if embedding_dim % 2 == 1: # zero pad + emb = torch.nn.functional.pad(emb, (0, 1)) + assert emb.shape == (w.shape[0], embedding_dim) + return emb + + @torch.no_grad() + def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + block_state.device = components._execution_device + + block_state.vae_scale_factor = components.vae_scale_factor + + block_state.height, block_state.width = block_state.latents.shape[-2:] + block_state.height = block_state.height * block_state.vae_scale_factor + block_state.width = block_state.width * block_state.vae_scale_factor + + block_state.original_size = block_state.original_size or (block_state.height, block_state.width) + block_state.target_size = block_state.target_size or (block_state.height, block_state.width) + + block_state.text_encoder_projection_dim = int(block_state.pooled_prompt_embeds.shape[-1]) + + if block_state.negative_original_size is None: + block_state.negative_original_size = block_state.original_size + if block_state.negative_target_size is None: + block_state.negative_target_size = block_state.target_size + + block_state.add_time_ids, block_state.negative_add_time_ids = self._get_add_time_ids_img2img( + components, + block_state.original_size, + block_state.crops_coords_top_left, + block_state.target_size, + block_state.aesthetic_score, + block_state.negative_aesthetic_score, + block_state.negative_original_size, + block_state.negative_crops_coords_top_left, + block_state.negative_target_size, + dtype=block_state.pooled_prompt_embeds.dtype, + text_encoder_projection_dim=block_state.text_encoder_projection_dim, + ) + block_state.add_time_ids = block_state.add_time_ids.repeat(block_state.batch_size * block_state.num_images_per_prompt, 1).to(device=block_state.device) + block_state.negative_add_time_ids = block_state.negative_add_time_ids.repeat(block_state.batch_size * block_state.num_images_per_prompt, 1).to(device=block_state.device) + + # Optionally get Guidance Scale Embedding for LCM + block_state.timestep_cond = None + if ( + hasattr(components, "unet") + and components.unet is not None + and components.unet.config.time_cond_proj_dim is not None + ): + # TODO(yiyi, aryan): Ideally, this should be `embedded_guidance_scale` instead of pulling from guider. Guider scales should be different from this! + block_state.guidance_scale_tensor = torch.tensor(components.guider.guidance_scale - 1).repeat(block_state.batch_size * block_state.num_images_per_prompt) + block_state.timestep_cond = self.get_guidance_scale_embedding( + block_state.guidance_scale_tensor, embedding_dim=components.unet.config.time_cond_proj_dim + ).to(device=block_state.device, dtype=block_state.latents.dtype) + + self.add_block_state(state, block_state) + return components, state + + +class StableDiffusionXLPrepareAdditionalConditioningStep(PipelineBlock): + model_name = "stable-diffusion-xl" + + @property + def description(self) -> str: + return ( + "Step that prepares the additional conditioning for the text-to-image generation process" + ) + + @property + def inputs(self) -> List[Tuple[str, Any]]: + return [ + InputParam("original_size"), + InputParam("target_size"), + InputParam("negative_original_size"), + InputParam("negative_target_size"), + InputParam("crops_coords_top_left", default=(0, 0)), + InputParam("negative_crops_coords_top_left", default=(0, 0)), + InputParam("num_images_per_prompt", default=1), + ] + + @property + def intermediates_inputs(self) -> List[InputParam]: + return [ + InputParam( + "latents", + required=True, + type_hint=torch.Tensor, + description="The initial latents to use for the denoising process. Can be generated in prepare_latent step." + ), + InputParam( + "pooled_prompt_embeds", + required=True, + type_hint=torch.Tensor, + description="The pooled prompt embeddings to use for the denoising process (used to determine shapes and dtypes for other additional conditioning inputs). Can be generated in text_encoder step." + ), + InputParam( + "batch_size", + required=True, + type_hint=int, + description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can be generated in input step." + ), + ] + + @property + def intermediates_outputs(self) -> List[OutputParam]: + return [OutputParam("add_time_ids", type_hint=torch.Tensor, kwargs_type="guider_input_fields", description="The time ids to condition the denoising process"), + OutputParam("negative_add_time_ids", type_hint=torch.Tensor, kwargs_type="guider_input_fields", description="The negative time ids to condition the denoising process"), + OutputParam("timestep_cond", type_hint=torch.Tensor, description="The timestep cond to use for LCM")] + + # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline._get_add_time_ids with self -> components + @staticmethod + def _get_add_time_ids( + components, original_size, crops_coords_top_left, target_size, dtype, text_encoder_projection_dim=None + ): + add_time_ids = list(original_size + crops_coords_top_left + target_size) + + passed_add_embed_dim = ( + components.unet.config.addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim + ) + expected_add_embed_dim = components.unet.add_embedding.linear_1.in_features + + if expected_add_embed_dim != passed_add_embed_dim: + raise ValueError( + f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`." + ) + + add_time_ids = torch.tensor([add_time_ids], dtype=dtype) + return add_time_ids + + # Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding + @staticmethod + def get_guidance_scale_embedding( + w: torch.Tensor, embedding_dim: int = 512, dtype: torch.dtype = torch.float32 + ) -> torch.Tensor: + """ + See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298 + + Args: + w (`torch.Tensor`): + Generate embedding vectors with a specified guidance scale to subsequently enrich timestep embeddings. + embedding_dim (`int`, *optional*, defaults to 512): + Dimension of the embeddings to generate. + dtype (`torch.dtype`, *optional*, defaults to `torch.float32`): + Data type of the generated embeddings. + + Returns: + `torch.Tensor`: Embedding vectors with shape `(len(w), embedding_dim)`. + """ + assert len(w.shape) == 1 + w = w * 1000.0 + + half_dim = embedding_dim // 2 + emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb) + emb = w.to(dtype)[:, None] * emb[None, :] + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) + if embedding_dim % 2 == 1: # zero pad + emb = torch.nn.functional.pad(emb, (0, 1)) + assert emb.shape == (w.shape[0], embedding_dim) + return emb + + @torch.no_grad() + def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + block_state.device = components._execution_device + + block_state.height, block_state.width = block_state.latents.shape[-2:] + block_state.height = block_state.height * components.vae_scale_factor + block_state.width = block_state.width * components.vae_scale_factor + + block_state.original_size = block_state.original_size or (block_state.height, block_state.width) + block_state.target_size = block_state.target_size or (block_state.height, block_state.width) + + block_state.text_encoder_projection_dim = int(block_state.pooled_prompt_embeds.shape[-1]) + + block_state.add_time_ids = self._get_add_time_ids( + components, + block_state.original_size, + block_state.crops_coords_top_left, + block_state.target_size, + block_state.pooled_prompt_embeds.dtype, + text_encoder_projection_dim=block_state.text_encoder_projection_dim, + ) + if block_state.negative_original_size is not None and block_state.negative_target_size is not None: + block_state.negative_add_time_ids = self._get_add_time_ids( + components, + block_state.negative_original_size, + block_state.negative_crops_coords_top_left, + block_state.negative_target_size, + block_state.pooled_prompt_embeds.dtype, + text_encoder_projection_dim=block_state.text_encoder_projection_dim, + ) + else: + block_state.negative_add_time_ids = block_state.add_time_ids + + block_state.add_time_ids = block_state.add_time_ids.repeat(block_state.batch_size * block_state.num_images_per_prompt, 1).to(device=block_state.device) + block_state.negative_add_time_ids = block_state.negative_add_time_ids.repeat(block_state.batch_size * block_state.num_images_per_prompt, 1).to(device=block_state.device) + + # Optionally get Guidance Scale Embedding for LCM + block_state.timestep_cond = None + if ( + hasattr(components, "unet") + and components.unet is not None + and components.unet.config.time_cond_proj_dim is not None + ): + # TODO(yiyi, aryan): Ideally, this should be `embedded_guidance_scale` instead of pulling from guider. Guider scales should be different from this! + block_state.guidance_scale_tensor = torch.tensor(components.guider.guidance_scale - 1).repeat(block_state.batch_size * block_state.num_images_per_prompt) + block_state.timestep_cond = self.get_guidance_scale_embedding( + block_state.guidance_scale_tensor, embedding_dim=components.unet.config.time_cond_proj_dim + ).to(device=block_state.device, dtype=block_state.latents.dtype) + + self.add_block_state(state, block_state) + return components, state + + +class StableDiffusionXLControlNetInputStep(PipelineBlock): + + model_name = "stable-diffusion-xl" + + @property + def expected_components(self) -> List[ComponentSpec]: + return [ + ComponentSpec("controlnet", ControlNetModel), + ComponentSpec("control_image_processor", VaeImageProcessor, config=FrozenDict({"do_convert_rgb": True, "do_normalize": False}), default_creation_method="from_config"), + ] + + @property + def description(self) -> str: + return "step that prepare inputs for controlnet" + + @property + def inputs(self) -> List[Tuple[str, Any]]: + return [ + InputParam("control_image", required=True), + InputParam("control_guidance_start", default=0.0), + InputParam("control_guidance_end", default=1.0), + InputParam("controlnet_conditioning_scale", default=1.0), + InputParam("guess_mode", default=False), + InputParam("num_images_per_prompt", default=1), + ] + + @property + def intermediates_inputs(self) -> List[str]: + return [ + InputParam( + "latents", + required=True, + type_hint=torch.Tensor, + description="The initial latents to use for the denoising process. Can be generated in prepare_latent step." + ), + InputParam( + "batch_size", + required=True, + type_hint=int, + description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can be generated in input step." + ), + InputParam( + "timesteps", + required=True, + type_hint=torch.Tensor, + description="The timesteps to use for the denoising process. Can be generated in set_timesteps step." + ), + InputParam( + "crops_coords", + type_hint=Optional[Tuple[int]], + description="The crop coordinates to use for preprocess/postprocess the image and mask, for inpainting task only. Can be generated in vae_encode step." + ), + ] + + @property + def intermediates_outputs(self) -> List[OutputParam]: + return [ + OutputParam("controlnet_cond", type_hint=torch.Tensor, description="The processed control image"), + OutputParam("control_guidance_start", type_hint=List[float], description="The controlnet guidance start values"), + OutputParam("control_guidance_end", type_hint=List[float], description="The controlnet guidance end values"), + OutputParam("conditioning_scale", type_hint=List[float], description="The controlnet conditioning scale values"), + OutputParam("guess_mode", type_hint=bool, description="Whether guess mode is used"), + OutputParam("controlnet_keep", type_hint=List[float], description="The controlnet keep values"), + ] + + + + # Modified from diffusers.pipelines.controlnet.pipeline_controlnet_sd_xl.StableDiffusionXLControlNetPipeline.prepare_image + # 1. return image without apply any guidance + # 2. add crops_coords and resize_mode to preprocess() + @staticmethod + def prepare_control_image( + components, + image, + width, + height, + batch_size, + num_images_per_prompt, + device, + dtype, + crops_coords=None, + ): + if crops_coords is not None: + image = components.control_image_processor.preprocess(image, height=height, width=width, crops_coords=crops_coords, resize_mode="fill").to(dtype=torch.float32) + else: + image = components.control_image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32) + + image_batch_size = image.shape[0] + if image_batch_size == 1: + repeat_by = batch_size + else: + # image batch size is the same as prompt batch size + repeat_by = num_images_per_prompt + + image = image.repeat_interleave(repeat_by, dim=0) + image = image.to(device=device, dtype=dtype) + return image + + + + @torch.no_grad() + def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineState) -> PipelineState: + + block_state = self.get_block_state(state) + + # (1) prepare controlnet inputs + block_state.device = components._execution_device + block_state.height, block_state.width = block_state.latents.shape[-2:] + block_state.height = block_state.height * components.vae_scale_factor + block_state.width = block_state.width * components.vae_scale_factor + + controlnet = unwrap_module(components.controlnet) + + # (1.1) + # control_guidance_start/control_guidance_end (align format) + if not isinstance(block_state.control_guidance_start, list) and isinstance(block_state.control_guidance_end, list): + block_state.control_guidance_start = len(block_state.control_guidance_end) * [block_state.control_guidance_start] + elif not isinstance(block_state.control_guidance_end, list) and isinstance(block_state.control_guidance_start, list): + block_state.control_guidance_end = len(block_state.control_guidance_start) * [block_state.control_guidance_end] + elif not isinstance(block_state.control_guidance_start, list) and not isinstance(block_state.control_guidance_end, list): + mult = len(controlnet.nets) if isinstance(controlnet, MultiControlNetModel) else 1 + block_state.control_guidance_start, block_state.control_guidance_end = ( + mult * [block_state.control_guidance_start], + mult * [block_state.control_guidance_end], + ) + + # (1.2) + # controlnet_conditioning_scale (align format) + if isinstance(controlnet, MultiControlNetModel) and isinstance(block_state.controlnet_conditioning_scale, float): + block_state.controlnet_conditioning_scale = [block_state.controlnet_conditioning_scale] * len(controlnet.nets) + + # (1.3) + # global_pool_conditions + block_state.global_pool_conditions = ( + controlnet.config.global_pool_conditions + if isinstance(controlnet, ControlNetModel) + else controlnet.nets[0].config.global_pool_conditions + ) + # (1.4) + # guess_mode + block_state.guess_mode = block_state.guess_mode or block_state.global_pool_conditions + + # (1.5) + # control_image + if isinstance(controlnet, ControlNetModel): + block_state.control_image = self.prepare_control_image( + components, + image=block_state.control_image, + width=block_state.width, + height=block_state.height, + batch_size=block_state.batch_size * block_state.num_images_per_prompt, + num_images_per_prompt=block_state.num_images_per_prompt, + device=block_state.device, + dtype=controlnet.dtype, + crops_coords=block_state.crops_coords, + ) + elif isinstance(controlnet, MultiControlNetModel): + control_images = [] + + for control_image_ in block_state.control_image: + control_image = self.prepare_control_image( + components, + image=control_image_, + width=block_state.width, + height=block_state.height, + batch_size=block_state.batch_size * block_state.num_images_per_prompt, + num_images_per_prompt=block_state.num_images_per_prompt, + device=block_state.device, + dtype=controlnet.dtype, + crops_coords=block_state.crops_coords, + ) + + control_images.append(control_image) + + block_state.control_image = control_images + else: + assert False + + # (1.6) + # controlnet_keep + block_state.controlnet_keep = [] + for i in range(len(block_state.timesteps)): + keeps = [ + 1.0 - float(i / len(block_state.timesteps) < s or (i + 1) / len(block_state.timesteps) > e) + for s, e in zip(block_state.control_guidance_start, block_state.control_guidance_end) + ] + block_state.controlnet_keep.append(keeps[0] if isinstance(controlnet, ControlNetModel) else keeps) + + block_state.controlnet_cond = block_state.control_image + block_state.conditioning_scale = block_state.controlnet_conditioning_scale + + + + self.add_block_state(state, block_state) + + return components, state + + +class StableDiffusionXLControlNetUnionInputStep(PipelineBlock): + model_name = "stable-diffusion-xl" + + @property + def expected_components(self) -> List[ComponentSpec]: + return [ + ComponentSpec("controlnet", ControlNetUnionModel), + ComponentSpec("control_image_processor", VaeImageProcessor, config=FrozenDict({"do_convert_rgb": True, "do_normalize": False}), default_creation_method="from_config"), + ] + + @property + def description(self) -> str: + return "step that prepares inputs for the ControlNetUnion model" + + @property + def inputs(self) -> List[Tuple[str, Any]]: + return [ + InputParam("control_image", required=True), + InputParam("control_mode", required=True), + InputParam("control_guidance_start", default=0.0), + InputParam("control_guidance_end", default=1.0), + InputParam("controlnet_conditioning_scale", default=1.0), + InputParam("guess_mode", default=False), + InputParam("num_images_per_prompt", default=1), + ] + + @property + def intermediates_inputs(self) -> List[InputParam]: + return [ + InputParam( + "latents", + required=True, + type_hint=torch.Tensor, + description="The initial latents to use for the denoising process. Used to determine the shape of the control images. Can be generated in prepare_latent step." + ), + InputParam( + "batch_size", + required=True, + type_hint=int, + description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can be generated in input step." + ), + InputParam( + "dtype", + required=True, + type_hint=torch.dtype, + description="The dtype of model tensor inputs. Can be generated in input step." + ), + InputParam( + "timesteps", + required=True, + type_hint=torch.Tensor, + description="The timesteps to use for the denoising process. Needed to determine `controlnet_keep`. Can be generated in set_timesteps step." + ), + InputParam( + "crops_coords", + type_hint=Optional[Tuple[int]], + description="The crop coordinates to use for preprocess/postprocess the image and mask, for inpainting task only. Can be generated in vae_encode step." + ), + ] + + @property + def intermediates_outputs(self) -> List[OutputParam]: + return [ + OutputParam("controlnet_cond", type_hint=List[torch.Tensor], description="The processed control images"), + OutputParam("control_type_idx", type_hint=List[int], description="The control mode indices", kwargs_type="controlnet_kwargs"), + OutputParam("control_type", type_hint=torch.Tensor, description="The control type tensor that specifies which control type is active", kwargs_type="controlnet_kwargs"), + OutputParam("control_guidance_start", type_hint=float, description="The controlnet guidance start value"), + OutputParam("control_guidance_end", type_hint=float, description="The controlnet guidance end value"), + OutputParam("conditioning_scale", type_hint=List[float], description="The controlnet conditioning scale values"), + OutputParam("guess_mode", type_hint=bool, description="Whether guess mode is used"), + OutputParam("controlnet_keep", type_hint=List[float], description="The controlnet keep values"), + ] + + # Modified from diffusers.pipelines.controlnet.pipeline_controlnet_sd_xl.StableDiffusionXLControlNetPipeline.prepare_image + # 1. return image without apply any guidance + # 2. add crops_coords and resize_mode to preprocess() + @staticmethod + def prepare_control_image( + components, + image, + width, + height, + batch_size, + num_images_per_prompt, + device, + dtype, + crops_coords=None, + ): + if crops_coords is not None: + image = components.control_image_processor.preprocess(image, height=height, width=width, crops_coords=crops_coords, resize_mode="fill").to(dtype=torch.float32) + else: + image = components.control_image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32) + + image_batch_size = image.shape[0] + if image_batch_size == 1: + repeat_by = batch_size + else: + # image batch size is the same as prompt batch size + repeat_by = num_images_per_prompt + + image = image.repeat_interleave(repeat_by, dim=0) + image = image.to(device=device, dtype=dtype) + return image + + @torch.no_grad() + def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineState) -> PipelineState: + + block_state = self.get_block_state(state) + + controlnet = unwrap_module(components.controlnet) + + device = components._execution_device + dtype = block_state.dtype or components.controlnet.dtype + + block_state.height, block_state.width = block_state.latents.shape[-2:] + block_state.height = block_state.height * components.vae_scale_factor + block_state.width = block_state.width * components.vae_scale_factor + + + # control_guidance_start/control_guidance_end (align format) + if not isinstance(block_state.control_guidance_start, list) and isinstance(block_state.control_guidance_end, list): + block_state.control_guidance_start = len(block_state.control_guidance_end) * [block_state.control_guidance_start] + elif not isinstance(block_state.control_guidance_end, list) and isinstance(block_state.control_guidance_start, list): + block_state.control_guidance_end = len(block_state.control_guidance_start) * [block_state.control_guidance_end] + + # guess_mode + block_state.global_pool_conditions = controlnet.config.global_pool_conditions + block_state.guess_mode = block_state.guess_mode or block_state.global_pool_conditions + + # control_image + if not isinstance(block_state.control_image, list): + block_state.control_image = [block_state.control_image] + # control_mode + if not isinstance(block_state.control_mode, list): + block_state.control_mode = [block_state.control_mode] + + if len(block_state.control_image) != len(block_state.control_mode): + raise ValueError("Expected len(control_image) == len(control_type)") + + # control_type + block_state.num_control_type = controlnet.config.num_control_type + block_state.control_type = [0 for _ in range(block_state.num_control_type)] + for control_idx in block_state.control_mode: + block_state.control_type[control_idx] = 1 + block_state.control_type = torch.Tensor(block_state.control_type) + + block_state.control_type = block_state.control_type.reshape(1, -1).to(device, dtype=block_state.dtype) + repeat_by = block_state.batch_size * block_state.num_images_per_prompt // block_state.control_type.shape[0] + block_state.control_type = block_state.control_type.repeat_interleave(repeat_by, dim=0) + + # prepare control_image + for idx, _ in enumerate(block_state.control_image): + block_state.control_image[idx] = self.prepare_control_image( + components, + image=block_state.control_image[idx], + width=block_state.width, + height=block_state.height, + batch_size=block_state.batch_size * block_state.num_images_per_prompt, + num_images_per_prompt=block_state.num_images_per_prompt, + device=device, + dtype=dtype, + crops_coords=block_state.crops_coords, + ) + block_state.height, block_state.width = block_state.control_image[idx].shape[-2:] + + # controlnet_keep + block_state.controlnet_keep = [] + for i in range(len(block_state.timesteps)): + block_state.controlnet_keep.append( + 1.0 + - float(i / len(block_state.timesteps) < block_state.control_guidance_start or (i + 1) / len(block_state.timesteps) > block_state.control_guidance_end) + ) + block_state.control_type_idx = block_state.control_mode + block_state.controlnet_cond = block_state.control_image + block_state.conditioning_scale = block_state.controlnet_conditioning_scale + + self.add_block_state(state, block_state) + + return components, state + + +class StableDiffusionXLControlNetAutoInput(AutoPipelineBlocks): + + block_classes = [StableDiffusionXLControlNetUnionInputStep, StableDiffusionXLControlNetInputStep] + block_names = ["controlnet_union", "controlnet"] + block_trigger_inputs = ["control_mode", "control_image"] + + + +# Before denoise +class StableDiffusionXLBeforeDenoiseStep(SequentialPipelineBlocks): + block_classes = [StableDiffusionXLInputStep, StableDiffusionXLSetTimestepsStep, StableDiffusionXLPrepareLatentsStep, StableDiffusionXLPrepareAdditionalConditioningStep, StableDiffusionXLControlNetAutoInput] + block_names = ["input", "set_timesteps", "prepare_latents", "prepare_add_cond", "controlnet_input"] + + @property + def description(self): + return "Before denoise step that prepare the inputs for the denoise step.\n" + \ + "This is a sequential pipeline blocks:\n" + \ + " - `StableDiffusionXLInputStep` is used to adjust the batch size of the model inputs\n" + \ + " - `StableDiffusionXLSetTimestepsStep` is used to set the timesteps\n" + \ + " - `StableDiffusionXLPrepareLatentsStep` is used to prepare the latents\n" + \ + " - `StableDiffusionXLPrepareAdditionalConditioningStep` is used to prepare the additional conditioning\n" + \ + " - `StableDiffusionXLControlNetAutoInput` is used to prepare the controlnet input" + + +class StableDiffusionXLImg2ImgBeforeDenoiseStep(SequentialPipelineBlocks): + block_classes = [StableDiffusionXLInputStep, StableDiffusionXLImg2ImgSetTimestepsStep, StableDiffusionXLImg2ImgPrepareLatentsStep, StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep, StableDiffusionXLControlNetAutoInput] + block_names = ["input", "set_timesteps", "prepare_latents", "prepare_add_cond", "controlnet_input"] + + @property + def description(self): + return "Before denoise step that prepare the inputs for the denoise step for img2img task.\n" + \ + "This is a sequential pipeline blocks:\n" + \ + " - `StableDiffusionXLInputStep` is used to adjust the batch size of the model inputs\n" + \ + " - `StableDiffusionXLImg2ImgSetTimestepsStep` is used to set the timesteps\n" + \ + " - `StableDiffusionXLImg2ImgPrepareLatentsStep` is used to prepare the latents\n" + \ + " - `StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep` is used to prepare the additional conditioning\n" + \ + " - `StableDiffusionXLControlNetAutoInput` is used to prepare the controlnet input" + + +class StableDiffusionXLInpaintBeforeDenoiseStep(SequentialPipelineBlocks): + block_classes = [StableDiffusionXLInputStep, StableDiffusionXLImg2ImgSetTimestepsStep, StableDiffusionXLInpaintPrepareLatentsStep, StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep, StableDiffusionXLControlNetAutoInput] + block_names = ["input", "set_timesteps", "prepare_latents", "prepare_add_cond", "controlnet_input"] + + @property + def description(self): + return "Before denoise step that prepare the inputs for the denoise step for inpainting task.\n" + \ + "This is a sequential pipeline blocks:\n" + \ + " - `StableDiffusionXLInputStep` is used to adjust the batch size of the model inputs\n" + \ + " - `StableDiffusionXLImg2ImgSetTimestepsStep` is used to set the timesteps\n" + \ + " - `StableDiffusionXLInpaintPrepareLatentsStep` is used to prepare the latents\n" + \ + " - `StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep` is used to prepare the additional conditioning\n" + \ + " - `StableDiffusionXLControlNetAutoInput` is used to prepare the controlnet input" + + +class StableDiffusionXLAutoBeforeDenoiseStep(AutoPipelineBlocks): + block_classes = [StableDiffusionXLInpaintBeforeDenoiseStep, StableDiffusionXLImg2ImgBeforeDenoiseStep, StableDiffusionXLBeforeDenoiseStep] + block_names = ["inpaint", "img2img", "text2img"] + block_trigger_inputs = ["mask", "image_latents", None] + + @property + def description(self): + return "Before denoise step that prepare the inputs for the denoise step.\n" + \ + "This is an auto pipeline block that works for text2img, img2img and inpainting tasks as well as controlnet, controlnet_union.\n" + \ + " - `StableDiffusionXLInpaintBeforeDenoiseStep` (inpaint) is used when both `mask` and `image_latents` are provided.\n" + \ + " - `StableDiffusionXLImg2ImgBeforeDenoiseStep` (img2img) is used when only `image_latents` is provided.\n" + \ + " - `StableDiffusionXLBeforeDenoiseStep` (text2img) is used when both `image_latents` and `mask` are not provided.\n" + \ + " - `StableDiffusionXLControlNetUnionInputStep` is called to prepare the controlnet input when `control_mode` and `control_image` are provided.\n" + \ + " - `StableDiffusionXLControlNetInputStep` is called to prepare the controlnet input when `control_image` is provided." + diff --git a/src/diffusers/modular_pipelines/stable_diffusion_xl/decoders.py b/src/diffusers/modular_pipelines/stable_diffusion_xl/decoders.py new file mode 100644 index 000000000000..ca848e20984f --- /dev/null +++ b/src/diffusers/modular_pipelines/stable_diffusion_xl/decoders.py @@ -0,0 +1,215 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Any, List, Optional, Tuple, Union, Dict + +import PIL +import torch +import numpy as np +from collections import OrderedDict + +from ...image_processor import VaeImageProcessor, PipelineImageInput +from ...models import AutoencoderKL +from ...models.attention_processor import AttnProcessor2_0, XFormersAttnProcessor +from ...utils import logging + +from ...pipelines.stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput +from ...configuration_utils import FrozenDict + +from ..modular_pipeline_utils import ComponentSpec, ConfigSpec, InputParam, OutputParam +from ..modular_pipeline import ( + AutoPipelineBlocks, + PipelineBlock, + PipelineState, + SequentialPipelineBlocks, +) + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + + + +class StableDiffusionXLDecodeStep(PipelineBlock): + + model_name = "stable-diffusion-xl" + + @property + def expected_components(self) -> List[ComponentSpec]: + return [ + ComponentSpec("vae", AutoencoderKL), + ComponentSpec( + "image_processor", + VaeImageProcessor, + config=FrozenDict({"vae_scale_factor": 8}), + default_creation_method="from_config"), + ] + + @property + def description(self) -> str: + return "Step that decodes the denoised latents into images" + + @property + def inputs(self) -> List[Tuple[str, Any]]: + return [ + InputParam("output_type", default="pil"), + ] + + @property + def intermediates_inputs(self) -> List[str]: + return [InputParam("latents", required=True, type_hint=torch.Tensor, description="The denoised latents from the denoising step")] + + @property + def intermediates_outputs(self) -> List[str]: + return [OutputParam("images", type_hint=Union[List[PIL.Image.Image], List[torch.Tensor], List[np.array]], description="The generated images, can be a PIL.Image.Image, torch.Tensor or a numpy array")] + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_upscale.StableDiffusionUpscalePipeline.upcast_vae with self -> components + @staticmethod + def upcast_vae(components): + dtype = components.vae.dtype + components.vae.to(dtype=torch.float32) + use_torch_2_0_or_xformers = isinstance( + components.vae.decoder.mid_block.attentions[0].processor, + ( + AttnProcessor2_0, + XFormersAttnProcessor, + ), + ) + # if xformers or torch_2_0 is used attention block does not need + # to be in float32 which can save lots of memory + if use_torch_2_0_or_xformers: + components.vae.post_quant_conv.to(dtype) + components.vae.decoder.conv_in.to(dtype) + components.vae.decoder.mid_block.to(dtype) + + @torch.no_grad() + def __call__(self, components, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + + if not block_state.output_type == "latent": + latents = block_state.latents + # make sure the VAE is in float32 mode, as it overflows in float16 + block_state.needs_upcasting = components.vae.dtype == torch.float16 and components.vae.config.force_upcast + + if block_state.needs_upcasting: + self.upcast_vae(components) + latents = latents.to(next(iter(components.vae.post_quant_conv.parameters())).dtype) + elif latents.dtype != components.vae.dtype: + if torch.backends.mps.is_available(): + # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 + components.vae = components.vae.to(latents.dtype) + + # unscale/denormalize the latents + # denormalize with the mean and std if available and not None + block_state.has_latents_mean = ( + hasattr(components.vae.config, "latents_mean") and components.vae.config.latents_mean is not None + ) + block_state.has_latents_std = ( + hasattr(components.vae.config, "latents_std") and components.vae.config.latents_std is not None + ) + if block_state.has_latents_mean and block_state.has_latents_std: + block_state.latents_mean = ( + torch.tensor(components.vae.config.latents_mean).view(1, 4, 1, 1).to(latents.device, latents.dtype) + ) + block_state.latents_std = ( + torch.tensor(components.vae.config.latents_std).view(1, 4, 1, 1).to(latents.device, latents.dtype) + ) + latents = latents * block_state.latents_std / components.vae.config.scaling_factor + block_state.latents_mean + else: + latents = latents / components.vae.config.scaling_factor + + block_state.images = components.vae.decode(latents, return_dict=False)[0] + + # cast back to fp16 if needed + if block_state.needs_upcasting: + components.vae.to(dtype=torch.float16) + else: + block_state.images = block_state.latents + + # apply watermark if available + if hasattr(components, "watermark") and components.watermark is not None: + block_state.images = components.watermark.apply_watermark(block_state.images) + + block_state.images = components.image_processor.postprocess(block_state.images, output_type=block_state.output_type) + + self.add_block_state(state, block_state) + + return components, state + + +class StableDiffusionXLInpaintOverlayMaskStep(PipelineBlock): + model_name = "stable-diffusion-xl" + + @property + def description(self) -> str: + return "A post-processing step that overlays the mask on the image (inpainting task only).\n" + \ + "only needed when you are using the `padding_mask_crop` option when pre-processing the image and mask" + + @property + def inputs(self) -> List[Tuple[str, Any]]: + return [ + InputParam("image", required=True), + InputParam("mask_image", required=True), + InputParam("padding_mask_crop"), + ] + + @property + def intermediates_inputs(self) -> List[str]: + return [ + InputParam("images", required=True, type_hint=Union[List[PIL.Image.Image], List[torch.Tensor], List[np.array]], description="The generated images from the decode step"), + InputParam("crops_coords", required=True, type_hint=Tuple[int, int], description="The crop coordinates to use for preprocess/postprocess the image and mask, for inpainting task only. Can be generated in vae_encode step.") + ] + + @property + def intermediates_outputs(self) -> List[str]: + return [OutputParam("images", type_hint=Union[List[PIL.Image.Image], List[torch.Tensor], List[np.array]], description="The generated images with the mask overlayed")] + + @torch.no_grad() + def __call__(self, components, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + + if block_state.padding_mask_crop is not None and block_state.crops_coords is not None: + block_state.images = [components.image_processor.apply_overlay(block_state.mask_image, block_state.image, i, block_state.crops_coords) for i in block_state.images] + + self.add_block_state(state, block_state) + + return components, state + + + +class StableDiffusionXLInpaintDecodeStep(SequentialPipelineBlocks): + block_classes = [StableDiffusionXLDecodeStep, StableDiffusionXLInpaintOverlayMaskStep] + block_names = ["decode", "mask_overlay"] + + @property + def description(self): + return "Inpaint decode step that decode the denoised latents into images outputs.\n" + \ + "This is a sequential pipeline blocks:\n" + \ + " - `StableDiffusionXLDecodeStep` is used to decode the denoised latents into images\n" + \ + " - `StableDiffusionXLInpaintOverlayMaskStep` is used to overlay the mask on the image" + + +class StableDiffusionXLAutoDecodeStep(AutoPipelineBlocks): + block_classes = [StableDiffusionXLInpaintDecodeStep, StableDiffusionXLDecodeStep] + block_names = ["inpaint", "non-inpaint"] + block_trigger_inputs = ["padding_mask_crop", None] + + @property + def description(self): + return "Decode step that decode the denoised latents into images outputs.\n" + \ + "This is an auto pipeline block that works for inpainting and non-inpainting tasks.\n" + \ + " - `StableDiffusionXLInpaintDecodeStep` (inpaint) is used when `padding_mask_crop` is provided.\n" + \ + " - `StableDiffusionXLDecodeStep` (non-inpaint) is used when `padding_mask_crop` is not provided." + + diff --git a/src/diffusers/modular_pipelines/stable_diffusion_xl/denoise.py b/src/diffusers/modular_pipelines/stable_diffusion_xl/denoise.py new file mode 100644 index 000000000000..bc567a6b034f --- /dev/null +++ b/src/diffusers/modular_pipelines/stable_diffusion_xl/denoise.py @@ -0,0 +1,1334 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Any, Dict, List, Optional, Tuple, Union + +import torch +from tqdm.auto import tqdm + +from ...configuration_utils import FrozenDict +from ...models import ControlNetModel, UNet2DConditionModel +from ...schedulers import EulerDiscreteScheduler +from ...utils import logging +from ...utils.torch_utils import unwrap_module + +from ...guiders import ClassifierFreeGuidance +from .modular_loader import StableDiffusionXLModularLoader +from ..modular_pipeline_utils import ComponentSpec, ConfigSpec, InputParam, OutputParam +from ..modular_pipeline import ( + PipelineBlock, + PipelineState, + AutoPipelineBlocks, + LoopSequentialPipelineBlocks, + BlockState, +) +from dataclasses import asdict + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + + +# YiYi experimenting composible denoise loop +# loop step (1): prepare latent input for denoiser +class StableDiffusionXLDenoiseLoopBeforeDenoiser(PipelineBlock): + + model_name = "stable-diffusion-xl" + + @property + def expected_components(self) -> List[ComponentSpec]: + return [ + ComponentSpec("scheduler", EulerDiscreteScheduler), + ] + + @property + def description(self) -> str: + return "step within the denoising loop that prepare the latent input for the denoiser" + + + @property + def intermediates_inputs(self) -> List[str]: + return [ + InputParam( + "latents", + required=True, + type_hint=torch.Tensor, + description="The initial latents to use for the denoising process. Can be generated in prepare_latent step." + ), + ] + + @torch.no_grad() + def __call__(self, components: StableDiffusionXLModularLoader, block_state: BlockState, i: int, t: int): + + block_state.scaled_latents = components.scheduler.scale_model_input(block_state.latents, t) + + return components, block_state + +# loop step (1): prepare latent input for denoiser (with inpainting) +class StableDiffusionXLInpaintDenoiseLoopBeforeDenoiser(PipelineBlock): + + model_name = "stable-diffusion-xl" + + @property + def expected_components(self) -> List[ComponentSpec]: + return [ + ComponentSpec("scheduler", EulerDiscreteScheduler), + ComponentSpec("unet", UNet2DConditionModel), + ] + + @property + def description(self) -> str: + return "step within the denoising loop that prepare the latent input for the denoiser" + + + @property + def intermediates_inputs(self) -> List[str]: + return [ + InputParam( + "latents", + required=True, + type_hint=torch.Tensor, + description="The initial latents to use for the denoising process. Can be generated in prepare_latent step." + ), + InputParam( + "mask", + type_hint=Optional[torch.Tensor], + description="The mask to use for the denoising process, for inpainting task only. Can be generated in vae_encode or prepare_latent step." + ), + InputParam( + "masked_image_latents", + type_hint=Optional[torch.Tensor], + description="The masked image latents to use for the denoising process, for inpainting task only. Can be generated in vae_encode or prepare_latent step." + ), + ] + + + @staticmethod + def check_inputs(components, block_state): + + num_channels_unet = components.num_channels_unet + if num_channels_unet == 9: + # default case for runwayml/stable-diffusion-inpainting + if block_state.mask is None or block_state.masked_image_latents is None: + raise ValueError("mask and masked_image_latents must be provided for inpainting-specific Unet") + num_channels_latents = block_state.latents.shape[1] + num_channels_mask = block_state.mask.shape[1] + num_channels_masked_image = block_state.masked_image_latents.shape[1] + if num_channels_latents + num_channels_mask + num_channels_masked_image != num_channels_unet: + raise ValueError( + f"Incorrect configuration settings! The config of `components.unet`: {components.unet.config} expects" + f" {components.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +" + f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}" + f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of" + " `components.unet` or your `mask_image` or `image` input." + ) + + @torch.no_grad() + def __call__(self, components: StableDiffusionXLModularLoader, block_state: BlockState, i: int, t: int): + + self.check_inputs(components, block_state) + + block_state.scaled_latents = components.scheduler.scale_model_input(block_state.latents, t) + if components.num_channels_unet == 9: + block_state.scaled_latents = torch.cat([block_state.scaled_latents, block_state.mask, block_state.masked_image_latents], dim=1) + + + return components, block_state + +# loop step (2): denoise the latents with guidance +class StableDiffusionXLDenoiseLoopDenoiser(PipelineBlock): + + model_name = "stable-diffusion-xl" + + @property + def expected_components(self) -> List[ComponentSpec]: + return [ + ComponentSpec( + "guider", + ClassifierFreeGuidance, + config=FrozenDict({"guidance_scale": 7.5}), + default_creation_method="from_config"), + ComponentSpec("unet", UNet2DConditionModel), + ] + + @property + def description(self) -> str: + return ( + "Step within the denoising loop that denoise the latents with guidance" + ) + + @property + def inputs(self) -> List[Tuple[str, Any]]: + return [ + InputParam("cross_attention_kwargs"), + ] + + @property + def intermediates_inputs(self) -> List[str]: + return [ + InputParam( + "num_inference_steps", + required=True, + type_hint=int, + description="The number of inference steps to use for the denoising process. Can be generated in set_timesteps step." + ), + InputParam( + "timestep_cond", + type_hint=Optional[torch.Tensor], + description="The guidance scale embedding to use for Latent Consistency Models(LCMs). Can be generated in prepare_additional_conditioning step." + ), + InputParam( + kwargs_type="guider_input_fields", + description=( + "All conditional model inputs that need to be prepared with guider. " + "It should contain prompt_embeds/negative_prompt_embeds, " + "add_time_ids/negative_add_time_ids, " + "pooled_prompt_embeds/negative_pooled_prompt_embeds, " + "and ip_adapter_embeds/negative_ip_adapter_embeds (optional)." + "please add `kwargs_type=guider_input_fields` to their parameter spec (`OutputParam`) when they are created and added to the pipeline state" + ) + ), + + ] + + + @torch.no_grad() + def __call__(self, components: StableDiffusionXLModularLoader, block_state: BlockState, i: int, t: int) -> PipelineState: + + # Map the keys we'll see on each `guider_state_batch` (e.g. guider_state_batch.prompt_embeds) + # to the corresponding (cond, uncond) fields on block_state. (e.g. block_state.prompt_embeds, block_state.negative_prompt_embeds) + guider_input_fields ={ + "prompt_embeds": ("prompt_embeds", "negative_prompt_embeds"), + "time_ids": ("add_time_ids", "negative_add_time_ids"), + "text_embeds": ("pooled_prompt_embeds", "negative_pooled_prompt_embeds"), + "image_embeds": ("ip_adapter_embeds", "negative_ip_adapter_embeds"), + } + + + components.guider.set_state(step=i, num_inference_steps=block_state.num_inference_steps, timestep=t) + + # Prepare mini‐batches according to guidance method and `guider_input_fields` + # Each guider_state_batch will have .prompt_embeds, .time_ids, text_embeds, image_embeds. + # e.g. for CFG, we prepare two batches: one for uncond, one for cond + # for first batch, guider_state_batch.prompt_embeds correspond to block_state.prompt_embeds + # for second batch, guider_state_batch.prompt_embeds correspond to block_state.negative_prompt_embeds + guider_state = components.guider.prepare_inputs(block_state, guider_input_fields) + + # run the denoiser for each guidance batch + for guider_state_batch in guider_state: + components.guider.prepare_models(components.unet) + cond_kwargs = guider_state_batch.as_dict() + cond_kwargs = {k:v for k,v in cond_kwargs.items() if k in guider_input_fields} + prompt_embeds = cond_kwargs.pop("prompt_embeds") + + # Predict the noise residual + # store the noise_pred in guider_state_batch so that we can apply guidance across all batches + guider_state_batch.noise_pred = components.unet( + block_state.scaled_latents, + t, + encoder_hidden_states=prompt_embeds, + timestep_cond=block_state.timestep_cond, + cross_attention_kwargs=block_state.cross_attention_kwargs, + added_cond_kwargs=cond_kwargs, + return_dict=False, + )[0] + components.guider.cleanup_models(components.unet) + + # Perform guidance + block_state.noise_pred, block_state.scheduler_step_kwargs = components.guider(guider_state) + + return components, block_state + +# loop step (2): denoise the latents with guidance (with controlnet) +class StableDiffusionXLControlNetDenoiseLoopDenoiser(PipelineBlock): + + model_name = "stable-diffusion-xl" + + @property + def expected_components(self) -> List[ComponentSpec]: + return [ + ComponentSpec( + "guider", + ClassifierFreeGuidance, + config=FrozenDict({"guidance_scale": 7.5}), + default_creation_method="from_config"), + ComponentSpec("unet", UNet2DConditionModel), + ComponentSpec("controlnet", ControlNetModel), + ] + + @property + def description(self) -> str: + return "step that iteratively denoise the latents for the text-to-image/image-to-image/inpainting generation process. Using ControlNet to condition the denoising process" + + @property + def inputs(self) -> List[Tuple[str, Any]]: + return [ + InputParam("cross_attention_kwargs"), + ] + + @property + def intermediates_inputs(self) -> List[str]: + return [ + InputParam( + "controlnet_cond", + required=True, + type_hint=torch.Tensor, + description="The control image to use for the denoising process. Can be generated in prepare_controlnet_inputs step." + ), + InputParam( + "conditioning_scale", + type_hint=float, + description="The controlnet conditioning scale value to use for the denoising process. Can be generated in prepare_controlnet_inputs step." + ), + InputParam( + "guess_mode", + required=True, + type_hint=bool, + description="The guess mode value to use for the denoising process. Can be generated in prepare_controlnet_inputs step." + ), + InputParam( + "controlnet_keep", + required=True, + type_hint=List[float], + description="The controlnet keep values to use for the denoising process. Can be generated in prepare_controlnet_inputs step." + ), + InputParam( + "timestep_cond", + type_hint=Optional[torch.Tensor], + description="The guidance scale embedding to use for Latent Consistency Models(LCMs), can be generated by prepare_additional_conditioning step" + ), + InputParam( + "num_inference_steps", + required=True, + type_hint=int, + description="The number of inference steps to use for the denoising process. Can be generated in set_timesteps step." + ), + InputParam( + kwargs_type="guider_input_fields", + description=( + "All conditional model inputs that need to be prepared with guider. " + "It should contain prompt_embeds/negative_prompt_embeds, " + "add_time_ids/negative_add_time_ids, " + "pooled_prompt_embeds/negative_pooled_prompt_embeds, " + "and ip_adapter_embeds/negative_ip_adapter_embeds (optional)." + "please add `kwargs_type=guider_input_fields` to their parameter spec (`OutputParam`) when they are created and added to the pipeline state" + ) + ), + InputParam( + kwargs_type="controlnet_kwargs", + description=( + "additional kwargs for controlnet (e.g. control_type_idx and control_type from the controlnet union input step )" + "please add `kwargs_type=controlnet_kwargs` to their parameter spec (`OutputParam`) when they are created and added to the pipeline state" + ) + ) + ] + + @staticmethod + def prepare_extra_kwargs(func, exclude_kwargs=[], **kwargs): + + accepted_kwargs = set(inspect.signature(func).parameters.keys()) + extra_kwargs = {} + for key, value in kwargs.items(): + if key in accepted_kwargs and key not in exclude_kwargs: + extra_kwargs[key] = value + + return extra_kwargs + + + @torch.no_grad() + def __call__(self, components: StableDiffusionXLModularLoader, block_state: BlockState, i: int, t: int): + + extra_controlnet_kwargs = self.prepare_extra_kwargs(components.controlnet.forward, **block_state.controlnet_kwargs) + + # Map the keys we'll see on each `guider_state_batch` (e.g. guider_state_batch.prompt_embeds) + # to the corresponding (cond, uncond) fields on block_state. (e.g. block_state.prompt_embeds, block_state.negative_prompt_embeds) + guider_input_fields ={ + "prompt_embeds": ("prompt_embeds", "negative_prompt_embeds"), + "time_ids": ("add_time_ids", "negative_add_time_ids"), + "text_embeds": ("pooled_prompt_embeds", "negative_pooled_prompt_embeds"), + "image_embeds": ("ip_adapter_embeds", "negative_ip_adapter_embeds"), + } + + + # cond_scale for the timestep (controlnet input) + if isinstance(block_state.controlnet_keep[i], list): + block_state.cond_scale = [c * s for c, s in zip(block_state.conditioning_scale, block_state.controlnet_keep[i])] + else: + controlnet_cond_scale = block_state.conditioning_scale + if isinstance(controlnet_cond_scale, list): + controlnet_cond_scale = controlnet_cond_scale[0] + block_state.cond_scale = controlnet_cond_scale * block_state.controlnet_keep[i] + + # default controlnet output/unet input for guess mode + conditional path + block_state.down_block_res_samples_zeros = None + block_state.mid_block_res_sample_zeros = None + + # guided denoiser step + components.guider.set_state(step=i, num_inference_steps=block_state.num_inference_steps, timestep=t) + + # Prepare mini‐batches according to guidance method and `guider_input_fields` + # Each guider_state_batch will have .prompt_embeds, .time_ids, text_embeds, image_embeds. + # e.g. for CFG, we prepare two batches: one for uncond, one for cond + # for first batch, guider_state_batch.prompt_embeds correspond to block_state.prompt_embeds + # for second batch, guider_state_batch.prompt_embeds correspond to block_state.negative_prompt_embeds + guider_state = components.guider.prepare_inputs(block_state, guider_input_fields) + + # run the denoiser for each guidance batch + for guider_state_batch in guider_state: + components.guider.prepare_models(components.unet) + + # Prepare additional conditionings + added_cond_kwargs = { + "text_embeds": guider_state_batch.text_embeds, + "time_ids": guider_state_batch.time_ids, + } + if hasattr(guider_state_batch, "image_embeds") and guider_state_batch.image_embeds is not None: + added_cond_kwargs["image_embeds"] = guider_state_batch.image_embeds + + # Prepare controlnet additional conditionings + controlnet_added_cond_kwargs = { + "text_embeds": guider_state_batch.text_embeds, + "time_ids": guider_state_batch.time_ids, + } + # run controlnet for the guidance batch + if block_state.guess_mode and not components.guider.is_conditional: + # guider always run uncond batch first, so these tensors should be set already + down_block_res_samples = block_state.down_block_res_samples_zeros + mid_block_res_sample = block_state.mid_block_res_sample_zeros + else: + down_block_res_samples, mid_block_res_sample = components.controlnet( + block_state.scaled_latents, + t, + encoder_hidden_states=guider_state_batch.prompt_embeds, + controlnet_cond=block_state.controlnet_cond, + conditioning_scale=block_state.cond_scale, + guess_mode=block_state.guess_mode, + added_cond_kwargs=controlnet_added_cond_kwargs, + return_dict=False, + **extra_controlnet_kwargs, + ) + + # assign it to block_state so it will be available for the uncond guidance batch + if block_state.down_block_res_samples_zeros is None: + block_state.down_block_res_samples_zeros = [torch.zeros_like(d) for d in down_block_res_samples] + if block_state.mid_block_res_sample_zeros is None: + block_state.mid_block_res_sample_zeros = torch.zeros_like(mid_block_res_sample) + + # Predict the noise + # store the noise_pred in guider_state_batch so we can apply guidance across all batches + guider_state_batch.noise_pred = components.unet( + block_state.scaled_latents, + t, + encoder_hidden_states=guider_state_batch.prompt_embeds, + timestep_cond=block_state.timestep_cond, + cross_attention_kwargs=block_state.cross_attention_kwargs, + added_cond_kwargs=added_cond_kwargs, + down_block_additional_residuals=down_block_res_samples, + mid_block_additional_residual=mid_block_res_sample, + return_dict=False, + )[0] + components.guider.cleanup_models(components.unet) + + # Perform guidance + block_state.noise_pred, block_state.scheduler_step_kwargs = components.guider(guider_state) + + return components, block_state + +# loop step (3): scheduler step to update latents +class StableDiffusionXLDenoiseLoopAfterDenoiser(PipelineBlock): + + model_name = "stable-diffusion-xl" + + @property + def expected_components(self) -> List[ComponentSpec]: + return [ + ComponentSpec("scheduler", EulerDiscreteScheduler), + ] + + @property + def description(self) -> str: + return "step that iteratively denoise the latents for the text-to-image/image-to-image/inpainting generation process. Using ControlNet to condition the denoising process" + + @property + def inputs(self) -> List[Tuple[str, Any]]: + return [ + InputParam("eta", default=0.0), + ] + + @property + def intermediates_inputs(self) -> List[str]: + return [ + InputParam("generator"), + ] + + @property + def intermediates_outputs(self) -> List[OutputParam]: + return [OutputParam("latents", type_hint=torch.Tensor, description="The denoised latents")] + + #YiYi TODO: move this out of here + @staticmethod + def prepare_extra_kwargs(func, exclude_kwargs=[], **kwargs): + + accepted_kwargs = set(inspect.signature(func).parameters.keys()) + extra_kwargs = {} + for key, value in kwargs.items(): + if key in accepted_kwargs and key not in exclude_kwargs: + extra_kwargs[key] = value + + return extra_kwargs + + + @torch.no_grad() + def __call__(self, components: StableDiffusionXLModularLoader, block_state: BlockState, i: int, t: int): + + # Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + block_state.extra_step_kwargs = self.prepare_extra_kwargs(components.scheduler.step, generator=block_state.generator, eta=block_state.eta) + + + # Perform scheduler step using the predicted output + block_state.latents_dtype = block_state.latents.dtype + block_state.latents = components.scheduler.step(block_state.noise_pred, t, block_state.latents, **block_state.extra_step_kwargs, **block_state.scheduler_step_kwargs, return_dict=False)[0] + + if block_state.latents.dtype != block_state.latents_dtype: + if torch.backends.mps.is_available(): + # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 + block_state.latents = block_state.latents.to(block_state.latents_dtype) + + return components, block_state + +# loop step (3): scheduler step to update latents (with inpainting) +class StableDiffusionXLInpaintDenoiseLoopAfterDenoiser(PipelineBlock): + + model_name = "stable-diffusion-xl" + + @property + def expected_components(self) -> List[ComponentSpec]: + return [ + ComponentSpec("scheduler", EulerDiscreteScheduler), + ComponentSpec("unet", UNet2DConditionModel), + ] + + @property + def description(self) -> str: + return "step that iteratively denoise the latents for the text-to-image/image-to-image/inpainting generation process. Using ControlNet to condition the denoising process" + + @property + def inputs(self) -> List[Tuple[str, Any]]: + return [ + InputParam("eta", default=0.0), + ] + + @property + def intermediates_inputs(self) -> List[str]: + return [ + InputParam("generator"), + InputParam( + "timesteps", + required=True, + type_hint=torch.Tensor, + description="The timesteps to use for the denoising process. Can be generated in set_timesteps step." + ), + InputParam( + "mask", + type_hint=Optional[torch.Tensor], + description="The mask to use for the denoising process, for inpainting task only. Can be generated in vae_encode or prepare_latent step." + ), + InputParam( + "noise", + type_hint=Optional[torch.Tensor], + description="The noise added to the image latents, for inpainting task only. Can be generated in prepare_latent step." + ), + InputParam( + "image_latents", + type_hint=Optional[torch.Tensor], + description="The image latents to use for the denoising process, for inpainting/image-to-image task only. Can be generated in vae_encode or prepare_latent step." + ), + ] + + @property + def intermediates_outputs(self) -> List[OutputParam]: + return [OutputParam("latents", type_hint=torch.Tensor, description="The denoised latents")] + + @staticmethod + def prepare_extra_kwargs(func, exclude_kwargs=[], **kwargs): + + accepted_kwargs = set(inspect.signature(func).parameters.keys()) + extra_kwargs = {} + for key, value in kwargs.items(): + if key in accepted_kwargs and key not in exclude_kwargs: + extra_kwargs[key] = value + + return extra_kwargs + + def check_inputs(self, components, block_state): + if components.num_channels_unet == 4: + if block_state.image_latents is None: + raise ValueError(f"image_latents is required for this step {self.__class__.__name__}") + if block_state.mask is None: + raise ValueError(f"mask is required for this step {self.__class__.__name__}") + if block_state.noise is None: + raise ValueError(f"noise is required for this step {self.__class__.__name__}") + + @torch.no_grad() + def __call__(self, components: StableDiffusionXLModularLoader, block_state: BlockState, i: int, t: int): + + self.check_inputs(components, block_state) + + # Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + block_state.extra_step_kwargs = self.prepare_extra_kwargs(components.scheduler.step, generator=block_state.generator, eta=block_state.eta) + + + # Perform scheduler step using the predicted output + block_state.latents_dtype = block_state.latents.dtype + block_state.latents = components.scheduler.step(block_state.noise_pred, t, block_state.latents, **block_state.extra_step_kwargs, **block_state.scheduler_step_kwargs, return_dict=False)[0] + + if block_state.latents.dtype != block_state.latents_dtype: + if torch.backends.mps.is_available(): + # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 + block_state.latents = block_state.latents.to(block_state.latents_dtype) + + # adjust latent for inpainting + if components.num_channels_unet == 4: + block_state.init_latents_proper = block_state.image_latents + if i < len(block_state.timesteps) - 1: + block_state.noise_timestep = block_state.timesteps[i + 1] + block_state.init_latents_proper = components.scheduler.add_noise( + block_state.init_latents_proper, block_state.noise, torch.tensor([block_state.noise_timestep]) + ) + + block_state.latents = (1 - block_state.mask) * block_state.init_latents_proper + block_state.mask * block_state.latents + + + + return components, block_state + + +# the loop wrapper that iterates over the timesteps +class StableDiffusionXLDenoiseLoopWrapper(LoopSequentialPipelineBlocks): + + model_name = "stable-diffusion-xl" + + @property + def description(self) -> str: + return ( + "Step that iteratively denoise the latents for the text-to-image/image-to-image/inpainting generation process" + ) + + @property + def loop_expected_components(self) -> List[ComponentSpec]: + return [ + ComponentSpec( + "guider", + ClassifierFreeGuidance, + config=FrozenDict({"guidance_scale": 7.5}), + default_creation_method="from_config"), + ComponentSpec("scheduler", EulerDiscreteScheduler), + ComponentSpec("unet", UNet2DConditionModel), + ] + + @property + def loop_intermediates_inputs(self) -> List[InputParam]: + return [ + InputParam( + "timesteps", + required=True, + type_hint=torch.Tensor, + description="The timesteps to use for the denoising process. Can be generated in set_timesteps step." + ), + InputParam( + "num_inference_steps", + required=True, + type_hint=int, + description="The number of inference steps to use for the denoising process. Can be generated in set_timesteps step." + ), + ] + + + @torch.no_grad() + def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + + block_state.disable_guidance = True if components.unet.config.time_cond_proj_dim is not None else False + if block_state.disable_guidance: + components.guider.disable() + else: + components.guider.enable() + + block_state.num_warmup_steps = max(len(block_state.timesteps) - block_state.num_inference_steps * components.scheduler.order, 0) + + with self.progress_bar(total=block_state.num_inference_steps) as progress_bar: + for i, t in enumerate(block_state.timesteps): + components, block_state = self.loop_step(components, block_state, i=i, t=t) + if i == len(block_state.timesteps) - 1 or ((i + 1) > block_state.num_warmup_steps and (i + 1) % components.scheduler.order == 0): + progress_bar.update() + + self.add_block_state(state, block_state) + + return components, state + + +# composing the denoising loops +class StableDiffusionXLDenoiseLoop(StableDiffusionXLDenoiseLoopWrapper): + block_classes = [StableDiffusionXLDenoiseLoopBeforeDenoiser, StableDiffusionXLDenoiseLoopDenoiser, StableDiffusionXLDenoiseLoopAfterDenoiser] + block_names = ["before_denoiser", "denoiser", "after_denoiser"] + +# control_cond +class StableDiffusionXLControlNetDenoiseLoop(StableDiffusionXLDenoiseLoopWrapper): + block_classes = [StableDiffusionXLDenoiseLoopBeforeDenoiser, StableDiffusionXLControlNetDenoiseLoopDenoiser, StableDiffusionXLDenoiseLoopAfterDenoiser] + block_names = ["before_denoiser", "denoiser", "after_denoiser"] + +# mask +class StableDiffusionXLInpaintDenoiseLoop(StableDiffusionXLDenoiseLoopWrapper): + block_classes = [StableDiffusionXLInpaintDenoiseLoopBeforeDenoiser, StableDiffusionXLDenoiseLoopDenoiser, StableDiffusionXLInpaintDenoiseLoopAfterDenoiser] + block_names = ["before_denoiser", "denoiser", "after_denoiser"] + +# control_cond + mask +class StableDiffusionXLInpaintControlNetDenoiseLoop(StableDiffusionXLDenoiseLoopWrapper): + block_classes = [StableDiffusionXLInpaintDenoiseLoopBeforeDenoiser, StableDiffusionXLControlNetDenoiseLoopDenoiser, StableDiffusionXLInpaintDenoiseLoopAfterDenoiser] + block_names = ["before_denoiser", "denoiser", "after_denoiser"] + + + +# all task without controlnet +class StableDiffusionXLDenoiseStep(AutoPipelineBlocks): + block_classes = [StableDiffusionXLInpaintDenoiseLoop, StableDiffusionXLDenoiseLoop] + block_names = ["inpaint_denoise", "denoise"] + block_trigger_inputs = ["mask", None] + +# all task with controlnet +class StableDiffusionXLControlNetDenoiseStep(AutoPipelineBlocks): + block_classes = [StableDiffusionXLInpaintControlNetDenoiseLoop, StableDiffusionXLControlNetDenoiseLoop] + block_names = ["inpaint_controlnet_denoise", "controlnet_denoise"] + block_trigger_inputs = ["mask", None] + +# all task with or without controlnet +class StableDiffusionXLAutoDenoiseStep(AutoPipelineBlocks): + block_classes = [StableDiffusionXLControlNetDenoiseStep, StableDiffusionXLDenoiseStep] + block_names = ["controlnet_denoise", "denoise"] + block_trigger_inputs = ["controlnet_cond", None] + + + + + + + +# YiYi Notes: alternatively, this is you can just write the denoise loop using a pipeline block, easier but not composible +# class StableDiffusionXLDenoiseStep(PipelineBlock): + +# model_name = "stable-diffusion-xl" + +# @property +# def expected_components(self) -> List[ComponentSpec]: +# return [ +# ComponentSpec( +# "guider", +# ClassifierFreeGuidance, +# config=FrozenDict({"guidance_scale": 7.5}), +# default_creation_method="from_config"), +# ComponentSpec("scheduler", EulerDiscreteScheduler), +# ComponentSpec("unet", UNet2DConditionModel), +# ] + +# @property +# def description(self) -> str: +# return ( +# "Step that iteratively denoise the latents for the text-to-image/image-to-image/inpainting generation process" +# ) + +# @property +# def inputs(self) -> List[Tuple[str, Any]]: +# return [ +# InputParam("cross_attention_kwargs"), +# InputParam("generator"), +# InputParam("eta", default=0.0), +# InputParam("num_images_per_prompt", default=1), +# ] + +# @property +# def intermediates_inputs(self) -> List[str]: +# return [ +# InputParam( +# "latents", +# required=True, +# type_hint=torch.Tensor, +# description="The initial latents to use for the denoising process. Can be generated in prepare_latent step." +# ), +# InputParam( +# "batch_size", +# required=True, +# type_hint=int, +# description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can be generated in input step." +# ), +# InputParam( +# "timesteps", +# required=True, +# type_hint=torch.Tensor, +# description="The timesteps to use for the denoising process. Can be generated in set_timesteps step." +# ), +# InputParam( +# "num_inference_steps", +# required=True, +# type_hint=int, +# description="The number of inference steps to use for the denoising process. Can be generated in set_timesteps step." +# ), +# InputParam( +# "pooled_prompt_embeds", +# required=True, +# type_hint=torch.Tensor, +# description="The pooled prompt embeddings to use to condition the denoising process. Can be generated in text_encoder step." +# ), +# InputParam( +# "negative_pooled_prompt_embeds", +# type_hint=Optional[torch.Tensor], +# description="The negative pooled prompt embeddings to use to condition the denoising process. Can be generated in text_encoder step. " +# ), +# InputParam( +# "add_time_ids", +# required=True, +# type_hint=torch.Tensor, +# description="The time ids to use as additional conditioning for the denoising process. Can be generated in prepare_additional_conditioning step." +# ), +# InputParam( +# "negative_add_time_ids", +# type_hint=Optional[torch.Tensor], +# description="The negative time ids to use as additional conditioning for the denoising process. Can be generated in prepare_additional_conditioning step." +# ), +# InputParam( +# "prompt_embeds", +# required=True, +# type_hint=torch.Tensor, +# description="The prompt embeddings to use to condition the denoising process. Can be generated in text_encoder step." +# ), +# InputParam( +# "negative_prompt_embeds", +# type_hint=Optional[torch.Tensor], +# description="The negative prompt embeddings to use to condition the denoising process. Can be generated in text_encoder step. " +# ), +# InputParam( +# "timestep_cond", +# type_hint=Optional[torch.Tensor], +# description="The guidance scale embedding to use for Latent Consistency Models(LCMs). Can be generated in prepare_additional_conditioning step." +# ), +# InputParam( +# "mask", +# type_hint=Optional[torch.Tensor], +# description="The mask to use for the denoising process, for inpainting task only. Can be generated in vae_encode or prepare_latent step." +# ), +# InputParam( +# "masked_image_latents", +# type_hint=Optional[torch.Tensor], +# description="The masked image latents to use for the denoising process, for inpainting task only. Can be generated in vae_encode or prepare_latent step." +# ), +# InputParam( +# "noise", +# type_hint=Optional[torch.Tensor], +# description="The noise added to the image latents, for inpainting task only. Can be generated in prepare_latent step." +# ), +# InputParam( +# "image_latents", +# type_hint=Optional[torch.Tensor], +# description="The image latents to use for the denoising process, for inpainting/image-to-image task only. Can be generated in vae_encode or prepare_latent step." +# ), +# InputParam( +# "ip_adapter_embeds", +# type_hint=Optional[torch.Tensor], +# description="The ip adapter embeddings to use to condition the denoising process, need to have ip adapter model loaded. Can be generated in ip_adapter step." +# ), +# InputParam( +# "negative_ip_adapter_embeds", +# type_hint=Optional[torch.Tensor], +# description="The negative ip adapter embeddings to use to condition the denoising process, need to have ip adapter model loaded. Can be generated in ip_adapter step." +# ), +# ] + +# @property +# def intermediates_outputs(self) -> List[OutputParam]: +# return [OutputParam("latents", type_hint=torch.Tensor, description="The denoised latents")] + + +# @staticmethod +# def check_inputs(components, block_state): + +# num_channels_unet = components.unet.config.in_channels +# if num_channels_unet == 9: +# # default case for runwayml/stable-diffusion-inpainting +# if block_state.mask is None or block_state.masked_image_latents is None: +# raise ValueError("mask and masked_image_latents must be provided for inpainting-specific Unet") +# num_channels_latents = block_state.latents.shape[1] +# num_channels_mask = block_state.mask.shape[1] +# num_channels_masked_image = block_state.masked_image_latents.shape[1] +# if num_channels_latents + num_channels_mask + num_channels_masked_image != num_channels_unet: +# raise ValueError( +# f"Incorrect configuration settings! The config of `components.unet`: {components.unet.config} expects" +# f" {components.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +" +# f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}" +# f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of" +# " `components.unet` or your `mask_image` or `image` input." +# ) + +# # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs with self -> components +# @staticmethod +# def prepare_extra_step_kwargs(components, generator, eta): +# # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature +# # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. +# # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 +# # and should be between [0, 1] + +# accepts_eta = "eta" in set(inspect.signature(components.scheduler.step).parameters.keys()) +# extra_step_kwargs = {} +# if accepts_eta: +# extra_step_kwargs["eta"] = eta + +# # check if the scheduler accepts generator +# accepts_generator = "generator" in set(inspect.signature(components.scheduler.step).parameters.keys()) +# if accepts_generator: +# extra_step_kwargs["generator"] = generator +# return extra_step_kwargs + +# @torch.no_grad() +# def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineState) -> PipelineState: + +# block_state = self.get_block_state(state) +# self.check_inputs(components, block_state) + +# block_state.num_channels_unet = components.unet.config.in_channels +# block_state.disable_guidance = True if components.unet.config.time_cond_proj_dim is not None else False +# if block_state.disable_guidance: +# components.guider.disable() +# else: +# components.guider.enable() + +# # Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline +# block_state.extra_step_kwargs = self.prepare_extra_step_kwargs(components, block_state.generator, block_state.eta) +# block_state.num_warmup_steps = max(len(block_state.timesteps) - block_state.num_inference_steps * components.scheduler.order, 0) + +# components.guider.set_input_fields( +# prompt_embeds=("prompt_embeds", "negative_prompt_embeds"), +# add_time_ids=("add_time_ids", "negative_add_time_ids"), +# pooled_prompt_embeds=("pooled_prompt_embeds", "negative_pooled_prompt_embeds"), +# ip_adapter_embeds=("ip_adapter_embeds", "negative_ip_adapter_embeds"), +# ) + +# with self.progress_bar(total=block_state.num_inference_steps) as progress_bar: +# for i, t in enumerate(block_state.timesteps): +# components.guider.set_state(step=i, num_inference_steps=block_state.num_inference_steps, timestep=t) +# guider_data = components.guider.prepare_inputs(block_state) + +# block_state.scaled_latents = components.scheduler.scale_model_input(block_state.latents, t) + +# # Prepare for inpainting +# if block_state.num_channels_unet == 9: +# block_state.scaled_latents = torch.cat([block_state.scaled_latents, block_state.mask, block_state.masked_image_latents], dim=1) + +# for batch in guider_data: +# components.guider.prepare_models(components.unet) + +# # Prepare additional conditionings +# batch.added_cond_kwargs = { +# "text_embeds": batch.pooled_prompt_embeds, +# "time_ids": batch.add_time_ids, +# } +# if batch.ip_adapter_embeds is not None: +# batch.added_cond_kwargs["image_embeds"] = batch.ip_adapter_embeds + +# # Predict the noise residual +# batch.noise_pred = components.unet( +# block_state.scaled_latents, +# t, +# encoder_hidden_states=batch.prompt_embeds, +# timestep_cond=block_state.timestep_cond, +# cross_attention_kwargs=block_state.cross_attention_kwargs, +# added_cond_kwargs=batch.added_cond_kwargs, +# return_dict=False, +# )[0] +# components.guider.cleanup_models(components.unet) + +# # Perform guidance +# block_state.noise_pred, scheduler_step_kwargs = components.guider(guider_data) + +# # Perform scheduler step using the predicted output +# block_state.latents_dtype = block_state.latents.dtype +# block_state.latents = components.scheduler.step(block_state.noise_pred, t, block_state.latents, **block_state.extra_step_kwargs, **scheduler_step_kwargs, return_dict=False)[0] + +# if block_state.latents.dtype != block_state.latents_dtype: +# if torch.backends.mps.is_available(): +# # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 +# block_state.latents = block_state.latents.to(block_state.latents_dtype) + +# if block_state.num_channels_unet == 4 and block_state.mask is not None and block_state.image_latents is not None: +# block_state.init_latents_proper = block_state.image_latents +# if i < len(block_state.timesteps) - 1: +# block_state.noise_timestep = block_state.timesteps[i + 1] +# block_state.init_latents_proper = components.scheduler.add_noise( +# block_state.init_latents_proper, block_state.noise, torch.tensor([block_state.noise_timestep]) +# ) + +# block_state.latents = (1 - block_state.mask) * block_state.init_latents_proper + block_state.mask * block_state.latents + +# if i == len(block_state.timesteps) - 1 or ((i + 1) > block_state.num_warmup_steps and (i + 1) % components.scheduler.order == 0): +# progress_bar.update() + +# self.add_block_state(state, block_state) + +# return components, state + + + +# class StableDiffusionXLControlNetDenoiseStep(PipelineBlock): + +# model_name = "stable-diffusion-xl" + +# @property +# def expected_components(self) -> List[ComponentSpec]: +# return [ +# ComponentSpec( +# "guider", +# ClassifierFreeGuidance, +# config=FrozenDict({"guidance_scale": 7.5}), +# default_creation_method="from_config"), +# ComponentSpec("scheduler", EulerDiscreteScheduler), +# ComponentSpec("unet", UNet2DConditionModel), +# ComponentSpec("controlnet", ControlNetModel), +# ] + +# @property +# def description(self) -> str: +# return "step that iteratively denoise the latents for the text-to-image/image-to-image/inpainting generation process. Using ControlNet to condition the denoising process" + +# @property +# def inputs(self) -> List[Tuple[str, Any]]: +# return [ +# InputParam("num_images_per_prompt", default=1), +# InputParam("cross_attention_kwargs"), +# InputParam("generator"), +# InputParam("eta", default=0.0), +# InputParam("controlnet_conditioning_scale", type_hint=float, default=1.0), # can expect either input or intermediate input, (intermediate input if both are passed) +# ] + +# @property +# def intermediates_inputs(self) -> List[str]: +# return [ +# InputParam( +# "controlnet_cond", +# required=True, +# type_hint=torch.Tensor, +# description="The control image to use for the denoising process. Can be generated in prepare_controlnet_inputs step." +# ), +# InputParam( +# "control_guidance_start", +# required=True, +# type_hint=float, +# description="The control guidance start value to use for the denoising process. Can be generated in prepare_controlnet_inputs step." +# ), +# InputParam( +# "control_guidance_end", +# required=True, +# type_hint=float, +# description="The control guidance end value to use for the denoising process. Can be generated in prepare_controlnet_inputs step." +# ), +# InputParam( +# "conditioning_scale", +# type_hint=float, +# description="The controlnet conditioning scale value to use for the denoising process. Can be generated in prepare_controlnet_inputs step." +# ), +# InputParam( +# "guess_mode", +# required=True, +# type_hint=bool, +# description="The guess mode value to use for the denoising process. Can be generated in prepare_controlnet_inputs step." +# ), +# InputParam( +# "controlnet_keep", +# required=True, +# type_hint=List[float], +# description="The controlnet keep values to use for the denoising process. Can be generated in prepare_controlnet_inputs step." +# ), +# InputParam( +# "latents", +# required=True, +# type_hint=torch.Tensor, +# description="The initial latents to use for the denoising process. Can be generated in prepare_latent step." +# ), +# InputParam( +# "batch_size", +# required=True, +# type_hint=int, +# description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can be generated in input step." +# ), +# InputParam( +# "timesteps", +# required=True, +# type_hint=torch.Tensor, +# description="The timesteps to use for the denoising process. Can be generated in set_timesteps step." +# ), +# InputParam( +# "prompt_embeds", +# required=True, +# type_hint=torch.Tensor, +# description="The prompt embeddings used to condition the denoising process. Can be generated in text_encoder step." +# ), +# InputParam( +# "negative_prompt_embeds", +# type_hint=Optional[torch.Tensor], +# description="The negative prompt embeddings used to condition the denoising process. Can be generated in text_encoder step." +# ), +# InputParam( +# "add_time_ids", +# required=True, +# type_hint=torch.Tensor, +# description="The time ids used to condition the denoising process. Can be generated in parepare_additional_conditioning step." +# ), +# InputParam( +# "negative_add_time_ids", +# type_hint=Optional[torch.Tensor], +# description="The negative time ids used to condition the denoising process. Can be generated in parepare_additional_conditioning step." +# ), +# InputParam( +# "pooled_prompt_embeds", +# required=True, +# type_hint=torch.Tensor, +# description="The pooled prompt embeddings used to condition the denoising process. Can be generated in text_encoder step." +# ), +# InputParam( +# "negative_pooled_prompt_embeds", +# type_hint=Optional[torch.Tensor], +# description="The negative pooled prompt embeddings to use to condition the denoising process. Can be generated in text_encoder step." +# ), +# InputParam( +# "timestep_cond", +# type_hint=Optional[torch.Tensor], +# description="The guidance scale embedding to use for Latent Consistency Models(LCMs), can be generated by prepare_additional_conditioning step" +# ), +# InputParam( +# "mask", +# type_hint=Optional[torch.Tensor], +# description="The mask to use for the denoising process, for inpainting task only. Can be generated in vae_encode or prepare_latent step." +# ), +# InputParam( +# "masked_image_latents", +# type_hint=Optional[torch.Tensor], +# description="The masked image latents to use for the denoising process, for inpainting task only. Can be generated in vae_encode or prepare_latent step." +# ), +# InputParam( +# "noise", +# type_hint=Optional[torch.Tensor], +# description="The noise added to the image latents, for inpainting task only. Can be generated in prepare_latent step." +# ), +# InputParam( +# "image_latents", +# type_hint=Optional[torch.Tensor], +# description="The image latents to use for the denoising process, for inpainting/image-to-image task only. Can be generated in vae_encode or prepare_latent step." +# ), +# InputParam( +# "crops_coords", +# type_hint=Optional[Tuple[int]], +# description="The crop coordinates to use for preprocess/postprocess the image and mask, for inpainting task only. Can be generated in vae_encode step." +# ), +# InputParam( +# "ip_adapter_embeds", +# type_hint=Optional[torch.Tensor], +# description="The ip adapter embeddings to use to condition the denoising process, need to have ip adapter model loaded. Can be generated in ip_adapter step." +# ), +# InputParam( +# "negative_ip_adapter_embeds", +# type_hint=Optional[torch.Tensor], +# description="The negative ip adapter embeddings to use to condition the denoising process, need to have ip adapter model loaded. Can be generated in ip_adapter step." +# ), +# InputParam( +# "num_inference_steps", +# required=True, +# type_hint=int, +# description="The number of inference steps to use for the denoising process. Can be generated in set_timesteps step." +# ), +# InputParam(kwargs_type="controlnet_kwargs", description="additional kwargs for controlnet") +# ] + +# @property +# def intermediates_outputs(self) -> List[OutputParam]: +# return [OutputParam("latents", type_hint=torch.Tensor, description="The denoised latents")] + +# @staticmethod +# def check_inputs(components, block_state): + +# num_channels_unet = components.unet.config.in_channels +# if num_channels_unet == 9: +# # default case for runwayml/stable-diffusion-inpainting +# if block_state.mask is None or block_state.masked_image_latents is None: +# raise ValueError("mask and masked_image_latents must be provided for inpainting-specific Unet") +# num_channels_latents = block_state.latents.shape[1] +# num_channels_mask = block_state.mask.shape[1] +# num_channels_masked_image = block_state.masked_image_latents.shape[1] +# if num_channels_latents + num_channels_mask + num_channels_masked_image != num_channels_unet: +# raise ValueError( +# f"Incorrect configuration settings! The config of `components.unet`: {components.unet.config} expects" +# f" {components.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +" +# f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}" +# f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of" +# " `components.unet` or your `mask_image` or `image` input." +# ) +# @staticmethod +# def prepare_extra_kwargs(func, exclude_kwargs=[], **kwargs): + +# accepted_kwargs = set(inspect.signature(func).parameters.keys()) +# extra_kwargs = {} +# for key, value in kwargs.items(): +# if key in accepted_kwargs and key not in exclude_kwargs: +# extra_kwargs[key] = value + +# return extra_kwargs + + +# @torch.no_grad() +# def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineState) -> PipelineState: + +# block_state = self.get_block_state(state) +# self.check_inputs(components, block_state) +# block_state.device = components._execution_device +# print(f" block_state: {block_state}") + +# controlnet = unwrap_module(components.controlnet) + +# # Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline +# block_state.extra_step_kwargs = self.prepare_extra_kwargs(components.scheduler.step, generator=block_state.generator, eta=block_state.eta) +# block_state.extra_controlnet_kwargs = self.prepare_extra_kwargs(controlnet.forward, exclude_kwargs=["controlnet_cond", "conditioning_scale", "guess_mode"], **block_state.controlnet_kwargs) + +# block_state.num_warmup_steps = max(len(block_state.timesteps) - block_state.num_inference_steps * components.scheduler.order, 0) + +# # (1) setup guider +# # disable for LCMs +# block_state.disable_guidance = True if components.unet.config.time_cond_proj_dim is not None else False +# if block_state.disable_guidance: +# components.guider.disable() +# else: +# components.guider.enable() +# components.guider.set_input_fields( +# prompt_embeds=("prompt_embeds", "negative_prompt_embeds"), +# add_time_ids=("add_time_ids", "negative_add_time_ids"), +# pooled_prompt_embeds=("pooled_prompt_embeds", "negative_pooled_prompt_embeds"), +# ip_adapter_embeds=("ip_adapter_embeds", "negative_ip_adapter_embeds"), +# ) + +# # (5) Denoise loop +# with self.progress_bar(total=block_state.num_inference_steps) as progress_bar: +# for i, t in enumerate(block_state.timesteps): + +# # prepare latent input for unet +# block_state.scaled_latents = components.scheduler.scale_model_input(block_state.latents, t) +# # adjust latent input for inpainting +# block_state.num_channels_unet = components.unet.config.in_channels +# if block_state.num_channels_unet == 9: +# block_state.scaled_latents = torch.cat([block_state.scaled_latents, block_state.mask, block_state.masked_image_latents], dim=1) + + +# # cond_scale (controlnet input) +# if isinstance(block_state.controlnet_keep[i], list): +# block_state.cond_scale = [c * s for c, s in zip(block_state.conditioning_scale, block_state.controlnet_keep[i])] +# else: +# block_state.controlnet_cond_scale = block_state.conditioning_scale +# if isinstance(block_state.controlnet_cond_scale, list): +# block_state.controlnet_cond_scale = block_state.controlnet_cond_scale[0] +# block_state.cond_scale = block_state.controlnet_cond_scale * block_state.controlnet_keep[i] + +# # default controlnet output/unet input for guess mode + conditional path +# block_state.down_block_res_samples_zeros = None +# block_state.mid_block_res_sample_zeros = None + +# # guided denoiser step +# components.guider.set_state(step=i, num_inference_steps=block_state.num_inference_steps, timestep=t) +# guider_state = components.guider.prepare_inputs(block_state) + +# for guider_state_batch in guider_state: +# components.guider.prepare_models(components.unet) + +# # Prepare additional conditionings +# guider_state_batch.added_cond_kwargs = { +# "text_embeds": guider_state_batch.pooled_prompt_embeds, +# "time_ids": guider_state_batch.add_time_ids, +# } +# if guider_state_batch.ip_adapter_embeds is not None: +# guider_state_batch.added_cond_kwargs["image_embeds"] = guider_state_batch.ip_adapter_embeds + +# # Prepare controlnet additional conditionings +# guider_state_batch.controlnet_added_cond_kwargs = { +# "text_embeds": guider_state_batch.pooled_prompt_embeds, +# "time_ids": guider_state_batch.add_time_ids, +# } + +# if block_state.guess_mode and not components.guider.is_conditional: +# # guider always run uncond batch first, so these tensors should be set already +# guider_state_batch.down_block_res_samples = block_state.down_block_res_samples_zeros +# guider_state_batch.mid_block_res_sample = block_state.mid_block_res_sample_zeros +# else: +# guider_state_batch.down_block_res_samples, guider_state_batch.mid_block_res_sample = components.controlnet( +# block_state.scaled_latents, +# t, +# encoder_hidden_states=guider_state_batch.prompt_embeds, +# controlnet_cond=block_state.controlnet_cond, +# conditioning_scale=block_state.conditioning_scale, +# guess_mode=block_state.guess_mode, +# added_cond_kwargs=guider_state_batch.controlnet_added_cond_kwargs, +# return_dict=False, +# **block_state.extra_controlnet_kwargs, +# ) + +# if block_state.down_block_res_samples_zeros is None: +# block_state.down_block_res_samples_zeros = [torch.zeros_like(d) for d in guider_state_batch.down_block_res_samples] +# if block_state.mid_block_res_sample_zeros is None: +# block_state.mid_block_res_sample_zeros = torch.zeros_like(guider_state_batch.mid_block_res_sample) + + + +# guider_state_batch.noise_pred = components.unet( +# block_state.scaled_latents, +# t, +# encoder_hidden_states=guider_state_batch.prompt_embeds, +# timestep_cond=block_state.timestep_cond, +# cross_attention_kwargs=block_state.cross_attention_kwargs, +# added_cond_kwargs=guider_state_batch.added_cond_kwargs, +# down_block_additional_residuals=guider_state_batch.down_block_res_samples, +# mid_block_additional_residual=guider_state_batch.mid_block_res_sample, +# return_dict=False, +# )[0] +# components.guider.cleanup_models(components.unet) + +# # Perform guidance +# block_state.noise_pred, scheduler_step_kwargs = components.guider(guider_state) + +# # Perform scheduler step using the predicted output +# block_state.latents_dtype = block_state.latents.dtype +# block_state.latents = components.scheduler.step(block_state.noise_pred, t, block_state.latents, **block_state.extra_step_kwargs, **scheduler_step_kwargs, return_dict=False)[0] + +# if block_state.latents.dtype != block_state.latents_dtype: +# if torch.backends.mps.is_available(): +# # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 +# block_state.latents = block_state.latents.to(block_state.latents_dtype) + +# # adjust latent for inpainting +# if block_state.num_channels_unet == 4 and block_state.mask is not None and block_state.image_latents is not None: +# block_state.init_latents_proper = block_state.image_latents +# if i < len(block_state.timesteps) - 1: +# block_state.noise_timestep = block_state.timesteps[i + 1] +# block_state.init_latents_proper = components.scheduler.add_noise( +# block_state.init_latents_proper, block_state.noise, torch.tensor([block_state.noise_timestep]) +# ) + +# block_state.latents = (1 - block_state.mask) * block_state.init_latents_proper + block_state.mask * block_state.latents + +# if i == len(block_state.timesteps) - 1 or ((i + 1) > block_state.num_warmup_steps and (i + 1) % components.scheduler.order == 0): +# progress_bar.update() + +# self.add_block_state(state, block_state) + +# return components, state \ No newline at end of file diff --git a/src/diffusers/modular_pipelines/stable_diffusion_xl/encoders.py b/src/diffusers/modular_pipelines/stable_diffusion_xl/encoders.py new file mode 100644 index 000000000000..ca4efe2c4a7f --- /dev/null +++ b/src/diffusers/modular_pipelines/stable_diffusion_xl/encoders.py @@ -0,0 +1,858 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Any, List, Optional, Tuple, Union, Dict + +import PIL +import torch +from collections import OrderedDict + +from ...image_processor import VaeImageProcessor, PipelineImageInput +from ...loaders import StableDiffusionXLLoraLoaderMixin, TextualInversionLoaderMixin, ModularIPAdapterMixin +from ...models import ControlNetModel, ImageProjection, UNet2DConditionModel, AutoencoderKL, ControlNetUnionModel +from ...models.attention_processor import AttnProcessor2_0, XFormersAttnProcessor +from ...models.lora import adjust_lora_scale_text_encoder +from ...utils import ( + USE_PEFT_BACKEND, + logging, + scale_lora_layers, + unscale_lora_layers, +) +from ...utils.torch_utils import randn_tensor, unwrap_module +from ...pipelines.controlnet.multicontrolnet import MultiControlNetModel +from ...configuration_utils import FrozenDict + +from transformers import ( + CLIPTextModel, + CLIPImageProcessor, + CLIPTextModelWithProjection, + CLIPTokenizer, + CLIPVisionModelWithProjection, +) + +from ...schedulers import EulerDiscreteScheduler +from ...guiders import ClassifierFreeGuidance + +from .modular_loader import StableDiffusionXLModularLoader +from ..modular_pipeline import PipelineBlock, PipelineState, AutoPipelineBlocks, SequentialPipelineBlocks +from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam, ConfigSpec + +import numpy as np + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents( + encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + +class StableDiffusionXLIPAdapterStep(PipelineBlock): + model_name = "stable-diffusion-xl" + + + @property + def description(self) -> str: + return ( + "IP Adapter step that handles all the ip adapter related tasks: Load/unload ip adapter weights into unet, prepare ip adapter image embeddings, etc" + " See [ModularIPAdapterMixin](https://huggingface.co/docs/diffusers/api/loaders/ip_adapter#diffusers.loaders.ModularIPAdapterMixin)" + " for more details" + ) + + @property + def expected_components(self) -> List[ComponentSpec]: + return [ + ComponentSpec("image_encoder", CLIPVisionModelWithProjection), + ComponentSpec("feature_extractor", CLIPImageProcessor, config=FrozenDict({"size": 224, "crop_size": 224}), default_creation_method="from_config"), + ComponentSpec("unet", UNet2DConditionModel), + ComponentSpec( + "guider", + ClassifierFreeGuidance, + config=FrozenDict({"guidance_scale": 7.5}), + default_creation_method="from_config"), + ] + + @property + def inputs(self) -> List[InputParam]: + return [ + InputParam( + "ip_adapter_image", + PipelineImageInput, + required=True, + description="The image(s) to be used as ip adapter" + ) + ] + + + @property + def intermediates_outputs(self) -> List[OutputParam]: + return [ + OutputParam("ip_adapter_embeds", type_hint=torch.Tensor, description="IP adapter image embeddings"), + OutputParam("negative_ip_adapter_embeds", type_hint=torch.Tensor, description="Negative IP adapter image embeddings") + ] + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image with self -> components + @staticmethod + def encode_image(components, image, device, num_images_per_prompt, output_hidden_states=None): + dtype = next(components.image_encoder.parameters()).dtype + + if not isinstance(image, torch.Tensor): + image = components.feature_extractor(image, return_tensors="pt").pixel_values + + image = image.to(device=device, dtype=dtype) + if output_hidden_states: + image_enc_hidden_states = components.image_encoder(image, output_hidden_states=True).hidden_states[-2] + image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_enc_hidden_states = components.image_encoder( + torch.zeros_like(image), output_hidden_states=True + ).hidden_states[-2] + uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( + num_images_per_prompt, dim=0 + ) + return image_enc_hidden_states, uncond_image_enc_hidden_states + else: + image_embeds = components.image_encoder(image).image_embeds + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_embeds = torch.zeros_like(image_embeds) + + return image_embeds, uncond_image_embeds + + # modified from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds + def prepare_ip_adapter_image_embeds( + self, components, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, prepare_unconditional_embeds + ): + image_embeds = [] + if prepare_unconditional_embeds: + negative_image_embeds = [] + if ip_adapter_image_embeds is None: + if not isinstance(ip_adapter_image, list): + ip_adapter_image = [ip_adapter_image] + + if len(ip_adapter_image) != len(components.unet.encoder_hid_proj.image_projection_layers): + raise ValueError( + f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(components.unet.encoder_hid_proj.image_projection_layers)} IP Adapters." + ) + + for single_ip_adapter_image, image_proj_layer in zip( + ip_adapter_image, components.unet.encoder_hid_proj.image_projection_layers + ): + output_hidden_state = not isinstance(image_proj_layer, ImageProjection) + single_image_embeds, single_negative_image_embeds = self.encode_image( + components, single_ip_adapter_image, device, 1, output_hidden_state + ) + + image_embeds.append(single_image_embeds[None, :]) + if prepare_unconditional_embeds: + negative_image_embeds.append(single_negative_image_embeds[None, :]) + else: + for single_image_embeds in ip_adapter_image_embeds: + if prepare_unconditional_embeds: + single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2) + negative_image_embeds.append(single_negative_image_embeds) + image_embeds.append(single_image_embeds) + + ip_adapter_image_embeds = [] + for i, single_image_embeds in enumerate(image_embeds): + single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0) + if prepare_unconditional_embeds: + single_negative_image_embeds = torch.cat([negative_image_embeds[i]] * num_images_per_prompt, dim=0) + single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0) + + single_image_embeds = single_image_embeds.to(device=device) + ip_adapter_image_embeds.append(single_image_embeds) + + return ip_adapter_image_embeds + + @torch.no_grad() + def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + + block_state.prepare_unconditional_embeds = components.guider.num_conditions > 1 + block_state.device = components._execution_device + + block_state.ip_adapter_embeds = self.prepare_ip_adapter_image_embeds( + components, + ip_adapter_image=block_state.ip_adapter_image, + ip_adapter_image_embeds=None, + device=block_state.device, + num_images_per_prompt=1, + prepare_unconditional_embeds=block_state.prepare_unconditional_embeds, + ) + if block_state.prepare_unconditional_embeds: + block_state.negative_ip_adapter_embeds = [] + for i, image_embeds in enumerate(block_state.ip_adapter_embeds): + negative_image_embeds, image_embeds = image_embeds.chunk(2) + block_state.negative_ip_adapter_embeds.append(negative_image_embeds) + block_state.ip_adapter_embeds[i] = image_embeds + + self.add_block_state(state, block_state) + return components, state + + +class StableDiffusionXLTextEncoderStep(PipelineBlock): + + model_name = "stable-diffusion-xl" + + @property + def description(self) -> str: + return( + "Text Encoder step that generate text_embeddings to guide the image generation" + ) + + @property + def expected_components(self) -> List[ComponentSpec]: + return [ + ComponentSpec("text_encoder", CLIPTextModel), + ComponentSpec("text_encoder_2", CLIPTextModelWithProjection), + ComponentSpec("tokenizer", CLIPTokenizer), + ComponentSpec("tokenizer_2", CLIPTokenizer), + ComponentSpec( + "guider", + ClassifierFreeGuidance, + config=FrozenDict({"guidance_scale": 7.5}), + default_creation_method="from_config"), + ] + + @property + def expected_configs(self) -> List[ConfigSpec]: + return [ConfigSpec("force_zeros_for_empty_prompt", True)] + + @property + def inputs(self) -> List[InputParam]: + return [ + InputParam("prompt"), + InputParam("prompt_2"), + InputParam("negative_prompt"), + InputParam("negative_prompt_2"), + InputParam("cross_attention_kwargs"), + InputParam("clip_skip"), + ] + + + @property + def intermediates_outputs(self) -> List[OutputParam]: + return [ + OutputParam("prompt_embeds", type_hint=torch.Tensor, kwargs_type="guider_input_fields",description="text embeddings used to guide the image generation"), + OutputParam("negative_prompt_embeds", type_hint=torch.Tensor, kwargs_type="guider_input_fields", description="negative text embeddings used to guide the image generation"), + OutputParam("pooled_prompt_embeds", type_hint=torch.Tensor, kwargs_type="guider_input_fields", description="pooled text embeddings used to guide the image generation"), + OutputParam("negative_pooled_prompt_embeds", type_hint=torch.Tensor, kwargs_type="guider_input_fields", description="negative pooled text embeddings used to guide the image generation"), + ] + + @staticmethod + def check_inputs(block_state): + + if block_state.prompt is not None and (not isinstance(block_state.prompt, str) and not isinstance(block_state.prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(block_state.prompt)}") + elif block_state.prompt_2 is not None and (not isinstance(block_state.prompt_2, str) and not isinstance(block_state.prompt_2, list)): + raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(block_state.prompt_2)}") + + @staticmethod + def encode_prompt( + components, + prompt: str, + prompt_2: Optional[str] = None, + device: Optional[torch.device] = None, + num_images_per_prompt: int = 1, + prepare_unconditional_embeds: bool = True, + negative_prompt: Optional[str] = None, + negative_prompt_2: Optional[str] = None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + pooled_prompt_embeds: Optional[torch.Tensor] = None, + negative_pooled_prompt_embeds: Optional[torch.Tensor] = None, + lora_scale: Optional[float] = None, + clip_skip: Optional[int] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + used in both text-encoders + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + prepare_unconditional_embeds (`bool`): + whether to use prepare unconditional embeddings or not + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + negative_prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and + `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + pooled_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + negative_pooled_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` + input argument. + lora_scale (`float`, *optional*): + A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + """ + device = device or components._execution_device + + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(components, StableDiffusionXLLoraLoaderMixin): + components._lora_scale = lora_scale + + # dynamically adjust the LoRA scale + if components.text_encoder is not None: + if not USE_PEFT_BACKEND: + adjust_lora_scale_text_encoder(components.text_encoder, lora_scale) + else: + scale_lora_layers(components.text_encoder, lora_scale) + + if components.text_encoder_2 is not None: + if not USE_PEFT_BACKEND: + adjust_lora_scale_text_encoder(components.text_encoder_2, lora_scale) + else: + scale_lora_layers(components.text_encoder_2, lora_scale) + + prompt = [prompt] if isinstance(prompt, str) else prompt + + if prompt is not None: + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + # Define tokenizers and text encoders + tokenizers = [components.tokenizer, components.tokenizer_2] if components.tokenizer is not None else [components.tokenizer_2] + text_encoders = ( + [components.text_encoder, components.text_encoder_2] if components.text_encoder is not None else [components.text_encoder_2] + ) + + if prompt_embeds is None: + prompt_2 = prompt_2 or prompt + prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2 + + # textual inversion: process multi-vector tokens if necessary + prompt_embeds_list = [] + prompts = [prompt, prompt_2] + for prompt, tokenizer, text_encoder in zip(prompts, tokenizers, text_encoders): + if isinstance(components, TextualInversionLoaderMixin): + prompt = components.maybe_convert_prompt(prompt, tokenizer) + + text_inputs = tokenizer( + prompt, + padding="max_length", + max_length=tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + + text_input_ids = text_inputs.input_ids + untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {tokenizer.model_max_length} tokens: {removed_text}" + ) + + prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True) + + # We are only ALWAYS interested in the pooled output of the final text encoder + pooled_prompt_embeds = prompt_embeds[0] + if clip_skip is None: + prompt_embeds = prompt_embeds.hidden_states[-2] + else: + # "2" because SDXL always indexes from the penultimate layer. + prompt_embeds = prompt_embeds.hidden_states[-(clip_skip + 2)] + + prompt_embeds_list.append(prompt_embeds) + + prompt_embeds = torch.concat(prompt_embeds_list, dim=-1) + + # get unconditional embeddings for classifier free guidance + zero_out_negative_prompt = negative_prompt is None and components.config.force_zeros_for_empty_prompt + if prepare_unconditional_embeds and negative_prompt_embeds is None and zero_out_negative_prompt: + negative_prompt_embeds = torch.zeros_like(prompt_embeds) + negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds) + elif prepare_unconditional_embeds and negative_prompt_embeds is None: + negative_prompt = negative_prompt or "" + negative_prompt_2 = negative_prompt_2 or negative_prompt + + # normalize str to list + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + negative_prompt_2 = ( + batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2 + ) + + uncond_tokens: List[str] + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = [negative_prompt, negative_prompt_2] + + negative_prompt_embeds_list = [] + for negative_prompt, tokenizer, text_encoder in zip(uncond_tokens, tokenizers, text_encoders): + if isinstance(components, TextualInversionLoaderMixin): + negative_prompt = components.maybe_convert_prompt(negative_prompt, tokenizer) + + max_length = prompt_embeds.shape[1] + uncond_input = tokenizer( + negative_prompt, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + negative_prompt_embeds = text_encoder( + uncond_input.input_ids.to(device), + output_hidden_states=True, + ) + # We are only ALWAYS interested in the pooled output of the final text encoder + negative_pooled_prompt_embeds = negative_prompt_embeds[0] + negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2] + + negative_prompt_embeds_list.append(negative_prompt_embeds) + + negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1) + + if components.text_encoder_2 is not None: + prompt_embeds = prompt_embeds.to(dtype=components.text_encoder_2.dtype, device=device) + else: + prompt_embeds = prompt_embeds.to(dtype=components.unet.dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + if prepare_unconditional_embeds: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + if components.text_encoder_2 is not None: + negative_prompt_embeds = negative_prompt_embeds.to(dtype=components.text_encoder_2.dtype, device=device) + else: + negative_prompt_embeds = negative_prompt_embeds.to(dtype=components.unet.dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( + bs_embed * num_images_per_prompt, -1 + ) + if prepare_unconditional_embeds: + negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( + bs_embed * num_images_per_prompt, -1 + ) + + if components.text_encoder is not None: + if isinstance(components, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(components.text_encoder, lora_scale) + + if components.text_encoder_2 is not None: + if isinstance(components, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(components.text_encoder_2, lora_scale) + + return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds + + + @torch.no_grad() + def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineState) -> PipelineState: + # Get inputs and intermediates + block_state = self.get_block_state(state) + self.check_inputs(block_state) + + block_state.prepare_unconditional_embeds = components.guider.num_conditions > 1 + block_state.device = components._execution_device + + # Encode input prompt + block_state.text_encoder_lora_scale = ( + block_state.cross_attention_kwargs.get("scale", None) if block_state.cross_attention_kwargs is not None else None + ) + ( + block_state.prompt_embeds, + block_state.negative_prompt_embeds, + block_state.pooled_prompt_embeds, + block_state.negative_pooled_prompt_embeds, + ) = self.encode_prompt( + components, + block_state.prompt, + block_state.prompt_2, + block_state.device, + 1, + block_state.prepare_unconditional_embeds, + block_state.negative_prompt, + block_state.negative_prompt_2, + prompt_embeds=None, + negative_prompt_embeds=None, + pooled_prompt_embeds=None, + negative_pooled_prompt_embeds=None, + lora_scale=block_state.text_encoder_lora_scale, + clip_skip=block_state.clip_skip, + ) + # Add outputs + self.add_block_state(state, block_state) + return components, state + + +class StableDiffusionXLVaeEncoderStep(PipelineBlock): + + model_name = "stable-diffusion-xl" + + + @property + def description(self) -> str: + return ( + "Vae Encoder step that encode the input image into a latent representation" + ) + + @property + def expected_components(self) -> List[ComponentSpec]: + return [ + ComponentSpec("vae", AutoencoderKL), + ComponentSpec( + "image_processor", + VaeImageProcessor, + config=FrozenDict({"vae_scale_factor": 8}), + default_creation_method="from_config"), + ] + + @property + def inputs(self) -> List[InputParam]: + return [ + InputParam("image", required=True), + InputParam("height"), + InputParam("width"), + ] + + @property + def intermediates_inputs(self) -> List[InputParam]: + return [ + InputParam("generator"), + InputParam("dtype", type_hint=torch.dtype, description="Data type of model tensor inputs"), + InputParam("preprocess_kwargs", type_hint=Optional[dict], description="A kwargs dictionary that if specified is passed along to the `ImageProcessor` as defined under `self.image_processor` in [diffusers.image_processor.VaeImageProcessor]")] + + @property + def intermediates_outputs(self) -> List[OutputParam]: + return [OutputParam("image_latents", type_hint=torch.Tensor, description="The latents representing the reference image for image-to-image/inpainting generation")] + + # Modified from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_inpaint.StableDiffusionXLInpaintPipeline._encode_vae_image with self -> components + # YiYi TODO: update the _encode_vae_image so that we can use #Coped from + def _encode_vae_image(self, components, image: torch.Tensor, generator: torch.Generator): + + latents_mean = latents_std = None + if hasattr(components.vae.config, "latents_mean") and components.vae.config.latents_mean is not None: + latents_mean = torch.tensor(components.vae.config.latents_mean).view(1, 4, 1, 1) + if hasattr(components.vae.config, "latents_std") and components.vae.config.latents_std is not None: + latents_std = torch.tensor(components.vae.config.latents_std).view(1, 4, 1, 1) + + dtype = image.dtype + if components.vae.config.force_upcast: + image = image.float() + components.vae.to(dtype=torch.float32) + + if isinstance(generator, list): + image_latents = [ + retrieve_latents(components.vae.encode(image[i : i + 1]), generator=generator[i]) + for i in range(image.shape[0]) + ] + image_latents = torch.cat(image_latents, dim=0) + else: + image_latents = retrieve_latents(components.vae.encode(image), generator=generator) + + if components.vae.config.force_upcast: + components.vae.to(dtype) + + image_latents = image_latents.to(dtype) + if latents_mean is not None and latents_std is not None: + latents_mean = latents_mean.to(device=image_latents.device, dtype=dtype) + latents_std = latents_std.to(device=image_latents.device, dtype=dtype) + image_latents = (image_latents - latents_mean) * components.vae.config.scaling_factor / latents_std + else: + image_latents = components.vae.config.scaling_factor * image_latents + + return image_latents + + + + @torch.no_grad() + def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + block_state.preprocess_kwargs = block_state.preprocess_kwargs or {} + block_state.device = components._execution_device + block_state.dtype = block_state.dtype if block_state.dtype is not None else components.vae.dtype + + block_state.image = components.image_processor.preprocess(block_state.image, height=block_state.height, width=block_state.width, **block_state.preprocess_kwargs) + block_state.image = block_state.image.to(device=block_state.device, dtype=block_state.dtype) + + block_state.batch_size = block_state.image.shape[0] + + # if generator is a list, make sure the length of it matches the length of images (both should be batch_size) + if isinstance(block_state.generator, list) and len(block_state.generator) != block_state.batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(block_state.generator)}, but requested an effective batch" + f" size of {block_state.batch_size}. Make sure the batch size matches the length of the generators." + ) + + + block_state.image_latents = self._encode_vae_image(components, image=block_state.image, generator=block_state.generator) + + self.add_block_state(state, block_state) + + return components, state + + +class StableDiffusionXLInpaintVaeEncoderStep(PipelineBlock): + model_name = "stable-diffusion-xl" + + @property + def expected_components(self) -> List[ComponentSpec]: + return [ + ComponentSpec("vae", AutoencoderKL), + ComponentSpec( + "image_processor", + VaeImageProcessor, + config=FrozenDict({"vae_scale_factor": 8}), + default_creation_method="from_config"), + ComponentSpec( + "mask_processor", + VaeImageProcessor, + config=FrozenDict({"do_normalize": False, "vae_scale_factor": 8, "do_binarize": True, "do_convert_grayscale": True}), + default_creation_method="from_config"), + ] + + + @property + def description(self) -> str: + return ( + "Vae encoder step that prepares the image and mask for the inpainting process" + ) + + @property + def inputs(self) -> List[InputParam]: + return [ + InputParam("height"), + InputParam("width"), + InputParam("image", required=True), + InputParam("mask_image", required=True), + InputParam("padding_mask_crop"), + ] + + @property + def intermediates_inputs(self) -> List[InputParam]: + return [ + InputParam("dtype", type_hint=torch.dtype, description="The dtype of the model inputs"), + InputParam("generator"), + ] + + @property + def intermediates_outputs(self) -> List[OutputParam]: + return [OutputParam("image_latents", type_hint=torch.Tensor, description="The latents representation of the input image"), + OutputParam("mask", type_hint=torch.Tensor, description="The mask to use for the inpainting process"), + OutputParam("masked_image_latents", type_hint=torch.Tensor, description="The masked image latents to use for the inpainting process (only for inpainting-specifid unet)"), + OutputParam("crops_coords", type_hint=Optional[Tuple[int, int]], description="The crop coordinates to use for the preprocess/postprocess of the image and mask")] + + # Modified from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_inpaint.StableDiffusionXLInpaintPipeline._encode_vae_image with self -> components + # YiYi TODO: update the _encode_vae_image so that we can use #Coped from + def _encode_vae_image(self, components, image: torch.Tensor, generator: torch.Generator): + + latents_mean = latents_std = None + if hasattr(components.vae.config, "latents_mean") and components.vae.config.latents_mean is not None: + latents_mean = torch.tensor(components.vae.config.latents_mean).view(1, 4, 1, 1) + if hasattr(components.vae.config, "latents_std") and components.vae.config.latents_std is not None: + latents_std = torch.tensor(components.vae.config.latents_std).view(1, 4, 1, 1) + + dtype = image.dtype + if components.vae.config.force_upcast: + image = image.float() + components.vae.to(dtype=torch.float32) + + if isinstance(generator, list): + image_latents = [ + retrieve_latents(components.vae.encode(image[i : i + 1]), generator=generator[i]) + for i in range(image.shape[0]) + ] + image_latents = torch.cat(image_latents, dim=0) + else: + image_latents = retrieve_latents(components.vae.encode(image), generator=generator) + + if components.vae.config.force_upcast: + components.vae.to(dtype) + + image_latents = image_latents.to(dtype) + if latents_mean is not None and latents_std is not None: + latents_mean = latents_mean.to(device=image_latents.device, dtype=dtype) + latents_std = latents_std.to(device=image_latents.device, dtype=dtype) + image_latents = (image_latents - latents_mean) * self.vae.config.scaling_factor / latents_std + else: + image_latents = components.vae.config.scaling_factor * image_latents + + return image_latents + + # modified from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_inpaint.StableDiffusionXLInpaintPipeline.prepare_mask_latents + # do not accept do_classifier_free_guidance + def prepare_mask_latents( + self, components, mask, masked_image, batch_size, height, width, dtype, device, generator + ): + # resize the mask to latents shape as we concatenate the mask to the latents + # we do that before converting to dtype to avoid breaking in case we're using cpu_offload + # and half precision + mask = torch.nn.functional.interpolate( + mask, size=(height // components.vae_scale_factor, width // components.vae_scale_factor) + ) + mask = mask.to(device=device, dtype=dtype) + + # duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method + if mask.shape[0] < batch_size: + if not batch_size % mask.shape[0] == 0: + raise ValueError( + "The passed mask and the required batch size don't match. Masks are supposed to be duplicated to" + f" a total batch size of {batch_size}, but {mask.shape[0]} masks were passed. Make sure the number" + " of masks that you pass is divisible by the total requested batch size." + ) + mask = mask.repeat(batch_size // mask.shape[0], 1, 1, 1) + + if masked_image is not None and masked_image.shape[1] == 4: + masked_image_latents = masked_image + else: + masked_image_latents = None + + if masked_image is not None: + if masked_image_latents is None: + masked_image = masked_image.to(device=device, dtype=dtype) + masked_image_latents = self._encode_vae_image(components, masked_image, generator=generator) + + if masked_image_latents.shape[0] < batch_size: + if not batch_size % masked_image_latents.shape[0] == 0: + raise ValueError( + "The passed images and the required batch size don't match. Images are supposed to be duplicated" + f" to a total batch size of {batch_size}, but {masked_image_latents.shape[0]} images were passed." + " Make sure the number of images that you pass is divisible by the total requested batch size." + ) + masked_image_latents = masked_image_latents.repeat( + batch_size // masked_image_latents.shape[0], 1, 1, 1 + ) + + # aligning device to prevent device errors when concating it with the latent model input + masked_image_latents = masked_image_latents.to(device=device, dtype=dtype) + + return mask, masked_image_latents + + + + @torch.no_grad() + def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineState) -> PipelineState: + + block_state = self.get_block_state(state) + + block_state.dtype = block_state.dtype if block_state.dtype is not None else components.vae.dtype + block_state.device = components._execution_device + + if block_state.padding_mask_crop is not None: + block_state.crops_coords = components.mask_processor.get_crop_region(block_state.mask_image, block_state.width, block_state.height, pad=block_state.padding_mask_crop) + block_state.resize_mode = "fill" + else: + block_state.crops_coords = None + block_state.resize_mode = "default" + + block_state.image = components.image_processor.preprocess(block_state.image, height=block_state.height, width=block_state.width, crops_coords=block_state.crops_coords, resize_mode=block_state.resize_mode) + block_state.image = block_state.image.to(dtype=torch.float32) + + block_state.mask = components.mask_processor.preprocess(block_state.mask_image, height=block_state.height, width=block_state.width, resize_mode=block_state.resize_mode, crops_coords=block_state.crops_coords) + block_state.masked_image = block_state.image * (block_state.mask < 0.5) + + block_state.batch_size = block_state.image.shape[0] + block_state.image = block_state.image.to(device=block_state.device, dtype=block_state.dtype) + block_state.image_latents = self._encode_vae_image(components, image=block_state.image, generator=block_state.generator) + + # 7. Prepare mask latent variables + block_state.mask, block_state.masked_image_latents = self.prepare_mask_latents( + components, + block_state.mask, + block_state.masked_image, + block_state.batch_size, + block_state.height, + block_state.width, + block_state.dtype, + block_state.device, + block_state.generator, + ) + + self.add_block_state(state, block_state) + + + return components, state + + + +# auto blocks (YiYi TODO: maybe move all the auto blocks to a separate file) +# Encode +class StableDiffusionXLAutoVaeEncoderStep(AutoPipelineBlocks): + block_classes = [StableDiffusionXLInpaintVaeEncoderStep, StableDiffusionXLVaeEncoderStep] + block_names = ["inpaint", "img2img"] + block_trigger_inputs = ["mask_image", "image"] + + @property + def description(self): + return "Vae encoder step that encode the image inputs into their latent representations.\n" + \ + "This is an auto pipeline block that works for both inpainting and img2img tasks.\n" + \ + " - `StableDiffusionXLInpaintVaeEncoderStep` (inpaint) is used when both `mask_image` and `image` are provided.\n" + \ + " - `StableDiffusionXLVaeEncoderStep` (img2img) is used when only `image` is provided." + + +class StableDiffusionXLAutoIPAdapterStep(AutoPipelineBlocks, ModularIPAdapterMixin): + block_classes = [StableDiffusionXLIPAdapterStep] + block_names = ["ip_adapter"] + block_trigger_inputs = ["ip_adapter_image"] + + @property + def description(self): + return "Run IP Adapter step if `ip_adapter_image` is provided." + diff --git a/src/diffusers/modular_pipelines/stable_diffusion_xl/modular_loader.py b/src/diffusers/modular_pipelines/stable_diffusion_xl/modular_loader.py new file mode 100644 index 000000000000..4af942af64e6 --- /dev/null +++ b/src/diffusers/modular_pipelines/stable_diffusion_xl/modular_loader.py @@ -0,0 +1,174 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, List, Optional, Tuple, Union, Dict +import PIL +import torch +import numpy as np + +from ...loaders import StableDiffusionXLLoraLoaderMixin, TextualInversionLoaderMixin, ModularIPAdapterMixin +from ...image_processor import PipelineImageInput +from ...pipelines.pipeline_utils import StableDiffusionMixin +from ...pipelines.stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput +from ...utils import logging + +from ..modular_pipeline import ModularLoader +from ..modular_pipeline_utils import InputParam, OutputParam + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + + +# YiYi TODO: move to a different file? stable_diffusion_xl_module should have its own folder? +# YiYi Notes: model specific components: +## (1) it should inherit from ModularLoader +## (2) acts like a container that holds components and configs +## (3) define default config (related to components), e.g. default_sample_size, vae_scale_factor, num_channels_unet, num_channels_latents +## (4) inherit from model-specic loader class (e.g. StableDiffusionXLLoraLoaderMixin) +## (5) how to use together with Components_manager? +class StableDiffusionXLModularLoader( + ModularLoader, + StableDiffusionMixin, + TextualInversionLoaderMixin, + StableDiffusionXLLoraLoaderMixin, + ModularIPAdapterMixin, +): + @property + def default_sample_size(self): + default_sample_size = 128 + if hasattr(self, "unet") and self.unet is not None: + default_sample_size = self.unet.config.sample_size + return default_sample_size + + @property + def vae_scale_factor(self): + vae_scale_factor = 8 + if hasattr(self, "vae") and self.vae is not None: + vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + return vae_scale_factor + + @property + def num_channels_unet(self): + num_channels_unet = 4 + if hasattr(self, "unet") and self.unet is not None: + num_channels_unet = self.unet.config.in_channels + return num_channels_unet + + @property + def num_channels_latents(self): + num_channels_latents = 4 + if hasattr(self, "vae") and self.vae is not None: + num_channels_latents = self.vae.config.latent_channels + return num_channels_latents + + + +# YiYi Notes: not used yet, maintain a list of schema that can be used across all pipeline blocks +SDXL_INPUTS_SCHEMA = { + "prompt": InputParam("prompt", type_hint=Union[str, List[str]], description="The prompt or prompts to guide the image generation"), + "prompt_2": InputParam("prompt_2", type_hint=Union[str, List[str]], description="The prompt or prompts to be sent to the tokenizer_2 and text_encoder_2"), + "negative_prompt": InputParam("negative_prompt", type_hint=Union[str, List[str]], description="The prompt or prompts not to guide the image generation"), + "negative_prompt_2": InputParam("negative_prompt_2", type_hint=Union[str, List[str]], description="The negative prompt or prompts for text_encoder_2"), + "cross_attention_kwargs": InputParam("cross_attention_kwargs", type_hint=Optional[dict], description="Kwargs dictionary passed to the AttentionProcessor"), + "clip_skip": InputParam("clip_skip", type_hint=Optional[int], description="Number of layers to skip in CLIP text encoder"), + "image": InputParam("image", type_hint=PipelineImageInput, required=True, description="The image(s) to modify for img2img or inpainting"), + "mask_image": InputParam("mask_image", type_hint=PipelineImageInput, required=True, description="Mask image for inpainting, white pixels will be repainted"), + "generator": InputParam("generator", type_hint=Optional[Union[torch.Generator, List[torch.Generator]]], description="Generator(s) for deterministic generation"), + "height": InputParam("height", type_hint=Optional[int], description="Height in pixels of the generated image"), + "width": InputParam("width", type_hint=Optional[int], description="Width in pixels of the generated image"), + "num_images_per_prompt": InputParam("num_images_per_prompt", type_hint=int, default=1, description="Number of images to generate per prompt"), + "num_inference_steps": InputParam("num_inference_steps", type_hint=int, default=50, description="Number of denoising steps"), + "timesteps": InputParam("timesteps", type_hint=Optional[torch.Tensor], description="Custom timesteps for the denoising process"), + "sigmas": InputParam("sigmas", type_hint=Optional[torch.Tensor], description="Custom sigmas for the denoising process"), + "denoising_end": InputParam("denoising_end", type_hint=Optional[float], description="Fraction of denoising process to complete before termination"), + # YiYi Notes: img2img defaults to 0.3, inpainting defaults to 0.9999 + "strength": InputParam("strength", type_hint=float, default=0.3, description="How much to transform the reference image"), + "denoising_start": InputParam("denoising_start", type_hint=Optional[float], description="Starting point of the denoising process"), + "latents": InputParam("latents", type_hint=Optional[torch.Tensor], description="Pre-generated noisy latents for image generation"), + "padding_mask_crop": InputParam("padding_mask_crop", type_hint=Optional[Tuple[int, int]], description="Size of margin in crop for image and mask"), + "original_size": InputParam("original_size", type_hint=Optional[Tuple[int, int]], description="Original size of the image for SDXL's micro-conditioning"), + "target_size": InputParam("target_size", type_hint=Optional[Tuple[int, int]], description="Target size for SDXL's micro-conditioning"), + "negative_original_size": InputParam("negative_original_size", type_hint=Optional[Tuple[int, int]], description="Negative conditioning based on image resolution"), + "negative_target_size": InputParam("negative_target_size", type_hint=Optional[Tuple[int, int]], description="Negative conditioning based on target resolution"), + "crops_coords_top_left": InputParam("crops_coords_top_left", type_hint=Tuple[int, int], default=(0, 0), description="Top-left coordinates for SDXL's micro-conditioning"), + "negative_crops_coords_top_left": InputParam("negative_crops_coords_top_left", type_hint=Tuple[int, int], default=(0, 0), description="Negative conditioning crop coordinates"), + "aesthetic_score": InputParam("aesthetic_score", type_hint=float, default=6.0, description="Simulates aesthetic score of generated image"), + "negative_aesthetic_score": InputParam("negative_aesthetic_score", type_hint=float, default=2.0, description="Simulates negative aesthetic score"), + "eta": InputParam("eta", type_hint=float, default=0.0, description="Parameter η in the DDIM paper"), + "output_type": InputParam("output_type", type_hint=str, default="pil", description="Output format (pil/tensor/np.array)"), + "ip_adapter_image": InputParam("ip_adapter_image", type_hint=PipelineImageInput, required=True, description="Image(s) to be used as IP adapter"), + "control_image": InputParam("control_image", type_hint=PipelineImageInput, required=True, description="ControlNet input condition"), + "control_guidance_start": InputParam("control_guidance_start", type_hint=Union[float, List[float]], default=0.0, description="When ControlNet starts applying"), + "control_guidance_end": InputParam("control_guidance_end", type_hint=Union[float, List[float]], default=1.0, description="When ControlNet stops applying"), + "controlnet_conditioning_scale": InputParam("controlnet_conditioning_scale", type_hint=Union[float, List[float]], default=1.0, description="Scale factor for ControlNet outputs"), + "guess_mode": InputParam("guess_mode", type_hint=bool, default=False, description="Enables ControlNet encoder to recognize input without prompts"), + "control_mode": InputParam("control_mode", type_hint=List[int], required=True, description="Control mode for union controlnet") +} + + +SDXL_INTERMEDIATE_INPUTS_SCHEMA = { + "prompt_embeds": InputParam("prompt_embeds", type_hint=torch.Tensor, required=True, description="Text embeddings used to guide image generation"), + "negative_prompt_embeds": InputParam("negative_prompt_embeds", type_hint=torch.Tensor, description="Negative text embeddings"), + "pooled_prompt_embeds": InputParam("pooled_prompt_embeds", type_hint=torch.Tensor, required=True, description="Pooled text embeddings"), + "negative_pooled_prompt_embeds": InputParam("negative_pooled_prompt_embeds", type_hint=torch.Tensor, description="Negative pooled text embeddings"), + "batch_size": InputParam("batch_size", type_hint=int, required=True, description="Number of prompts"), + "dtype": InputParam("dtype", type_hint=torch.dtype, description="Data type of model tensor inputs"), + "preprocess_kwargs": InputParam("preprocess_kwargs", type_hint=Optional[dict], description="Kwargs for ImageProcessor"), + "latents": InputParam("latents", type_hint=torch.Tensor, required=True, description="Initial latents for denoising process"), + "timesteps": InputParam("timesteps", type_hint=torch.Tensor, required=True, description="Timesteps for inference"), + "num_inference_steps": InputParam("num_inference_steps", type_hint=int, required=True, description="Number of denoising steps"), + "latent_timestep": InputParam("latent_timestep", type_hint=torch.Tensor, required=True, description="Initial noise level timestep"), + "image_latents": InputParam("image_latents", type_hint=torch.Tensor, required=True, description="Latents representing reference image"), + "mask": InputParam("mask", type_hint=torch.Tensor, required=True, description="Mask for inpainting"), + "masked_image_latents": InputParam("masked_image_latents", type_hint=torch.Tensor, description="Masked image latents for inpainting"), + "add_time_ids": InputParam("add_time_ids", type_hint=torch.Tensor, required=True, description="Time ids for conditioning"), + "negative_add_time_ids": InputParam("negative_add_time_ids", type_hint=torch.Tensor, description="Negative time ids"), + "timestep_cond": InputParam("timestep_cond", type_hint=torch.Tensor, description="Timestep conditioning for LCM"), + "noise": InputParam("noise", type_hint=torch.Tensor, description="Noise added to image latents"), + "crops_coords": InputParam("crops_coords", type_hint=Optional[Tuple[int]], description="Crop coordinates"), + "ip_adapter_embeds": InputParam("ip_adapter_embeds", type_hint=List[torch.Tensor], description="Image embeddings for IP-Adapter"), + "negative_ip_adapter_embeds": InputParam("negative_ip_adapter_embeds", type_hint=List[torch.Tensor], description="Negative image embeddings for IP-Adapter"), + "images": InputParam("images", type_hint=Union[List[PIL.Image.Image], List[torch.Tensor], List[np.array]], required=True, description="Generated images") +} + + +SDXL_INTERMEDIATE_OUTPUTS_SCHEMA = { + "prompt_embeds": OutputParam("prompt_embeds", type_hint=torch.Tensor, description="Text embeddings used to guide image generation"), + "negative_prompt_embeds": OutputParam("negative_prompt_embeds", type_hint=torch.Tensor, description="Negative text embeddings"), + "pooled_prompt_embeds": OutputParam("pooled_prompt_embeds", type_hint=torch.Tensor, description="Pooled text embeddings"), + "negative_pooled_prompt_embeds": OutputParam("negative_pooled_prompt_embeds", type_hint=torch.Tensor, description="Negative pooled text embeddings"), + "batch_size": OutputParam("batch_size", type_hint=int, description="Number of prompts"), + "dtype": OutputParam("dtype", type_hint=torch.dtype, description="Data type of model tensor inputs"), + "image_latents": OutputParam("image_latents", type_hint=torch.Tensor, description="Latents representing reference image"), + "mask": OutputParam("mask", type_hint=torch.Tensor, description="Mask for inpainting"), + "masked_image_latents": OutputParam("masked_image_latents", type_hint=torch.Tensor, description="Masked image latents for inpainting"), + "crops_coords": OutputParam("crops_coords", type_hint=Optional[Tuple[int]], description="Crop coordinates"), + "timesteps": OutputParam("timesteps", type_hint=torch.Tensor, description="Timesteps for inference"), + "num_inference_steps": OutputParam("num_inference_steps", type_hint=int, description="Number of denoising steps"), + "latent_timestep": OutputParam("latent_timestep", type_hint=torch.Tensor, description="Initial noise level timestep"), + "add_time_ids": OutputParam("add_time_ids", type_hint=torch.Tensor, description="Time ids for conditioning"), + "negative_add_time_ids": OutputParam("negative_add_time_ids", type_hint=torch.Tensor, description="Negative time ids"), + "timestep_cond": OutputParam("timestep_cond", type_hint=torch.Tensor, description="Timestep conditioning for LCM"), + "latents": OutputParam("latents", type_hint=torch.Tensor, description="Denoised latents"), + "noise": OutputParam("noise", type_hint=torch.Tensor, description="Noise added to image latents"), + "ip_adapter_embeds": OutputParam("ip_adapter_embeds", type_hint=List[torch.Tensor], description="Image embeddings for IP-Adapter"), + "negative_ip_adapter_embeds": OutputParam("negative_ip_adapter_embeds", type_hint=List[torch.Tensor], description="Negative image embeddings for IP-Adapter"), + "images": OutputParam("images", type_hint=Union[List[PIL.Image.Image], List[torch.Tensor], List[np.array]], description="Generated images") +} + + +SDXL_OUTPUTS_SCHEMA = { + "images": OutputParam("images", type_hint=Union[Tuple[Union[List[PIL.Image.Image], List[torch.Tensor], List[np.array]]], StableDiffusionXLPipelineOutput], description="The final generated images") +} + diff --git a/src/diffusers/modular_pipelines/stable_diffusion_xl/modular_pipeline_block_mappings.py b/src/diffusers/modular_pipelines/stable_diffusion_xl/modular_pipeline_block_mappings.py new file mode 100644 index 000000000000..6d909ab5a4a0 --- /dev/null +++ b/src/diffusers/modular_pipelines/stable_diffusion_xl/modular_pipeline_block_mappings.py @@ -0,0 +1,128 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from collections import OrderedDict + +# Import all the necessary block classes +from .denoise import ( + StableDiffusionXLAutoDenoiseStep, + StableDiffusionXLDenoiseStep, + StableDiffusionXLControlNetDenoiseStep +) +from .before_denoise import ( + StableDiffusionXLAutoBeforeDenoiseStep, + StableDiffusionXLInputStep, + StableDiffusionXLSetTimestepsStep, + StableDiffusionXLPrepareLatentsStep, + StableDiffusionXLPrepareAdditionalConditioningStep, + StableDiffusionXLImg2ImgSetTimestepsStep, + StableDiffusionXLImg2ImgPrepareLatentsStep, + StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep, + StableDiffusionXLInpaintPrepareLatentsStep, + StableDiffusionXLControlNetInputStep, + StableDiffusionXLControlNetUnionInputStep +) +from .encoders import ( + StableDiffusionXLTextEncoderStep, + StableDiffusionXLAutoIPAdapterStep, + StableDiffusionXLAutoVaeEncoderStep, + StableDiffusionXLVaeEncoderStep, + StableDiffusionXLInpaintVaeEncoderStep, + StableDiffusionXLIPAdapterStep +) +from .decoders import ( + StableDiffusionXLDecodeStep, + StableDiffusionXLInpaintDecodeStep, + StableDiffusionXLAutoDecodeStep +) + + +# YiYi notes: comment out for now, work on this later +# block mapping +TEXT2IMAGE_BLOCKS = OrderedDict([ + ("text_encoder", StableDiffusionXLTextEncoderStep), + ("ip_adapter", StableDiffusionXLAutoIPAdapterStep), + ("input", StableDiffusionXLInputStep), + ("set_timesteps", StableDiffusionXLSetTimestepsStep), + ("prepare_latents", StableDiffusionXLPrepareLatentsStep), + ("prepare_add_cond", StableDiffusionXLPrepareAdditionalConditioningStep), + ("denoise", StableDiffusionXLDenoiseStep), + ("decode", StableDiffusionXLDecodeStep) +]) + +IMAGE2IMAGE_BLOCKS = OrderedDict([ + ("text_encoder", StableDiffusionXLTextEncoderStep), + ("ip_adapter", StableDiffusionXLAutoIPAdapterStep), + ("image_encoder", StableDiffusionXLVaeEncoderStep), + ("input", StableDiffusionXLInputStep), + ("set_timesteps", StableDiffusionXLImg2ImgSetTimestepsStep), + ("prepare_latents", StableDiffusionXLImg2ImgPrepareLatentsStep), + ("prepare_add_cond", StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep), + ("denoise", StableDiffusionXLDenoiseStep), + ("decode", StableDiffusionXLDecodeStep) +]) + +INPAINT_BLOCKS = OrderedDict([ + ("text_encoder", StableDiffusionXLTextEncoderStep), + ("ip_adapter", StableDiffusionXLAutoIPAdapterStep), + ("image_encoder", StableDiffusionXLInpaintVaeEncoderStep), + ("input", StableDiffusionXLInputStep), + ("set_timesteps", StableDiffusionXLImg2ImgSetTimestepsStep), + ("prepare_latents", StableDiffusionXLInpaintPrepareLatentsStep), + ("prepare_add_cond", StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep), + ("denoise", StableDiffusionXLDenoiseStep), + ("decode", StableDiffusionXLInpaintDecodeStep) +]) + +CONTROLNET_BLOCKS = OrderedDict([ + ("controlnet_input", StableDiffusionXLControlNetInputStep), + ("denoise", StableDiffusionXLControlNetDenoiseStep), +]) + +CONTROLNET_UNION_BLOCKS = OrderedDict([ + ("controlnet_input", StableDiffusionXLControlNetUnionInputStep), + ("denoise", StableDiffusionXLControlNetDenoiseStep), +]) + +IP_ADAPTER_BLOCKS = OrderedDict([ + ("ip_adapter", StableDiffusionXLIPAdapterStep), +]) + +AUTO_BLOCKS = OrderedDict([ + ("text_encoder", StableDiffusionXLTextEncoderStep), + ("ip_adapter", StableDiffusionXLAutoIPAdapterStep), + ("image_encoder", StableDiffusionXLAutoVaeEncoderStep), + ("before_denoise", StableDiffusionXLAutoBeforeDenoiseStep), + ("denoise", StableDiffusionXLAutoDenoiseStep), + ("decode", StableDiffusionXLAutoDecodeStep) +]) + +AUTO_CORE_BLOCKS = OrderedDict([ + ("before_denoise", StableDiffusionXLAutoBeforeDenoiseStep), + ("denoise", StableDiffusionXLAutoDenoiseStep), +]) + + +SDXL_SUPPORTED_BLOCKS = { + "text2img": TEXT2IMAGE_BLOCKS, + "img2img": IMAGE2IMAGE_BLOCKS, + "inpaint": INPAINT_BLOCKS, + "controlnet": CONTROLNET_BLOCKS, + "controlnet_union": CONTROLNET_UNION_BLOCKS, + "ip_adapter": IP_ADAPTER_BLOCKS, + "auto": AUTO_BLOCKS +} + + + diff --git a/src/diffusers/modular_pipelines/stable_diffusion_xl/modular_pipeline_presets.py b/src/diffusers/modular_pipelines/stable_diffusion_xl/modular_pipeline_presets.py new file mode 100644 index 000000000000..637c7ac306d7 --- /dev/null +++ b/src/diffusers/modular_pipelines/stable_diffusion_xl/modular_pipeline_presets.py @@ -0,0 +1,43 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, List, Optional, Tuple, Union, Dict +from ...utils import logging +from ..modular_pipeline import SequentialPipelineBlocks + +from .denoise import StableDiffusionXLAutoDenoiseStep +from .before_denoise import StableDiffusionXLAutoBeforeDenoiseStep +from .decoders import StableDiffusionXLAutoDecodeStep +from .encoders import StableDiffusionXLTextEncoderStep, StableDiffusionXLAutoIPAdapterStep, StableDiffusionXLAutoVaeEncoderStep + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +class StableDiffusionXLAutoPipeline(SequentialPipelineBlocks): + block_classes = [StableDiffusionXLTextEncoderStep, StableDiffusionXLAutoIPAdapterStep, StableDiffusionXLAutoVaeEncoderStep, StableDiffusionXLAutoBeforeDenoiseStep, StableDiffusionXLAutoDenoiseStep, StableDiffusionXLAutoDecodeStep] + block_names = ["text_encoder", "ip_adapter", "image_encoder", "before_denoise", "denoise", "decoder"] + + @property + def description(self): + return "Auto Modular pipeline for text-to-image, image-to-image, inpainting, and controlnet tasks using Stable Diffusion XL.\n" + \ + "- for image-to-image generation, you need to provide either `image` or `image_latents`\n" + \ + "- for inpainting, you need to provide `mask_image` and `image`, optionally you can provide `padding_mask_crop` \n" + \ + "- to run the controlnet workflow, you need to provide `control_image`\n" + \ + "- to run the controlnet_union workflow, you need to provide `control_image` and `control_mode`\n" + \ + "- to run the ip_adapter workflow, you need to provide `ip_adapter_image`\n" + \ + "- for text-to-image generation, all you need to provide is `prompt`" + + + + diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index 61ed023ce06b..011f23ed371c 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -47,7 +47,6 @@ "AutoPipelineForInpainting", "AutoPipelineForText2Image", ] - _import_structure["modular_pipeline"] = ["ModularPipeline"] _import_structure["consistency_models"] = ["ConsistencyModelPipeline"] _import_structure["dance_diffusion"] = ["DanceDiffusionPipeline"] _import_structure["ddim"] = ["DDIMPipeline"] @@ -330,8 +329,6 @@ "StableDiffusionXLInpaintPipeline", "StableDiffusionXLInstructPix2PixPipeline", "StableDiffusionXLPipeline", - "StableDiffusionXLModularPipeline", - "StableDiffusionXLAutoPipeline", ] ) _import_structure["stable_diffusion_diffedit"] = ["StableDiffusionDiffEditPipeline"] @@ -481,7 +478,6 @@ from .deprecated import KarrasVePipeline, LDMPipeline, PNDMPipeline, RePaintPipeline, ScoreSdeVePipeline from .dit import DiTPipeline from .latent_diffusion import LDMSuperResolutionPipeline - from .modular_pipeline import ModularPipeline from .pipeline_utils import ( AudioPipelineOutput, DiffusionPipeline, @@ -706,9 +702,7 @@ StableDiffusionXLImg2ImgPipeline, StableDiffusionXLInpaintPipeline, StableDiffusionXLInstructPix2PixPipeline, - StableDiffusionXLModularPipeline, StableDiffusionXLPipeline, - StableDiffusionXLAutoPipeline, ) from .stable_video_diffusion import StableVideoDiffusionPipeline from .t2i_adapter import ( diff --git a/src/diffusers/pipelines/modular_pipeline.py b/src/diffusers/pipelines/modular_pipeline.py deleted file mode 100644 index b50d00dbc219..000000000000 --- a/src/diffusers/pipelines/modular_pipeline.py +++ /dev/null @@ -1,1704 +0,0 @@ -# Copyright 2024 The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import traceback -import warnings -from collections import OrderedDict -from dataclasses import dataclass, field -from typing import Any, Dict, List, Tuple, Union - - -import torch -from tqdm.auto import tqdm -import re - -from ..configuration_utils import ConfigMixin -from ..utils import ( - is_accelerate_available, - is_accelerate_version, - logging, -) -from .pipeline_loading_utils import _get_pipeline_class - - -if is_accelerate_available(): - import accelerate - -logger = logging.get_logger(__name__) # pylint: disable=invalid-name - - -MODULAR_PIPELINE_MAPPING = OrderedDict( - [ - ("stable-diffusion-xl", "StableDiffusionXLModularPipeline"), - ] -) - - -@dataclass -class PipelineState: - """ - [`PipelineState`] stores the state of a pipeline. It is used to pass data between pipeline blocks. - """ - - inputs: Dict[str, Any] = field(default_factory=dict) - intermediates: Dict[str, Any] = field(default_factory=dict) - - def add_input(self, key: str, value: Any): - self.inputs[key] = value - - def add_intermediate(self, key: str, value: Any): - self.intermediates[key] = value - - def get_input(self, key: str, default: Any = None) -> Any: - return self.inputs.get(key, default) - - def get_inputs(self, keys: List[str], default: Any = None) -> Dict[str, Any]: - return {key: self.inputs.get(key, default) for key in keys} - - def get_intermediate(self, key: str, default: Any = None) -> Any: - return self.intermediates.get(key, default) - - def get_intermediates(self, keys: List[str], default: Any = None) -> Dict[str, Any]: - return {key: self.intermediates.get(key, default) for key in keys} - - def to_dict(self) -> Dict[str, Any]: - return {**self.__dict__, "inputs": self.inputs, "intermediates": self.intermediates} - - def __repr__(self): - def format_value(v): - if hasattr(v, "shape") and hasattr(v, "dtype"): - return f"Tensor(dtype={v.dtype}, shape={v.shape})" - elif isinstance(v, list) and len(v) > 0 and hasattr(v[0], "shape") and hasattr(v[0], "dtype"): - return f"[Tensor(dtype={v[0].dtype}, shape={v[0].shape}), ...]" - else: - return repr(v) - - inputs = "\n".join(f" {k}: {format_value(v)}" for k, v in self.inputs.items()) - intermediates = "\n".join(f" {k}: {format_value(v)}" for k, v in self.intermediates.items()) - - return ( - f"PipelineState(\n" - f" inputs={{\n{inputs}\n }},\n" - f" intermediates={{\n{intermediates}\n }}\n" - f")" - ) - - -@dataclass -class BlockState: - """ - Container for block state data with attribute access and formatted representation. - """ - def __init__(self, **kwargs): - for key, value in kwargs.items(): - setattr(self, key, value) - - def __repr__(self): - def format_value(v): - # Handle tensors directly - if hasattr(v, "shape") and hasattr(v, "dtype"): - return f"Tensor(dtype={v.dtype}, shape={v.shape})" - - # Handle lists of tensors - elif isinstance(v, list): - if len(v) > 0 and hasattr(v[0], "shape") and hasattr(v[0], "dtype"): - shapes = [t.shape for t in v] - return f"List[{len(v)}] of Tensors with shapes {shapes}" - return repr(v) - - # Handle tuples of tensors - elif isinstance(v, tuple): - if len(v) > 0 and hasattr(v[0], "shape") and hasattr(v[0], "dtype"): - shapes = [t.shape for t in v] - return f"Tuple[{len(v)}] of Tensors with shapes {shapes}" - return repr(v) - - # Handle dicts with tensor values - elif isinstance(v, dict): - if any(hasattr(val, "shape") and hasattr(val, "dtype") for val in v.values()): - shapes = {k: val.shape for k, val in v.items() if hasattr(val, "shape")} - return f"Dict of Tensors with shapes {shapes}" - return repr(v) - - # Default case - return repr(v) - - attributes = "\n".join(f" {k}: {format_value(v)}" for k, v in self.__dict__.items()) - return f"BlockState(\n{attributes}\n)" - - -@dataclass -class InputParam: - name: str - default: Any = None - required: bool = False - description: str = "" - type_hint: Any = Any - - def __repr__(self): - return f"<{self.name}: {'required' if self.required else 'optional'}, default={self.default}>" - -@dataclass -class OutputParam: - name: str - description: str = "" - type_hint: Any = Any - - def __repr__(self): - return f"<{self.name}: {self.type_hint.__name__ if hasattr(self.type_hint, '__name__') else str(self.type_hint)}>" - -def format_inputs_short(inputs): - """ - Format input parameters into a string representation, with required params first followed by optional ones. - - Args: - inputs: List of input parameters with 'required' and 'name' attributes, and 'default' for optional params - - Returns: - str: Formatted string of input parameters - """ - required_inputs = [param for param in inputs if param.required] - optional_inputs = [param for param in inputs if not param.required] - - required_str = ", ".join(param.name for param in required_inputs) - optional_str = ", ".join(f"{param.name}={param.default}" for param in optional_inputs) - - inputs_str = required_str - if optional_str: - inputs_str = f"{inputs_str}, {optional_str}" if required_str else optional_str - - return inputs_str - - -def format_intermediates_short(intermediates_inputs: List[InputParam], required_intermediates_inputs: List[str], intermediates_outputs: List[OutputParam]) -> str: - """ - Formats intermediate inputs and outputs of a block into a string representation. - - Args: - intermediates_inputs: List of intermediate input parameters - required_intermediates_inputs: List of required intermediate input names - intermediates_outputs: List of intermediate output parameters - - Returns: - str: Formatted string like: - Intermediates: - - inputs: Required(latents), dtype - - modified: latents # variables that appear in both inputs and outputs - - outputs: images # new outputs only - """ - # Handle inputs - input_parts = [] - for inp in intermediates_inputs: - if inp.name in required_intermediates_inputs: - input_parts.append(f"Required({inp.name})") - else: - input_parts.append(inp.name) - - # Handle modified variables (appear in both inputs and outputs) - inputs_set = {inp.name for inp in intermediates_inputs} - modified_parts = [] - new_output_parts = [] - - for out in intermediates_outputs: - if out.name in inputs_set: - modified_parts.append(out.name) - else: - new_output_parts.append(out.name) - - result = [] - if input_parts: - result.append(f" - inputs: {', '.join(input_parts)}") - if modified_parts: - result.append(f" - modified: {', '.join(modified_parts)}") - if new_output_parts: - result.append(f" - outputs: {', '.join(new_output_parts)}") - - return "\n".join(result) if result else " (none)" - - -def format_params(params: List[Union[InputParam, OutputParam]], header: str = "Args", indent_level: int = 4, max_line_length: int = 115) -> str: - """Format a list of InputParam or OutputParam objects into a readable string representation. - - Args: - params: List of InputParam or OutputParam objects to format - header: Header text to use (e.g. "Args" or "Returns") - indent_level: Number of spaces to indent each parameter line (default: 4) - max_line_length: Maximum length for each line before wrapping (default: 115) - - Returns: - A formatted string representing all parameters - """ - if not params: - return "" - - base_indent = " " * indent_level - param_indent = " " * (indent_level + 4) - desc_indent = " " * (indent_level + 8) - formatted_params = [] - - def get_type_str(type_hint): - if hasattr(type_hint, "__origin__") and type_hint.__origin__ is Union: - types = [t.__name__ if hasattr(t, "__name__") else str(t) for t in type_hint.__args__] - return f"Union[{', '.join(types)}]" - return type_hint.__name__ if hasattr(type_hint, "__name__") else str(type_hint) - - def wrap_text(text: str, indent: str, max_length: int) -> str: - """Wrap text while preserving markdown links and maintaining indentation.""" - words = text.split() - lines = [] - current_line = [] - current_length = 0 - - for word in words: - word_length = len(word) + (1 if current_line else 0) - - if current_line and current_length + word_length > max_length: - lines.append(" ".join(current_line)) - current_line = [word] - current_length = len(word) - else: - current_line.append(word) - current_length += word_length - - if current_line: - lines.append(" ".join(current_line)) - - return f"\n{indent}".join(lines) - - # Add the header - formatted_params.append(f"{base_indent}{header}:") - - for param in params: - # Format parameter name and type - type_str = get_type_str(param.type_hint) if param.type_hint != Any else "" - param_str = f"{param_indent}{param.name} (`{type_str}`" - - # Add optional tag and default value if parameter is an InputParam and optional - if isinstance(param, InputParam): - if not param.required: - param_str += ", *optional*" - if param.default is not None: - param_str += f", defaults to {param.default}" - param_str += "):" - - # Add description on a new line with additional indentation and wrapping - if param.description: - desc = re.sub( - r'\[(.*?)\]\((https?://[^\s\)]+)\)', - r'[\1](\2)', - param.description - ) - wrapped_desc = wrap_text(desc, desc_indent, max_line_length) - param_str += f"\n{desc_indent}{wrapped_desc}" - - formatted_params.append(param_str) - - return "\n\n".join(formatted_params) - -# Then update the original functions to use this combined version: -def format_input_params(input_params: List[InputParam], indent_level: int = 4, max_line_length: int = 115) -> str: - return format_params(input_params, "Args", indent_level, max_line_length) - -def format_output_params(output_params: List[OutputParam], indent_level: int = 4, max_line_length: int = 115) -> str: - return format_params(output_params, "Returns", indent_level, max_line_length) - - - -def make_doc_string(inputs, intermediates_inputs, outputs, description=""): - """ - Generates a formatted documentation string describing the pipeline block's parameters and structure. - - Returns: - str: A formatted string containing information about call parameters, intermediate inputs/outputs, - and final intermediate outputs. - """ - output = "" - - if description: - desc_lines = description.strip().split('\n') - aligned_desc = '\n'.join(' ' + line for line in desc_lines) - output += aligned_desc + "\n\n" - - output += format_input_params(inputs + intermediates_inputs, indent_level=2) - - output += "\n\n" - output += format_output_params(outputs, indent_level=2) - - return output - - -class PipelineBlock: - # YiYi Notes: do we need this? - # pipelie block should set the default value for all expected config/components, so maybe we do not need to explicitly set the list - expected_components = [] - expected_configs = [] - model_name = None - - @property - def description(self) -> str: - """Description of the block. Must be implemented by subclasses.""" - raise NotImplementedError("description method must be implemented in subclasses") - - @property - def inputs(self) -> List[InputParam]: - """List of input parameters. Must be implemented by subclasses.""" - raise NotImplementedError("inputs method must be implemented in subclasses") - - @property - def intermediates_inputs(self) -> List[InputParam]: - """List of intermediate input parameters. Must be implemented by subclasses.""" - raise NotImplementedError("intermediates_inputs method must be implemented in subclasses") - - @property - def intermediates_outputs(self) -> List[OutputParam]: - """List of intermediate output parameters. Must be implemented by subclasses.""" - raise NotImplementedError("intermediates_outputs method must be implemented in subclasses") - - # Adding outputs attributes here for consistency between PipelineBlock/AutoPipelineBlocks/SequentialPipelineBlocks - @property - def outputs(self) -> List[OutputParam]: - return self.intermediates_outputs - - @property - def required_inputs(self) -> List[str]: - input_names = [] - for input_param in self.inputs: - if input_param.required: - input_names.append(input_param.name) - return input_names - - @property - def required_intermediates_inputs(self) -> List[str]: - input_names = [] - for input_param in self.intermediates_inputs: - if input_param.required: - input_names.append(input_param.name) - return input_names - - def __init__(self): - self.components: Dict[str, Any] = {} - self.auxiliaries: Dict[str, Any] = {} - self.configs: Dict[str, Any] = {} - - def __call__(self, pipeline, state: PipelineState) -> PipelineState: - raise NotImplementedError("__call__ method must be implemented in subclasses") - - def __repr__(self): - class_name = self.__class__.__name__ - base_class = self.__class__.__bases__[0].__name__ - - # Format description with proper indentation - desc_lines = self.description.split('\n') - desc = [] - # First line with "Description:" label - desc.append(f" Description: {desc_lines[0]}") - # Subsequent lines with proper indentation - if len(desc_lines) > 1: - desc.extend(f" {line}" for line in desc_lines[1:]) - desc = '\n'.join(desc) + '\n' - - # Components section - expected_components = set(getattr(self, "expected_components", [])) - loaded_components = set(self.components.keys()) - all_components = sorted(expected_components | loaded_components) - - main_components = [] - auxiliary_components = [] - for k in all_components: - component_str = f" - {k}={type(self.components[k]).__name__}" if k in loaded_components else f" - {k}" - if k in getattr(self, "auxiliary_components", []): - auxiliary_components.append(component_str) - else: - main_components.append(component_str) - - components = "Components:\n" + "\n".join(main_components) - if auxiliary_components: - components += "\n Auxiliaries:\n" + "\n".join(auxiliary_components) - - # Configs section - expected_configs = set(getattr(self, "expected_configs", [])) - loaded_configs = set(self.configs.keys()) - all_configs = sorted(expected_configs | loaded_configs) - configs = "Configs:\n" + "\n".join( - f" - {k}={self.configs[k]}" if k in loaded_configs else f" - {k}" - for k in all_configs - ) - - # Inputs section - inputs_str = format_inputs_short(self.inputs) - inputs = "Inputs:\n " + inputs_str - - # Intermediates section - intermediates_str = format_intermediates_short(self.intermediates_inputs, self.required_intermediates_inputs, self.intermediates_outputs) - intermediates = f"Intermediates:\n{intermediates_str}" - - return ( - f"{class_name}(\n" - f" Class: {base_class}\n" - f"{desc}" - f" {components}\n" - f" {configs}\n" - f" {inputs}\n" - f" {intermediates}\n" - f")" - ) - - - @property - def doc(self): - return make_doc_string(self.inputs, self.intermediates_inputs, self.outputs, self.description) - - - def get_block_state(self, state: PipelineState) -> dict: - """Get all inputs and intermediates in one dictionary""" - data = {} - - # Check inputs - for input_param in self.inputs: - value = state.get_input(input_param.name) - if input_param.required and value is None: - raise ValueError(f"Required input '{input_param.name}' is missing") - data[input_param.name] = value - - # Check intermediates - for input_param in self.intermediates_inputs: - value = state.get_intermediate(input_param.name) - if input_param.required and value is None: - raise ValueError(f"Required intermediate input '{input_param.name}' is missing") - data[input_param.name] = value - - return BlockState(**data) - - def add_block_state(self, state: PipelineState, block_state: BlockState): - for output_param in self.intermediates_outputs: - if not hasattr(block_state, output_param.name): - raise ValueError(f"Intermediate output '{output_param.name}' is missing in block state") - state.add_intermediate(output_param.name, getattr(block_state, output_param.name)) - - -def combine_inputs(*named_input_lists: List[Tuple[str, List[InputParam]]]) -> List[InputParam]: - """ - Combines multiple lists of InputParam objects from different blocks. For duplicate inputs, updates only if - current default value is None and new default value is not None. Warns if multiple non-None default values - exist for the same input. - - Args: - named_input_lists: List of tuples containing (block_name, input_param_list) pairs - - Returns: - List[InputParam]: Combined list of unique InputParam objects - """ - combined_dict = {} # name -> InputParam - value_sources = {} # name -> block_name - - for block_name, inputs in named_input_lists: - for input_param in inputs: - if input_param.name in combined_dict: - current_param = combined_dict[input_param.name] - if (current_param.default is not None and - input_param.default is not None and - current_param.default != input_param.default): - warnings.warn( - f"Multiple different default values found for input '{input_param.name}': " - f"{current_param.default} (from block '{value_sources[input_param.name]}') and " - f"{input_param.default} (from block '{block_name}'). Using {current_param.default}." - ) - if current_param.default is None and input_param.default is not None: - combined_dict[input_param.name] = input_param - value_sources[input_param.name] = block_name - else: - combined_dict[input_param.name] = input_param - value_sources[input_param.name] = block_name - - return list(combined_dict.values()) - -def combine_outputs(*named_output_lists: List[Tuple[str, List[OutputParam]]]) -> List[OutputParam]: - """ - Combines multiple lists of OutputParam objects from different blocks. For duplicate outputs, - keeps the first occurrence of each output name. - - Args: - named_output_lists: List of tuples containing (block_name, output_param_list) pairs - - Returns: - List[OutputParam]: Combined list of unique OutputParam objects - """ - combined_dict = {} # name -> OutputParam - - for block_name, outputs in named_output_lists: - for output_param in outputs: - if output_param.name not in combined_dict: - combined_dict[output_param.name] = output_param - - return list(combined_dict.values()) - - -class AutoPipelineBlocks: - """ - A class that automatically selects a block to run based on the inputs. - - Attributes: - block_classes: List of block classes to be used - block_names: List of prefixes for each block - block_trigger_inputs: List of input names that trigger specific blocks, with None for default - """ - - block_classes = [] - block_names = [] - block_trigger_inputs = [] - - def __init__(self): - blocks = OrderedDict() - for block_name, block_cls in zip(self.block_names, self.block_classes): - blocks[block_name] = block_cls() - self.blocks = blocks - if not (len(self.block_classes) == len(self.block_names) == len(self.block_trigger_inputs)): - raise ValueError(f"In {self.__class__.__name__}, the number of block_classes, block_names, and block_trigger_inputs must be the same.") - default_blocks = [t for t in self.block_trigger_inputs if t is None] - # can only have 1 or 0 default block, and has to put in the last - # the order of blocksmatters here because the first block with matching trigger will be dispatched - # e.g. blocks = [inpaint, img2img] and block_trigger_inputs = ["mask", "image"] - # if both mask and image are provided, it is inpaint; if only image is provided, it is img2img - if len(default_blocks) > 1 or ( - len(default_blocks) == 1 and self.block_trigger_inputs[-1] is not None - ): - raise ValueError( - f"In {self.__class__.__name__}, exactly one None must be specified as the last element " - "in block_trigger_inputs." - ) - - # Map trigger inputs to block objects - self.trigger_to_block_map = dict(zip(self.block_trigger_inputs, self.blocks.values())) - self.trigger_to_block_name_map = dict(zip(self.block_trigger_inputs, self.blocks.keys())) - self.block_to_trigger_map = dict(zip(self.blocks.keys(), self.block_trigger_inputs)) - - @property - def model_name(self): - return next(iter(self.blocks.values())).model_name - - @property - def description(self): - return "" - - @property - def expected_components(self): - expected_components = [] - for block in self.blocks.values(): - for component in block.expected_components: - if component not in expected_components: - expected_components.append(component) - return expected_components - - @property - def expected_configs(self): - expected_configs = [] - for block in self.blocks.values(): - for config in block.expected_configs: - if config not in expected_configs: - expected_configs.append(config) - return expected_configs - - # YiYi TODO: address the case where multiple blocks have the same component/auxiliary/config; give out warning etc - @property - def components(self): - # Combine components from all blocks - components = {} - for block_name, block in self.blocks.items(): - for key, value in block.components.items(): - # Only update if: - # 1. Key doesn't exist yet in components, OR - # 2. New value is not None - if key not in components or value is not None: - components[key] = value - return components - - @property - def auxiliaries(self): - # Combine auxiliaries from all blocks - auxiliaries = {} - for block_name, block in self.blocks.items(): - auxiliaries.update(block.auxiliaries) - return auxiliaries - - @property - def configs(self): - # Combine configs from all blocks - configs = {} - for block_name, block in self.blocks.items(): - configs.update(block.configs) - return configs - - @property - def required_inputs(self) -> List[str]: - first_block = next(iter(self.blocks.values())) - required_by_all = set(getattr(first_block, "required_inputs", set())) - - # Intersect with required inputs from all other blocks - for block in list(self.blocks.values())[1:]: - block_required = set(getattr(block, "required_inputs", set())) - required_by_all.intersection_update(block_required) - - return list(required_by_all) - - @property - def required_intermediates_inputs(self) -> List[str]: - first_block = next(iter(self.blocks.values())) - required_by_all = set(getattr(first_block, "required_intermediates_inputs", set())) - - # Intersect with required inputs from all other blocks - for block in list(self.blocks.values())[1:]: - block_required = set(getattr(block, "required_intermediates_inputs", set())) - required_by_all.intersection_update(block_required) - - return list(required_by_all) - - - # YiYi TODO: add test for this - @property - def inputs(self) -> List[Tuple[str, Any]]: - named_inputs = [(name, block.inputs) for name, block in self.blocks.items()] - combined_inputs = combine_inputs(*named_inputs) - # mark Required inputs only if that input is required by all the blocks - for input_param in combined_inputs: - if input_param.name in self.required_inputs: - input_param.required = True - else: - input_param.required = False - return combined_inputs - - - @property - def intermediates_inputs(self) -> List[str]: - named_inputs = [(name, block.intermediates_inputs) for name, block in self.blocks.items()] - combined_inputs = combine_inputs(*named_inputs) - # mark Required inputs only if that input is required by all the blocks - for input_param in combined_inputs: - if input_param.name in self.required_intermediates_inputs: - input_param.required = True - else: - input_param.required = False - return combined_inputs - - @property - def intermediates_outputs(self) -> List[str]: - named_outputs = [(name, block.intermediates_outputs) for name, block in self.blocks.items()] - combined_outputs = combine_outputs(*named_outputs) - return combined_outputs - - @property - def outputs(self) -> List[str]: - named_outputs = [(name, block.outputs) for name, block in self.blocks.items()] - combined_outputs = combine_outputs(*named_outputs) - return combined_outputs - - @torch.no_grad() - def __call__(self, pipeline, state: PipelineState) -> PipelineState: - # Find default block first (if any) - - block = self.trigger_to_block_map.get(None) - for input_name in self.block_trigger_inputs: - if input_name is not None and state.get_input(input_name) is not None: - block = self.trigger_to_block_map[input_name] - break - elif input_name is not None and state.get_intermediate(input_name) is not None: - block = self.trigger_to_block_map[input_name] - break - - if block is None: - logger.warning(f"skipping auto block: {self.__class__.__name__}") - return pipeline, state - - try: - logger.info(f"Running block: {block.__class__.__name__}, trigger: {input_name}") - return block(pipeline, state) - except Exception as e: - error_msg = ( - f"\nError in block: {block.__class__.__name__}\n" - f"Error details: {str(e)}\n" - f"Traceback:\n{traceback.format_exc()}" - ) - logger.error(error_msg) - raise - - def _get_trigger_inputs(self): - """ - Returns a set of all unique trigger input values found in the blocks. - Returns: Set[str] containing all unique block_trigger_inputs values - """ - def fn_recursive_get_trigger(blocks): - trigger_values = set() - - if blocks is not None: - for name, block in blocks.items(): - # Check if current block has trigger inputs(i.e. auto block) - if hasattr(block, 'block_trigger_inputs') and block.block_trigger_inputs is not None: - # Add all non-None values from the trigger inputs list - trigger_values.update(t for t in block.block_trigger_inputs if t is not None) - - # If block has blocks, recursively check them - if hasattr(block, 'blocks'): - nested_triggers = fn_recursive_get_trigger(block.blocks) - trigger_values.update(nested_triggers) - - return trigger_values - - trigger_inputs = set(self.block_trigger_inputs) - trigger_inputs.update(fn_recursive_get_trigger(self.blocks)) - - return trigger_inputs - - @property - def trigger_inputs(self): - return self._get_trigger_inputs() - - def __repr__(self): - class_name = self.__class__.__name__ - base_class = self.__class__.__bases__[0].__name__ - header = ( - f"{class_name}(\n Class: {base_class}\n" - if base_class and base_class != "object" - else f"{class_name}(\n" - ) - - - if self.trigger_inputs: - header += "\n" - header += " " + "=" * 100 + "\n" - header += " This pipeline contains blocks that are selected at runtime based on inputs.\n" - header += f" Trigger Inputs: {self.trigger_inputs}\n" - # Get first trigger input as example - example_input = next(t for t in self.trigger_inputs if t is not None) - header += f" Use `get_execution_blocks()` with input names to see selected blocks (e.g. `get_execution_blocks('{example_input}')`).\n" - header += " " + "=" * 100 + "\n\n" - - # Format description with proper indentation - desc_lines = self.description.split('\n') - desc = [] - # First line with "Description:" label - desc.append(f" Description: {desc_lines[0]}") - # Subsequent lines with proper indentation - if len(desc_lines) > 1: - desc.extend(f" {line}" for line in desc_lines[1:]) - desc = '\n'.join(desc) + '\n' - - # Components section - expected_components = set(getattr(self, "expected_components", [])) - loaded_components = set(self.components.keys()) - all_components = sorted(expected_components | loaded_components) - components_str = " Components:\n" + "\n".join( - f" - {k}={type(self.components[k]).__name__}" if k in loaded_components else f" - {k}" - for k in all_components - ) - - # Auxiliaries section - auxiliaries_str = " Auxiliaries:\n" + "\n".join( - f" - {k}={type(v).__name__}" for k, v in self.auxiliaries.items() - ) - - # Configs section - expected_configs = set(getattr(self, "expected_configs", [])) - loaded_configs = set(self.configs.keys()) - all_configs = sorted(expected_configs | loaded_configs) - configs_str = " Configs:\n" + "\n".join( - f" - {k}={v}" if k in loaded_configs else f" - {k}" for k, v in self.configs.items() - ) - - blocks_str = " Blocks:\n" - for i, (name, block) in enumerate(self.blocks.items()): - # Get trigger input for this block - trigger = None - if hasattr(self, 'block_to_trigger_map'): - trigger = self.block_to_trigger_map.get(name) - # Format the trigger info - if trigger is None: - trigger_str = "[default]" - elif isinstance(trigger, (list, tuple)): - trigger_str = f"[trigger: {', '.join(str(t) for t in trigger)}]" - else: - trigger_str = f"[trigger: {trigger}]" - # For AutoPipelineBlocks, add bullet points - blocks_str += f" • {name} {trigger_str} ({block.__class__.__name__})\n" - else: - # For SequentialPipelineBlocks, show execution order - blocks_str += f" [{i}] {name} ({block.__class__.__name__})\n" - - # Add block description - desc_lines = block.description.split('\n') - indented_desc = desc_lines[0] - if len(desc_lines) > 1: - indented_desc += '\n' + '\n'.join(' ' + line for line in desc_lines[1:]) - blocks_str += f" Description: {indented_desc}\n" - - # Format inputs - inputs_str = format_inputs_short(block.inputs) - blocks_str += f" inputs: {inputs_str}\n" - - # Format intermediates - intermediates_str = format_intermediates_short( - block.intermediates_inputs, - block.required_intermediates_inputs, - block.intermediates_outputs - ) - if intermediates_str != " (none)": - blocks_str += " intermediates:\n" - indented_intermediates = "\n".join( - " " + line for line in intermediates_str.split("\n") - ) - blocks_str += f"{indented_intermediates}\n" - blocks_str += "\n" - - inputs_str = format_inputs_short(self.inputs) - inputs_str = " Inputs:\n " + inputs_str - outputs = [out.name for out in self.outputs] - - intermediates_str = format_intermediates_short(self.intermediates_inputs, self.required_intermediates_inputs, self.intermediates_outputs) - intermediates_str = ( - "\n Intermediates:\n" - f"{intermediates_str}\n" - f" - final outputs: {', '.join(outputs)}" - ) - - return ( - f"{header}\n" - f"{desc}" - f"{components_str}\n" - f"{auxiliaries_str}\n" - f"{configs_str}\n" - f"{blocks_str}\n" - f"{inputs_str}\n" - f"{intermediates_str}\n" - f")" - ) - - @property - def doc(self): - return make_doc_string(self.inputs, self.intermediates_inputs, self.outputs, self.description) - -class SequentialPipelineBlocks: - """ - A class that combines multiple pipeline block classes into one. When called, it will call each block in sequence. - """ - block_classes = [] - block_names = [] - - @property - def model_name(self): - return next(iter(self.blocks.values())).model_name - - @property - def description(self): - return "" - - @property - def expected_components(self): - expected_components = [] - for block in self.blocks.values(): - for component in block.expected_components: - if component not in expected_components: - expected_components.append(component) - return expected_components - - @property - def expected_configs(self): - expected_configs = [] - for block in self.blocks.values(): - for config in block.expected_configs: - if config not in expected_configs: - expected_configs.append(config) - return expected_configs - - @classmethod - def from_blocks_dict(cls, blocks_dict: Dict[str, Any]) -> "SequentialPipelineBlocks": - """Creates a SequentialPipelineBlocks instance from a dictionary of blocks. - - Args: - blocks_dict: Dictionary mapping block names to block instances - - Returns: - A new SequentialPipelineBlocks instance - """ - instance = cls() - instance.block_classes = [block.__class__ for block in blocks_dict.values()] - instance.block_names = list(blocks_dict.keys()) - instance.blocks = blocks_dict - return instance - - def __init__(self): - blocks = OrderedDict() - for block_name, block_cls in zip(self.block_names, self.block_classes): - blocks[block_name] = block_cls() - self.blocks = blocks - - # YiYi TODO: address the case where multiple blocks have the same component/auxiliary/config; give out warning etc - @property - def components(self): - # Combine components from all blocks - components = {} - for block_name, block in self.blocks.items(): - for key, value in block.components.items(): - # Only update if: - # 1. Key doesn't exist yet in components, OR - # 2. New value is not None - if key not in components or value is not None: - components[key] = value - return components - - @property - def auxiliaries(self): - # Combine auxiliaries from all blocks - auxiliaries = {} - for block_name, block in self.blocks.items(): - auxiliaries.update(block.auxiliaries) - return auxiliaries - - @property - def configs(self): - # Combine configs from all blocks - configs = {} - for block_name, block in self.blocks.items(): - configs.update(block.configs) - return configs - - @property - def required_inputs(self) -> List[str]: - # Get the first block from the dictionary - first_block = next(iter(self.blocks.values())) - required_by_any = set(getattr(first_block, "required_inputs", set())) - - # Union with required inputs from all other blocks - for block in list(self.blocks.values())[1:]: - block_required = set(getattr(block, "required_inputs", set())) - required_by_any.update(block_required) - - return list(required_by_any) - - @property - def required_intermediates_inputs(self) -> List[str]: - required_intermediates_inputs = [] - for input_param in self.intermediates_inputs: - if input_param.required: - required_intermediates_inputs.append(input_param.name) - return required_intermediates_inputs - - # YiYi TODO: add test for this - @property - def inputs(self) -> List[Tuple[str, Any]]: - named_inputs = [(name, block.inputs) for name, block in self.blocks.items()] - combined_inputs = combine_inputs(*named_inputs) - # mark Required inputs only if that input is required any of the blocks - for input_param in combined_inputs: - if input_param.name in self.required_inputs: - input_param.required = True - else: - input_param.required = False - return combined_inputs - - @property - def intermediates_inputs(self) -> List[str]: - inputs = [] - outputs = set() - - # Go through all blocks in order - for block in self.blocks.values(): - # Add inputs that aren't in outputs yet - inputs.extend(input_name for input_name in block.intermediates_inputs if input_name.name not in outputs) - - # Only add outputs if the block cannot be skipped - should_add_outputs = True - if hasattr(block, "block_trigger_inputs") and None not in block.block_trigger_inputs: - should_add_outputs = False - - if should_add_outputs: - # Add this block's outputs - block_intermediates_outputs = [out.name for out in block.intermediates_outputs] - outputs.update(block_intermediates_outputs) - return inputs - - @property - def intermediates_outputs(self) -> List[str]: - named_outputs = [(name, block.intermediates_outputs) for name, block in self.blocks.items()] - combined_outputs = combine_outputs(*named_outputs) - return combined_outputs - - @property - def outputs(self) -> List[str]: - return next(reversed(self.blocks.values())).intermediates_outputs - - @torch.no_grad() - def __call__(self, pipeline, state: PipelineState) -> PipelineState: - for block_name, block in self.blocks.items(): - try: - pipeline, state = block(pipeline, state) - except Exception as e: - error_msg = ( - f"\nError in block: ({block_name}, {block.__class__.__name__})\n" - f"Error details: {str(e)}\n" - f"Traceback:\n{traceback.format_exc()}" - ) - logger.error(error_msg) - raise - return pipeline, state - - def _get_trigger_inputs(self): - """ - Returns a set of all unique trigger input values found in the blocks. - Returns: Set[str] containing all unique block_trigger_inputs values - """ - def fn_recursive_get_trigger(blocks): - trigger_values = set() - - if blocks is not None: - for name, block in blocks.items(): - # Check if current block has trigger inputs(i.e. auto block) - if hasattr(block, 'block_trigger_inputs') and block.block_trigger_inputs is not None: - # Add all non-None values from the trigger inputs list - trigger_values.update(t for t in block.block_trigger_inputs if t is not None) - - # If block has blocks, recursively check them - if hasattr(block, 'blocks'): - nested_triggers = fn_recursive_get_trigger(block.blocks) - trigger_values.update(nested_triggers) - - return trigger_values - - return fn_recursive_get_trigger(self.blocks) - - @property - def trigger_inputs(self): - return self._get_trigger_inputs() - - def _traverse_trigger_blocks(self, trigger_inputs): - # Convert trigger_inputs to a set for easier manipulation - active_triggers = set(trigger_inputs) - - def fn_recursive_traverse(block, block_name, active_triggers): - result_blocks = OrderedDict() - - # sequential or PipelineBlock - if not hasattr(block, 'block_trigger_inputs'): - if hasattr(block, 'blocks'): - # sequential - for block_name, block in block.blocks.items(): - blocks_to_update = fn_recursive_traverse(block, block_name, active_triggers) - result_blocks.update(blocks_to_update) - else: - # PipelineBlock - result_blocks[block_name] = block - # Add this block's output names to active triggers if defined - if hasattr(block, 'outputs'): - active_triggers.update(out.name for out in block.outputs) - return result_blocks - - # auto - else: - # Find first block_trigger_input that matches any value in our active_triggers - this_block = None - matching_trigger = None - for trigger_input in block.block_trigger_inputs: - if trigger_input is not None and trigger_input in active_triggers: - this_block = block.trigger_to_block_map[trigger_input] - matching_trigger = trigger_input - break - - # If no matches found, try to get the default (None) block - if this_block is None and None in block.block_trigger_inputs: - this_block = block.trigger_to_block_map[None] - matching_trigger = None - - if this_block is not None: - # sequential/auto - if hasattr(this_block, 'blocks'): - result_blocks.update(fn_recursive_traverse(this_block, block_name, active_triggers)) - else: - # PipelineBlock - result_blocks[block_name] = this_block - # Add this block's output names to active triggers if defined - if hasattr(this_block, 'outputs'): - active_triggers.update(out.name for out in this_block.outputs) - - return result_blocks - - all_blocks = OrderedDict() - for block_name, block in self.blocks.items(): - blocks_to_update = fn_recursive_traverse(block, block_name, active_triggers) - all_blocks.update(blocks_to_update) - return all_blocks - - def get_execution_blocks(self, *trigger_inputs): - trigger_inputs_all = self.trigger_inputs - - if trigger_inputs is not None: - - if not isinstance(trigger_inputs, (list, tuple, set)): - trigger_inputs = [trigger_inputs] - invalid_inputs = [x for x in trigger_inputs if x not in trigger_inputs_all] - if invalid_inputs: - logger.warning( - f"The following trigger inputs will be ignored as they are not supported: {invalid_inputs}" - ) - trigger_inputs = [x for x in trigger_inputs if x in trigger_inputs_all] - - if trigger_inputs is None: - if None in trigger_inputs_all: - trigger_inputs = [None] - else: - trigger_inputs = [trigger_inputs_all[0]] - blocks_triggered = self._traverse_trigger_blocks(trigger_inputs) - return SequentialPipelineBlocks.from_blocks_dict(blocks_triggered) - - def __repr__(self): - class_name = self.__class__.__name__ - base_class = self.__class__.__bases__[0].__name__ - header = ( - f"{class_name}(\n Class: {base_class}\n" - if base_class and base_class != "object" - else f"{class_name}(\n" - ) - - - if self.trigger_inputs: - header += "\n" - header += " " + "=" * 100 + "\n" - header += " This pipeline contains blocks that are selected at runtime based on inputs.\n" - header += f" Trigger Inputs: {self.trigger_inputs}\n" - # Get first trigger input as example - example_input = next(t for t in self.trigger_inputs if t is not None) - header += f" Use `get_execution_blocks()` with input names to see selected blocks (e.g. `get_execution_blocks('{example_input}')`).\n" - header += " " + "=" * 100 + "\n\n" - - # Format description with proper indentation - desc_lines = self.description.split('\n') - desc = [] - # First line with "Description:" label - desc.append(f" Description: {desc_lines[0]}") - # Subsequent lines with proper indentation - if len(desc_lines) > 1: - desc.extend(f" {line}" for line in desc_lines[1:]) - desc = '\n'.join(desc) + '\n' - - # Components section - expected_components = set(getattr(self, "expected_components", [])) - loaded_components = set(self.components.keys()) - all_components = sorted(expected_components | loaded_components) - components_str = " Components:\n" + "\n".join( - f" - {k}={type(self.components[k]).__name__}" if k in loaded_components else f" - {k}" - for k in all_components - ) - - # Auxiliaries section - auxiliaries_str = " Auxiliaries:\n" + "\n".join( - f" - {k}={type(v).__name__}" for k, v in self.auxiliaries.items() - ) - - # Configs section - expected_configs = set(getattr(self, "expected_configs", [])) - loaded_configs = set(self.configs.keys()) - all_configs = sorted(expected_configs | loaded_configs) - configs_str = " Configs:\n" + "\n".join( - f" - {k}={self.configs[k]}" if k in loaded_configs else f" - {k}" for k in all_configs - ) - - blocks_str = " Blocks:\n" - for i, (name, block) in enumerate(self.blocks.items()): - # Get trigger input for this block - trigger = None - if hasattr(self, 'block_to_trigger_map'): - trigger = self.block_to_trigger_map.get(name) - # Format the trigger info - if trigger is None: - trigger_str = "[default]" - elif isinstance(trigger, (list, tuple)): - trigger_str = f"[trigger: {', '.join(str(t) for t in trigger)}]" - else: - trigger_str = f"[trigger: {trigger}]" - # For AutoPipelineBlocks, add bullet points - blocks_str += f" • {name} {trigger_str} ({block.__class__.__name__})\n" - else: - # For SequentialPipelineBlocks, show execution order - blocks_str += f" [{i}] {name} ({block.__class__.__name__})\n" - - # Add block description - desc_lines = block.description.split('\n') - indented_desc = desc_lines[0] - if len(desc_lines) > 1: - indented_desc += '\n' + '\n'.join(' ' + line for line in desc_lines[1:]) - blocks_str += f" Description: {indented_desc}\n" - - # Format inputs - inputs_str = format_inputs_short(block.inputs) - blocks_str += f" inputs: {inputs_str}\n" - - # Format intermediates - intermediates_str = format_intermediates_short( - block.intermediates_inputs, - block.required_intermediates_inputs, - block.intermediates_outputs - ) - if intermediates_str != " (none)": - blocks_str += " intermediates:\n" - indented_intermediates = "\n".join( - " " + line for line in intermediates_str.split("\n") - ) - blocks_str += f"{indented_intermediates}\n" - blocks_str += "\n" - - inputs_str = format_inputs_short(self.inputs) - inputs_str = " Inputs:\n " + inputs_str - outputs = [out.name for out in self.outputs] - - intermediates_str = format_intermediates_short(self.intermediates_inputs, self.required_intermediates_inputs, self.intermediates_outputs) - intermediates_str = ( - "\n Intermediates:\n" - f"{intermediates_str}\n" - f" - final outputs: {', '.join(outputs)}" - ) - - return ( - f"{header}\n" - f"{desc}" - f"{components_str}\n" - f"{auxiliaries_str}\n" - f"{configs_str}\n" - f"{blocks_str}\n" - f"{inputs_str}\n" - f"{intermediates_str}\n" - f")" - ) - - @property - def doc(self): - return make_doc_string(self.inputs, self.intermediates_inputs, self.outputs, self.description) - -class ModularPipeline(ConfigMixin): - """ - Base class for all Modular pipelines. - - """ - - config_name = "model_index.json" - _exclude_from_cpu_offload = [] - - def __init__(self, block): - self.pipeline_block = block - - # add default components from pipeline_block (e.g. guider) - for key, value in block.components.items(): - setattr(self, key, value) - - # add default configs from pipeline_block (e.g. force_zeros_for_empty_prompt) - self.register_to_config(**block.configs) - - # add default auxiliaries from pipeline_block (e.g. image_processor) - for key, value in block.auxiliaries.items(): - setattr(self, key, value) - - @classmethod - def from_block(cls, block): - modular_pipeline_class_name = MODULAR_PIPELINE_MAPPING[block.model_name] - modular_pipeline_class = _get_pipeline_class(cls, class_name=modular_pipeline_class_name) - - return modular_pipeline_class(block) - - @property - def device(self) -> torch.device: - r""" - Returns: - `torch.device`: The torch device on which the pipeline is located. - """ - modules = self.components.values() - modules = [m for m in modules if isinstance(m, torch.nn.Module)] - - for module in modules: - return module.device - - return torch.device("cpu") - - @property - # Copied from diffusers.pipelines.pipeline_utils.DiffusionPipeline._execution_device - def _execution_device(self): - r""" - Returns the device on which the pipeline's models will be executed. After calling - [`~DiffusionPipeline.enable_sequential_cpu_offload`] the execution device can only be inferred from - Accelerate's module hooks. - """ - for name, model in self.components.items(): - if not isinstance(model, torch.nn.Module) or name in self._exclude_from_cpu_offload: - continue - - if not hasattr(model, "_hf_hook"): - return self.device - for module in model.modules(): - if ( - hasattr(module, "_hf_hook") - and hasattr(module._hf_hook, "execution_device") - and module._hf_hook.execution_device is not None - ): - return torch.device(module._hf_hook.execution_device) - return self.device - - - def get_execution_blocks(self, *trigger_inputs): - return self.pipeline_block.get_execution_blocks(*trigger_inputs) - - @property - def dtype(self) -> torch.dtype: - r""" - Returns: - `torch.dtype`: The torch dtype on which the pipeline is located. - """ - modules = self.components.values() - modules = [m for m in modules if isinstance(m, torch.nn.Module)] - - for module in modules: - return module.dtype - - return torch.float32 - - @property - def expected_components(self): - return self.pipeline_block.expected_components - - @property - def expected_configs(self): - return self.pipeline_block.expected_configs - - @property - def components(self): - components = {} - for name in self.expected_components: - if hasattr(self, name): - components[name] = getattr(self, name) - return components - - # Copied from diffusers.pipelines.pipeline_utils.DiffusionPipeline.progress_bar - def progress_bar(self, iterable=None, total=None): - if not hasattr(self, "_progress_bar_config"): - self._progress_bar_config = {} - elif not isinstance(self._progress_bar_config, dict): - raise ValueError( - f"`self._progress_bar_config` should be of type `dict`, but is {type(self._progress_bar_config)}." - ) - - if iterable is not None: - return tqdm(iterable, **self._progress_bar_config) - elif total is not None: - return tqdm(total=total, **self._progress_bar_config) - else: - raise ValueError("Either `total` or `iterable` has to be defined.") - - # Copied from diffusers.pipelines.pipeline_utils.DiffusionPipeline.set_progress_bar_config - def set_progress_bar_config(self, **kwargs): - self._progress_bar_config = kwargs - - def __call__(self, state: PipelineState = None, output: Union[str, List[str]] = None, **kwargs): - """ - Run one or more blocks in sequence, optionally you can pass a previous pipeline state. - """ - if state is None: - state = PipelineState() - - # Make a copy of the input kwargs - input_params = kwargs.copy() - - default_params = self.default_call_parameters - - # Add inputs to state, using defaults if not provided in the kwargs or the state - # if same input already in the state, will override it if provided in the kwargs - - intermediates_inputs = [inp.name for inp in self.pipeline_block.intermediates_inputs] - for name, default in default_params.items(): - if name in input_params: - if name not in intermediates_inputs: - state.add_input(name, input_params.pop(name)) - else: - state.add_input(name, input_params[name]) - elif name not in state.inputs: - state.add_input(name, default) - - for name in intermediates_inputs: - if name in input_params: - state.add_intermediate(name, input_params.pop(name)) - - # Warn about unexpected inputs - if len(input_params) > 0: - logger.warning(f"Unexpected input '{input_params.keys()}' provided. This input will be ignored.") - # Run the pipeline - with torch.no_grad(): - try: - pipeline, state = self.pipeline_block(self, state) - except Exception: - error_msg = f"Error in block: ({self.pipeline_block.__class__.__name__}):\n" - logger.error(error_msg) - raise - - if output is None: - return state - - - elif isinstance(output, str): - return state.get_intermediate(output) - - elif isinstance(output, (list, tuple)): - return state.get_intermediates(output) - else: - raise ValueError(f"Output '{output}' is not a valid output type") - - def update_states(self, **kwargs): - """ - Update components and configs after instance creation. Auxiliaries (e.g. image_processor) should be defined for - each pipeline block, does not need to be updated by users. Logs if existing non-None components are being - overwritten. - - Args: - kwargs (dict): Keyword arguments to update the states. - """ - - for component_name in self.expected_components: - if component_name in kwargs: - if hasattr(self, component_name) and getattr(self, component_name) is not None: - current_component = getattr(self, component_name) - new_component = kwargs[component_name] - - if not isinstance(new_component, current_component.__class__): - logger.info( - f"Overwriting existing component '{component_name}' " - f"(type: {current_component.__class__.__name__}) " - f"with type: {new_component.__class__.__name__})" - ) - elif isinstance(current_component, torch.nn.Module): - if id(current_component) != id(new_component): - logger.info( - f"Overwriting existing component '{component_name}' " - f"(type: {type(current_component).__name__}) " - f"with new value (type: {type(new_component).__name__})" - ) - - setattr(self, component_name, kwargs.pop(component_name)) - - configs_to_add = {} - for config_name in self.expected_configs: - if config_name in kwargs: - configs_to_add[config_name] = kwargs.pop(config_name) - self.register_to_config(**configs_to_add) - - @property - def default_call_parameters(self) -> Dict[str, Any]: - params = {} - for input_param in self.pipeline_block.inputs: - params[input_param.name] = input_param.default - return params - - def __repr__(self): - output = "ModularPipeline:\n" - output += "==============================\n\n" - - block = self.pipeline_block - - # List the pipeline block structure first - output += "Pipeline Block:\n" - output += "--------------\n" - if hasattr(block, "blocks"): - output += f"{block.__class__.__name__}\n" - base_class = block.__class__.__bases__[0].__name__ - output += f" (Class: {base_class})\n" if base_class != "object" else "\n" - for sub_block_name, sub_block in block.blocks.items(): - if hasattr(block, "block_trigger_inputs"): - trigger_input = block.block_to_trigger_map[sub_block_name] - trigger_info = f" [trigger: {trigger_input}]" if trigger_input is not None else " [default]" - output += f" • {sub_block_name} ({sub_block.__class__.__name__}){trigger_info}\n" - else: - output += f" • {sub_block_name} ({sub_block.__class__.__name__})\n" - else: - output += f"{block.__class__.__name__}\n" - output += "\n" - - # List the components registered in the pipeline - output += "Registered Components:\n" - output += "----------------------\n" - for name, component in self.components.items(): - output += f"{name}: {type(component).__name__}" - if hasattr(component, "dtype") and hasattr(component, "device"): - output += f" (dtype={component.dtype}, device={component.device})" - output += "\n" - output += "\n" - - # List the configs registered in the pipeline - output += "Registered Configs:\n" - output += "------------------\n" - for name, config in self.config.items(): - output += f"{name}: {config!r}\n" - output += "\n" - - # Add auto blocks section - if hasattr(block, "trigger_inputs") and block.trigger_inputs: - output += "------------------\n" - output += "This pipeline contains blocks that are selected at runtime based on inputs.\n\n" - output += f"Trigger Inputs: {block.trigger_inputs}\n" - # Get first trigger input as example - example_input = next(t for t in block.trigger_inputs if t is not None) - output += f" Use `get_execution_blocks()` with input names to see selected blocks (e.g. `get_execution_blocks('{example_input}')`).\n" - output += "Check `.doc` of returned object for more information.\n\n" - - # List the call parameters - full_doc = self.pipeline_block.doc - if "------------------------" in full_doc: - full_doc = full_doc.split("------------------------")[0].rstrip() - output += full_doc - - return output - - # YiYi TO-DO: try to unify the to method with the one in DiffusionPipeline - # Modified from diffusers.pipelines.pipeline_utils.DiffusionPipeline.to - def to(self, *args, **kwargs): - r""" - Performs Pipeline dtype and/or device conversion. A torch.dtype and torch.device are inferred from the - arguments of `self.to(*args, **kwargs).` - - - - If the pipeline already has the correct torch.dtype and torch.device, then it is returned as is. Otherwise, - the returned pipeline is a copy of self with the desired torch.dtype and torch.device. - - - - - Here are the ways to call `to`: - - - `to(dtype, silence_dtype_warnings=False) → DiffusionPipeline` to return a pipeline with the specified - [`dtype`](https://pytorch.org/docs/stable/tensor_attributes.html#torch.dtype) - - `to(device, silence_dtype_warnings=False) → DiffusionPipeline` to return a pipeline with the specified - [`device`](https://pytorch.org/docs/stable/tensor_attributes.html#torch.device) - - `to(device=None, dtype=None, silence_dtype_warnings=False) → DiffusionPipeline` to return a pipeline with the - specified [`device`](https://pytorch.org/docs/stable/tensor_attributes.html#torch.device) and - [`dtype`](https://pytorch.org/docs/stable/tensor_attributes.html#torch.dtype) - - Arguments: - dtype (`torch.dtype`, *optional*): - Returns a pipeline with the specified - [`dtype`](https://pytorch.org/docs/stable/tensor_attributes.html#torch.dtype) - device (`torch.Device`, *optional*): - Returns a pipeline with the specified - [`device`](https://pytorch.org/docs/stable/tensor_attributes.html#torch.device) - silence_dtype_warnings (`str`, *optional*, defaults to `False`): - Whether to omit warnings if the target `dtype` is not compatible with the target `device`. - - Returns: - [`DiffusionPipeline`]: The pipeline converted to specified `dtype` and/or `dtype`. - """ - dtype = kwargs.pop("dtype", None) - device = kwargs.pop("device", None) - silence_dtype_warnings = kwargs.pop("silence_dtype_warnings", False) - - dtype_arg = None - device_arg = None - if len(args) == 1: - if isinstance(args[0], torch.dtype): - dtype_arg = args[0] - else: - device_arg = torch.device(args[0]) if args[0] is not None else None - elif len(args) == 2: - if isinstance(args[0], torch.dtype): - raise ValueError( - "When passing two arguments, make sure the first corresponds to `device` and the second to `dtype`." - ) - device_arg = torch.device(args[0]) if args[0] is not None else None - dtype_arg = args[1] - elif len(args) > 2: - raise ValueError("Please make sure to pass at most two arguments (`device` and `dtype`) `.to(...)`") - - if dtype is not None and dtype_arg is not None: - raise ValueError( - "You have passed `dtype` both as an argument and as a keyword argument. Please only pass one of the two." - ) - - dtype = dtype or dtype_arg - - if device is not None and device_arg is not None: - raise ValueError( - "You have passed `device` both as an argument and as a keyword argument. Please only pass one of the two." - ) - - device = device or device_arg - - # throw warning if pipeline is in "offloaded"-mode but user tries to manually set to GPU. - def module_is_sequentially_offloaded(module): - if not is_accelerate_available() or is_accelerate_version("<", "0.14.0"): - return False - - return hasattr(module, "_hf_hook") and ( - isinstance(module._hf_hook, accelerate.hooks.AlignDevicesHook) - or hasattr(module._hf_hook, "hooks") - and isinstance(module._hf_hook.hooks[0], accelerate.hooks.AlignDevicesHook) - ) - - def module_is_offloaded(module): - if not is_accelerate_available() or is_accelerate_version("<", "0.17.0.dev0"): - return False - - return hasattr(module, "_hf_hook") and isinstance(module._hf_hook, accelerate.hooks.CpuOffload) - - # .to("cuda") would raise an error if the pipeline is sequentially offloaded, so we raise our own to make it clearer - pipeline_is_sequentially_offloaded = any( - module_is_sequentially_offloaded(module) for _, module in self.components.items() - ) - if pipeline_is_sequentially_offloaded and device and torch.device(device).type == "cuda": - raise ValueError( - "It seems like you have activated sequential model offloading by calling `enable_sequential_cpu_offload`, but are now attempting to move the pipeline to GPU. This is not compatible with offloading. Please, move your pipeline `.to('cpu')` or consider removing the move altogether if you use sequential offloading." - ) - - is_pipeline_device_mapped = hasattr(self, "hf_device_map") and self.hf_device_map is not None and len(self.hf_device_map) > 1 - if is_pipeline_device_mapped: - raise ValueError( - "It seems like you have activated a device mapping strategy on the pipeline which doesn't allow explicit device placement using `to()`. You can call `reset_device_map()` first and then call `to()`." - ) - - # Display a warning in this case (the operation succeeds but the benefits are lost) - pipeline_is_offloaded = any(module_is_offloaded(module) for _, module in self.components.items()) - if pipeline_is_offloaded and device and torch.device(device).type == "cuda": - logger.warning( - f"It seems like you have activated model offloading by calling `enable_model_cpu_offload`, but are now manually moving the pipeline to GPU. It is strongly recommended against doing so as memory gains from offloading are likely to be lost. Offloading automatically takes care of moving the individual components {', '.join(self.components.keys())} to GPU when needed. To make sure offloading works as expected, you should consider moving the pipeline back to CPU: `pipeline.to('cpu')` or removing the move altogether if you use offloading." - ) - - modules = [m for m in self.components.values() if isinstance(m, torch.nn.Module)] - - is_offloaded = pipeline_is_offloaded or pipeline_is_sequentially_offloaded - for module in modules: - is_loaded_in_8bit = hasattr(module, "is_loaded_in_8bit") and module.is_loaded_in_8bit - - if is_loaded_in_8bit and dtype is not None: - logger.warning( - f"The module '{module.__class__.__name__}' has been loaded in 8bit and conversion to {dtype} is not yet supported. Module is still in 8bit precision." - ) - - if is_loaded_in_8bit and device is not None: - logger.warning( - f"The module '{module.__class__.__name__}' has been loaded in 8bit and moving it to {dtype} via `.to()` is not yet supported. Module is still on {module.device}." - ) - else: - module.to(device, dtype) - - if ( - module.dtype == torch.float16 - and str(device) in ["cpu"] - and not silence_dtype_warnings - and not is_offloaded - ): - logger.warning( - "Pipelines loaded with `dtype=torch.float16` cannot run with `cpu` device. It" - " is not recommended to move them to `cpu` as running them will fail. Please make" - " sure to use an accelerator to run the pipeline in inference, due to the lack of" - " support for`float16` operations on this device in PyTorch. Please, remove the" - " `torch_dtype=torch.float16` argument, or use another device for inference." - ) - return self diff --git a/src/diffusers/pipelines/pipeline_loading_utils.py b/src/diffusers/pipelines/pipeline_loading_utils.py index 2b8afeffa00a..8b422798713f 100644 --- a/src/diffusers/pipelines/pipeline_loading_utils.py +++ b/src/diffusers/pipelines/pipeline_loading_utils.py @@ -331,6 +331,20 @@ def maybe_raise_or_warn( ) +# a simpler version of get_class_obj_and_candidates, it won't work with custom code +def simple_get_class_obj(library_name, class_name): + from diffusers import pipelines + is_pipeline_module = hasattr(pipelines, library_name) + + if is_pipeline_module: + pipeline_module = getattr(pipelines, library_name) + class_obj = getattr(pipeline_module, class_name) + else: + library = importlib.import_module(library_name) + class_obj = getattr(library, class_name) + + return class_obj + def get_class_obj_and_candidates( library_name, class_name, importable_classes, pipelines, is_pipeline_module, component_name=None, cache_dir=None ): @@ -412,7 +426,7 @@ def _get_pipeline_class( revision=revision, ) - if class_obj.__name__ != "DiffusionPipeline" and class_obj.__name__ != "ModularPipeline": + if class_obj.__name__ != "DiffusionPipeline": return class_obj diffusers_module = importlib.import_module(class_obj.__module__.split(".")[0]) @@ -839,7 +853,10 @@ def _fetch_class_library_tuple(module): library = not_compiled_module.__module__ # retrieve class_name - class_name = not_compiled_module.__class__.__name__ + if isinstance(not_compiled_module, type): + class_name = not_compiled_module.__name__ + else: + class_name = not_compiled_module.__class__.__name__ return (library, class_name) diff --git a/src/diffusers/pipelines/pipeline_utils.py b/src/diffusers/pipelines/pipeline_utils.py index 8f9486aa6386..49575e99763a 100644 --- a/src/diffusers/pipelines/pipeline_utils.py +++ b/src/diffusers/pipelines/pipeline_utils.py @@ -1948,9 +1948,10 @@ def from_pipe(cls, pipeline, **kwargs): f"{'' if k.startswith('_') else '_'}{k}": v for k, v in original_config.items() if k not in pipeline_kwargs } + optional_components = pipeline._optional_components if hasattr(pipeline, "_optional_components") and pipeline._optional_components else [] missing_modules = ( set(expected_modules) - - set(pipeline._optional_components) + - set(optional_components) - set(pipeline_kwargs.keys()) - set(true_optional_modules) ) diff --git a/src/diffusers/pipelines/stable_diffusion_xl/__init__.py b/src/diffusers/pipelines/stable_diffusion_xl/__init__.py index 584b260eaaa8..8088fbcfceba 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/__init__.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/__init__.py @@ -29,18 +29,6 @@ _import_structure["pipeline_stable_diffusion_xl_img2img"] = ["StableDiffusionXLImg2ImgPipeline"] _import_structure["pipeline_stable_diffusion_xl_inpaint"] = ["StableDiffusionXLInpaintPipeline"] _import_structure["pipeline_stable_diffusion_xl_instruct_pix2pix"] = ["StableDiffusionXLInstructPix2PixPipeline"] - _import_structure["pipeline_stable_diffusion_xl_modular"] = [ - "StableDiffusionXLControlNetDenoiseStep", - "StableDiffusionXLDecodeLatentsStep", - "StableDiffusionXLDenoiseStep", - "StableDiffusionXLInputStep", - "StableDiffusionXLModularPipeline", - "StableDiffusionXLPrepareAdditionalConditioningStep", - "StableDiffusionXLPrepareLatentsStep", - "StableDiffusionXLSetTimestepsStep", - "StableDiffusionXLTextEncoderStep", - "StableDiffusionXLAutoPipeline", - ] if is_transformers_available() and is_flax_available(): from ...schedulers.scheduling_pndm_flax import PNDMSchedulerState @@ -60,18 +48,6 @@ from .pipeline_stable_diffusion_xl_img2img import StableDiffusionXLImg2ImgPipeline from .pipeline_stable_diffusion_xl_inpaint import StableDiffusionXLInpaintPipeline from .pipeline_stable_diffusion_xl_instruct_pix2pix import StableDiffusionXLInstructPix2PixPipeline - from .pipeline_stable_diffusion_xl_modular import ( - StableDiffusionXLControlNetDenoiseStep, - StableDiffusionXLDecodeLatentsStep, - StableDiffusionXLDenoiseStep, - StableDiffusionXLInputStep, - StableDiffusionXLModularPipeline, - StableDiffusionXLPrepareAdditionalConditioningStep, - StableDiffusionXLPrepareLatentsStep, - StableDiffusionXLSetTimestepsStep, - StableDiffusionXLTextEncoderStep, - StableDiffusionXLAutoPipeline, - ) try: if not (is_transformers_available() and is_flax_available()): diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py deleted file mode 100644 index f743f442cc40..000000000000 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py +++ /dev/null @@ -1,3909 +0,0 @@ -# Copyright 2024 The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import inspect -from typing import Any, List, Optional, Tuple, Union, Dict - -import PIL -import torch -from collections import OrderedDict - -from ...guider import CFGGuider -from ...image_processor import VaeImageProcessor, PipelineImageInput -from ...loaders import StableDiffusionXLLoraLoaderMixin, TextualInversionLoaderMixin, ModularIPAdapterMixin -from ...models import ControlNetModel, ImageProjection -from ...models.attention_processor import AttnProcessor2_0, XFormersAttnProcessor -from ...models.lora import adjust_lora_scale_text_encoder -from ...utils import ( - USE_PEFT_BACKEND, - logging, - scale_lora_layers, - unscale_lora_layers, -) -from ...utils.torch_utils import is_compiled_module, randn_tensor -from ..controlnet.multicontrolnet import MultiControlNetModel -from ..modular_pipeline import ( - AutoPipelineBlocks, - ModularPipeline, - PipelineBlock, - PipelineState, - InputParam, - OutputParam, - SequentialPipelineBlocks, -) -from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin -from .pipeline_output import ( - StableDiffusionXLPipelineOutput, -) - -import numpy as np - -logger = logging.get_logger(__name__) # pylint: disable=invalid-name - - -# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps -def retrieve_timesteps( - scheduler, - num_inference_steps: Optional[int] = None, - device: Optional[Union[str, torch.device]] = None, - timesteps: Optional[List[int]] = None, - sigmas: Optional[List[float]] = None, - **kwargs, -): - r""" - Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles - custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. - - Args: - scheduler (`SchedulerMixin`): - The scheduler to get timesteps from. - num_inference_steps (`int`): - The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` - must be `None`. - device (`str` or `torch.device`, *optional*): - The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. - timesteps (`List[int]`, *optional*): - Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, - `num_inference_steps` and `sigmas` must be `None`. - sigmas (`List[float]`, *optional*): - Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, - `num_inference_steps` and `timesteps` must be `None`. - - Returns: - `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the - second element is the number of inference steps. - """ - if timesteps is not None and sigmas is not None: - raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") - if timesteps is not None: - accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) - if not accepts_timesteps: - raise ValueError( - f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" - f" timestep schedules. Please check whether you are using the correct scheduler." - ) - scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) - timesteps = scheduler.timesteps - num_inference_steps = len(timesteps) - elif sigmas is not None: - accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) - if not accept_sigmas: - raise ValueError( - f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" - f" sigmas schedules. Please check whether you are using the correct scheduler." - ) - scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) - timesteps = scheduler.timesteps - num_inference_steps = len(timesteps) - else: - scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) - timesteps = scheduler.timesteps - return timesteps, num_inference_steps - - -# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents -def retrieve_latents( - encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample" -): - if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": - return encoder_output.latent_dist.sample(generator) - elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": - return encoder_output.latent_dist.mode() - elif hasattr(encoder_output, "latents"): - return encoder_output.latents - else: - raise AttributeError("Could not access latents of provided encoder_output") - - - -class StableDiffusionXLLoraStep(PipelineBlock): - expected_components = ["text_encoder", "text_encoder_2", "unet"] - model_name = "stable-diffusion-xl" - - @property - def description(self) -> str: - return ( - "Lora step that handles all the lora related tasks: load/unload lora weights into unet and text encoders, manage lora adapters etc" - " See [StableDiffusionXLLoraLoaderMixin](https://huggingface.co/docs/diffusers/api/loaders/lora#diffusers.loaders.StableDiffusionXLLoraLoaderMixin)" - " for more details" - ) - - - @property - def inputs(self) -> List[InputParam]: - return [] - - @property - def intermediates_inputs(self) -> List[InputParam]: - return [] - - @property - def intermediates_outputs(self) -> List[OutputParam]: - return [] - - def __init__(self): - super().__init__() - self.components["text_encoder"] = None - self.components["text_encoder_2"] = None - self.components["unet"] = None - - @torch.no_grad() - def __call__(self, pipeline, state: PipelineState) -> PipelineState: - raise EnvironmentError("StableDiffusionXLLoraStep is desgined to be used to load lora weights, __call__ is not implemented") - - -class StableDiffusionXLIPAdapterStep(PipelineBlock): - expected_components = ["image_encoder", "feature_extractor", "unet"] - model_name = "stable-diffusion-xl" - - - @property - def description(self) -> str: - return ( - "IP Adapter step that handles all the ip adapter related tasks: Load/unload ip adapter weights into unet, prepare ip adapter image embeddings, etc" - " See [ModularIPAdapterMixin](https://huggingface.co/docs/diffusers/api/loaders/ip_adapter#diffusers.loaders.ModularIPAdapterMixin)" - " for more details" - ) - - @property - def inputs(self) -> List[InputParam]: - return [ - InputParam( - "ip_adapter_image", - required=True, - type_hint=PipelineImageInput, - description="The image(s) to be used as ip adapter" - ), - InputParam( - "guidance_scale", - default=5.0, - description="Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). Guidance scale is enabled by setting `guidance_scale > 1`." - ), - ] - - @property - def intermediates_inputs(self) -> List[InputParam]: - return [] - - @property - def intermediates_outputs(self) -> List[OutputParam]: - return [ - OutputParam("ip_adapter_embeds", type_hint=torch.Tensor, description="IP adapter image embeddings"), - OutputParam("negative_ip_adapter_embeds", type_hint=torch.Tensor, description="Negative IP adapter image embeddings") - ] - - def __init__(self): - super().__init__() - self.components["image_encoder"] = None - self.components["feature_extractor"] = None - self.components["unet"] = None - - - @torch.no_grad() - def __call__(self, pipeline, state: PipelineState) -> PipelineState: - data = self.get_block_state(state) - - data.do_classifier_free_guidance = data.guidance_scale > 1.0 - data.device = pipeline._execution_device - - data.ip_adapter_embeds = pipeline.prepare_ip_adapter_image_embeds( - ip_adapter_image=data.ip_adapter_image, - ip_adapter_image_embeds=None, - device=data.device, - num_images_per_prompt=1, - do_classifier_free_guidance=data.do_classifier_free_guidance, - ) - if data.do_classifier_free_guidance: - data.negative_ip_adapter_embeds = [] - for i, image_embeds in enumerate(data.ip_adapter_embeds): - negative_image_embeds, image_embeds = image_embeds.chunk(2) - data.negative_ip_adapter_embeds.append(negative_image_embeds) - data.ip_adapter_embeds[i] = image_embeds - - self.add_block_state(state, data) - return pipeline, state - - -class StableDiffusionXLTextEncoderStep(PipelineBlock): - expected_components = ["text_encoder", "text_encoder_2", "tokenizer", "tokenizer_2"] - expected_configs = ["force_zeros_for_empty_prompt"] - model_name = "stable-diffusion-xl" - - @property - def description(self) -> str: - return( - "Text Encoder step that generate text_embeddings to guide the image generation" - ) - - - @property - def inputs(self) -> List[InputParam]: - return [ - InputParam( - name="prompt", - type_hint=Union[str, List[str]], - description="The prompt or prompts to guide the image generation.", - ), - InputParam( - name="prompt_2", - type_hint=Union[str, List[str]], - description="The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is used in both text-encoders", - ), - InputParam( - name="negative_prompt", - type_hint=Union[str, List[str]], - description="The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`).", - ), - InputParam( - name="negative_prompt_2", - type_hint=Union[str, List[str]], - description="The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders", - ), - InputParam( - name="cross_attention_kwargs", - type_hint=Optional[dict], - description="A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under `self.processor` in [diffusers.models.attention_processor]", - ), - InputParam( - name="guidance_scale", - type_hint=float, - default=5.0, - description="Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality.", - ), - InputParam( - name="clip_skip", - type_hint=Optional[int], - ), - ] - - - @property - def intermediates_inputs(self) -> List[InputParam]: - return [] - - @property - def intermediates_outputs(self) -> List[OutputParam]: - return [ - OutputParam("prompt_embeds", type_hint=torch.Tensor, description="text embeddings used to guide the image generation"), - OutputParam("negative_prompt_embeds", type_hint=torch.Tensor, description="negative text embeddings used to guide the image generation"), - OutputParam("pooled_prompt_embeds", type_hint=torch.Tensor, description="pooled text embeddings used to guide the image generation"), - OutputParam("negative_pooled_prompt_embeds", type_hint=torch.Tensor, description="negative pooled text embeddings used to guide the image generation"), - ] - - def __init__(self): - super().__init__() - self.configs["force_zeros_for_empty_prompt"] = True - self.components["text_encoder"] = None - self.components["text_encoder_2"] = None - self.components["tokenizer"] = None - self.components["tokenizer_2"] = None - - def check_inputs(self, pipeline, data): - - if data.prompt is not None and (not isinstance(data.prompt, str) and not isinstance(data.prompt, list)): - raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(data.prompt)}") - elif data.prompt_2 is not None and (not isinstance(data.prompt_2, str) and not isinstance(data.prompt_2, list)): - raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(data.prompt_2)}") - - - @torch.no_grad() - def __call__(self, pipeline, state: PipelineState) -> PipelineState: - # Get inputs and intermediates - data = self.get_block_state(state) - self.check_inputs(pipeline, data) - - data.do_classifier_free_guidance = data.guidance_scale > 1.0 - data.device = pipeline._execution_device - - - # Encode input prompt - data.text_encoder_lora_scale = ( - data.cross_attention_kwargs.get("scale", None) if data.cross_attention_kwargs is not None else None - ) - ( - data.prompt_embeds, - data.negative_prompt_embeds, - data.pooled_prompt_embeds, - data.negative_pooled_prompt_embeds, - ) = pipeline.encode_prompt( - data.prompt, - data.prompt_2, - data.device, - 1, - data.do_classifier_free_guidance, - data.negative_prompt, - data.negative_prompt_2, - prompt_embeds=None, - negative_prompt_embeds=None, - pooled_prompt_embeds=None, - negative_pooled_prompt_embeds=None, - lora_scale=data.text_encoder_lora_scale, - clip_skip=data.clip_skip, - ) - # Add outputs - self.add_block_state(state, data) - return pipeline, state - - -class StableDiffusionXLVaeEncoderStep(PipelineBlock): - expected_components = ["vae"] - model_name = "stable-diffusion-xl" - - - @property - def description(self) -> str: - return ( - "Vae Encoder step that encode the input image into a latent representation" - ) - - @property - def inputs(self) -> List[InputParam]: - return [ - InputParam( - name="image", - type_hint=PipelineImageInput, - required=True, - description="The image(s) to modify with the pipeline, for img2img or inpainting task. When using for inpainting task, parts of the image will be masked out with `mask_image` and repainted according to `prompt`." - ), - InputParam( - name="generator", - type_hint=Optional[Union[torch.Generator, List[torch.Generator]]], - description="One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)" - "to make generation deterministic." - ), - InputParam( - name="height", - type_hint=Optional[int], - description="The height in pixels of the generated image. This is set to 1024 by default for the best results. " - "Anything below 512 pixels won't work well for [stabilityai/stable-diffusion-xl-base-1.0]" - "(https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0) and checkpoints that are not " - "specifically fine-tuned on low resolutions.", - ), - InputParam( - name="width", - type_hint=Optional[int], - description="The width in pixels of the generated image. This is set to 1024 by default for the best results. " - "Anything below 512 pixels won't work well for [stabilityai/stable-diffusion-xl-base-1.0]" - "(https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0) and checkpoints that are not " - "specifically fine-tuned on low resolutions.", - ), - ] - - @property - def intermediates_inputs(self) -> List[InputParam]: - return [ - InputParam("dtype", type_hint=torch.dtype, description="Data type of model tensor inputs"), - InputParam("preprocess_kwargs", type_hint=Optional[dict], description="A kwargs dictionary that if specified is passed along to the `ImageProcessor` as defined under `self.image_processor` in [diffusers.image_processor.VaeImageProcessor]")] - - @property - def intermediates_outputs(self) -> List[OutputParam]: - return [OutputParam("image_latents", type_hint=torch.Tensor, description="The latents representing the reference image for image-to-image/inpainting generation")] - - def __init__(self): - super().__init__() - self.components["vae"] = None - self.auxiliaries["image_processor"] = VaeImageProcessor() - - @torch.no_grad() - def __call__(self, pipeline, state: PipelineState) -> PipelineState: - data = self.get_block_state(state) - data.preprocess_kwargs = data.preprocess_kwargs or {} - data.device = pipeline._execution_device - data.dtype = data.dtype if data.dtype is not None else pipeline.vae.dtype - - data.image = pipeline.image_processor.preprocess(data.image, height=data.height, width=data.width, **data.preprocess_kwargs) - data.image = data.image.to(device=data.device, dtype=data.dtype) - - data.batch_size = data.image.shape[0] - - # if generator is a list, make sure the length of it matches the length of images (both should be batch_size) - if isinstance(data.generator, list) and len(data.generator) != data.batch_size: - raise ValueError( - f"You have passed a list of generators of length {len(data.generator)}, but requested an effective batch" - f" size of {data.batch_size}. Make sure the batch size matches the length of the generators." - ) - - - data.image_latents = pipeline._encode_vae_image(image=data.image, generator=data.generator) - - self.add_block_state(state, data) - - return pipeline, state - - -class StableDiffusionXLInpaintVaeEncoderStep(PipelineBlock): - expected_components = ["vae"] - model_name = "stable-diffusion-xl" - - @property - def description(self) -> str: - return ( - "Vae encoder step that prepares the image and mask for the inpainting process" - ) - - @property - def inputs(self) -> List[InputParam]: - return [ - InputParam( - "height", - type_hint=Optional[int], - description="The height in pixels of the generated image. This is set to 1024 by default for the best results. " - "Anything below 512 pixels won't work well for [stabilityai/stable-diffusion-xl-base-1.0]" - "(https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0) and checkpoints that are not " - "specifically fine-tuned on low resolutions.", - ), - InputParam( - "width", - type_hint=Optional[int], - description="The width in pixels of the generated image. This is set to 1024 by default for the best results. " - "Anything below 512 pixels won't work well for [stabilityai/stable-diffusion-xl-base-1.0]" - "(https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0) and checkpoints that are not " - "specifically fine-tuned on low resolutions.", - ), - InputParam( - "generator", - type_hint=Optional[Union[torch.Generator, List[torch.Generator]]], - description="One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) " - "to make generation deterministic." - ), - InputParam( - "image", - required=True, - type_hint=PipelineImageInput, - description="The image(s) to modify with the pipeline, for img2img or inpainting task. When using for inpainting task, parts of the image will be masked out with `mask_image` and repainted according to `prompt`." - ), - InputParam( - "mask_image", - required=True, - type_hint=PipelineImageInput, - description="`Image`, or tensor representing an image batch, to mask `image`. White pixels in the mask will be " - "repainted, while black pixels will be preserved. If `mask_image` is a PIL image, it will be converted " - "to a single channel (luminance) before use. If it's a tensor, it should contain one color channel (L) " - "instead of 3, so the expected shape would be `(B, H, W, 1)`." - ), - InputParam( - "padding_mask_crop", - type_hint=Optional[Tuple[int, int]], - description="The size of margin in the crop to be applied to the image and masking. If `None`, no crop is applied to " - "image and mask_image. If `padding_mask_crop` is not `None`, it will first find a rectangular region " - "with the same aspect ratio of the image and contains all masked area, and then expand that area based " - "on `padding_mask_crop`. The image and mask_image will then be cropped based on the expanded area before " - "resizing to the original image size for inpainting. This is useful when the masked area is small while " - "the image is large and contain information irrelevant for inpainting, such as background." - ), - ] - - @property - def intermediates_inputs(self) -> List[InputParam]: - return [InputParam("dtype", type_hint=torch.dtype, description="The dtype of the model inputs")] - - @property - def intermediates_outputs(self) -> List[OutputParam]: - return [OutputParam("image_latents", type_hint=torch.Tensor, description="The latents representation of the input image"), - OutputParam("mask", type_hint=torch.Tensor, description="The mask to use for the inpainting process"), - OutputParam("masked_image_latents", type_hint=torch.Tensor, description="The masked image latents to use for the inpainting process (only for inpainting-specifid unet)"), - OutputParam("crops_coords", type_hint=Optional[Tuple[int, int]], description="The crop coordinates to use for the preprocess/postprocess of the image and mask")] - - def __init__(self): - super().__init__() - self.auxiliaries["image_processor"] = VaeImageProcessor() - self.auxiliaries["mask_processor"] = VaeImageProcessor(do_normalize=False, do_binarize=True, do_convert_grayscale=True) - self.components["vae"] = None - - @torch.no_grad() - def __call__(self, pipeline: DiffusionPipeline, state: PipelineState) -> PipelineState: - - data = self.get_block_state(state) - - data.dtype = data.dtype if data.dtype is not None else pipeline.vae.dtype - data.device = pipeline._execution_device - - if data.padding_mask_crop is not None: - data.crops_coords = pipeline.mask_processor.get_crop_region(data.mask_image, data.width, data.height, pad=data.padding_mask_crop) - data.resize_mode = "fill" - else: - data.crops_coords = None - data.resize_mode = "default" - - data.image = pipeline.image_processor.preprocess(data.image, height=data.height, width=data.width, crops_coords=data.crops_coords, resize_mode=data.resize_mode) - data.image = data.image.to(dtype=torch.float32) - - data.mask = pipeline.mask_processor.preprocess(data.mask_image, height=data.height, width=data.width, resize_mode=data.resize_mode, crops_coords=data.crops_coords) - data.masked_image = data.image * (data.mask < 0.5) - - data.batch_size = data.image.shape[0] - data.image = data.image.to(device=data.device, dtype=data.dtype) - data.image_latents = pipeline._encode_vae_image(image=data.image, generator=data.generator) - - # 7. Prepare mask latent variables - data.mask, data.masked_image_latents = pipeline.prepare_mask_latents( - data.mask, - data.masked_image, - data.batch_size, - data.height, - data.width, - data.dtype, - data.device, - data.generator, - ) - - self.add_block_state(state, data) - - - return pipeline, state - - -class StableDiffusionXLInputStep(PipelineBlock): - model_name = "stable-diffusion-xl" - - @property - def description(self) -> str: - return ( - "Input processing step that:\n" - " 1. Determines `batch_size` and `dtype` based on `prompt_embeds`\n" - " 2. Adjusts input tensor shapes based on `batch_size` (number of prompts) and `num_images_per_prompt`\n\n" - "All input tensors are expected to have either batch_size=1 or match the batch_size\n" - "of prompt_embeds. The tensors will be duplicated across the batch dimension to\n" - "have a final batch_size of batch_size * num_images_per_prompt." - ) - - @property - def inputs(self) -> List[InputParam]: - return [ - InputParam( - name="num_images_per_prompt", - type_hint=int, - default=1, - description="The number of images to generate per prompt.", - ), - ] - - @property - def intermediates_inputs(self) -> List[str]: - return [ - InputParam("prompt_embeds", required=True, type_hint=torch.Tensor, description="Pre-generated text embeddings. Can be generated from text_encoder step."), - InputParam("negative_prompt_embeds", type_hint=torch.Tensor, description="Pre-generated negative text embeddings. Can be generated from text_encoder step."), - InputParam("pooled_prompt_embeds", required=True, type_hint=torch.Tensor, description="Pre-generated pooled text embeddings. Can be generated from text_encoder step."), - InputParam("negative_pooled_prompt_embeds", description="Pre-generated negative pooled text embeddings. Can be generated from text_encoder step."), - InputParam("ip_adapter_embeds", type_hint=List[torch.Tensor], description="Pre-generated image embeddings for IP-Adapter. Can be generated from ip_adapter step."), - InputParam("negative_ip_adapter_embeds", type_hint=List[torch.Tensor], description="Pre-generated negative image embeddings for IP-Adapter. Can be generated from ip_adapter step."), - ] - - @property - def intermediates_outputs(self) -> List[str]: - return [ - OutputParam("batch_size", type_hint=int, description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt"), - OutputParam("dtype", type_hint=torch.dtype, description="Data type of model tensor inputs (determined by `prompt_embeds`)"), - OutputParam("prompt_embeds", type_hint=torch.Tensor, description="text embeddings used to guide the image generation"), - OutputParam("negative_prompt_embeds", type_hint=torch.Tensor, description="negative text embeddings used to guide the image generation"), - OutputParam("pooled_prompt_embeds", type_hint=torch.Tensor, description="pooled text embeddings used to guide the image generation"), - OutputParam("negative_pooled_prompt_embeds", type_hint=torch.Tensor, description="negative pooled text embeddings used to guide the image generation"), - OutputParam("ip_adapter_embeds", type_hint=List[torch.Tensor], description="image embeddings for IP-Adapter"), - OutputParam("negative_ip_adapter_embeds", type_hint=List[torch.Tensor], description="negative image embeddings for IP-Adapter"), - ] - - def check_inputs(self, pipeline, data): - - if data.prompt_embeds is not None and data.negative_prompt_embeds is not None: - if data.prompt_embeds.shape != data.negative_prompt_embeds.shape: - raise ValueError( - "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" - f" got: `prompt_embeds` {data.prompt_embeds.shape} != `negative_prompt_embeds`" - f" {data.negative_prompt_embeds.shape}." - ) - - if data.prompt_embeds is not None and data.pooled_prompt_embeds is None: - raise ValueError( - "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`." - ) - - if data.negative_prompt_embeds is not None and data.negative_pooled_prompt_embeds is None: - raise ValueError( - "If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`." - ) - - if data.ip_adapter_embeds is not None and not isinstance(data.ip_adapter_embeds, list): - raise ValueError("`ip_adapter_embeds` must be a list") - - if data.negative_ip_adapter_embeds is not None and not isinstance(data.negative_ip_adapter_embeds, list): - raise ValueError("`negative_ip_adapter_embeds` must be a list") - - if data.ip_adapter_embeds is not None and data.negative_ip_adapter_embeds is not None: - for i, ip_adapter_embed in enumerate(data.ip_adapter_embeds): - if ip_adapter_embed.shape != data.negative_ip_adapter_embeds[i].shape: - raise ValueError( - "`ip_adapter_embeds` and `negative_ip_adapter_embeds` must have the same shape when passed directly, but" - f" got: `ip_adapter_embeds` {ip_adapter_embed.shape} != `negative_ip_adapter_embeds`" - f" {data.negative_ip_adapter_embeds[i].shape}." - ) - - @torch.no_grad() - def __call__(self, pipeline, state: PipelineState) -> PipelineState: - data = self.get_block_state(state) - self.check_inputs(pipeline, data) - - data.batch_size = data.prompt_embeds.shape[0] - data.dtype = data.prompt_embeds.dtype - - _, seq_len, _ = data.prompt_embeds.shape - # duplicate text embeddings for each generation per prompt, using mps friendly method - data.prompt_embeds = data.prompt_embeds.repeat(1, data.num_images_per_prompt, 1) - data.prompt_embeds = data.prompt_embeds.view(data.batch_size * data.num_images_per_prompt, seq_len, -1) - - if data.negative_prompt_embeds is not None: - _, seq_len, _ = data.negative_prompt_embeds.shape - data.negative_prompt_embeds = data.negative_prompt_embeds.repeat(1, data.num_images_per_prompt, 1) - data.negative_prompt_embeds = data.negative_prompt_embeds.view(data.batch_size * data.num_images_per_prompt, seq_len, -1) - - data.pooled_prompt_embeds = data.pooled_prompt_embeds.repeat(1, data.num_images_per_prompt, 1) - data.pooled_prompt_embeds = data.pooled_prompt_embeds.view(data.batch_size * data.num_images_per_prompt, -1) - - if data.negative_pooled_prompt_embeds is not None: - data.negative_pooled_prompt_embeds = data.negative_pooled_prompt_embeds.repeat(1, data.num_images_per_prompt, 1) - data.negative_pooled_prompt_embeds = data.negative_pooled_prompt_embeds.view(data.batch_size * data.num_images_per_prompt, -1) - - if data.ip_adapter_embeds is not None: - for i, ip_adapter_embed in enumerate(data.ip_adapter_embeds): - data.ip_adapter_embeds[i] = torch.cat([ip_adapter_embed] * data.num_images_per_prompt, dim=0) - - if data.negative_ip_adapter_embeds is not None: - for i, negative_ip_adapter_embed in enumerate(data.negative_ip_adapter_embeds): - data.negative_ip_adapter_embeds[i] = torch.cat([negative_ip_adapter_embed] * data.num_images_per_prompt, dim=0) - - self.add_block_state(state, data) - - return pipeline, state - - -class StableDiffusionXLImg2ImgSetTimestepsStep(PipelineBlock): - expected_components = ["scheduler"] - model_name = "stable-diffusion-xl" - - @property - def description(self) -> str: - return ( - "Step that sets the timesteps for the scheduler and determines the initial noise level (latent_timestep) for image-to-image/inpainting generation.\n" + \ - "The latent_timestep is calculated from the `strength` parameter - higher strength means starting from a noisier version of the input image." - ) - - @property - def inputs(self) -> List[InputParam]: - return [ - InputParam( - "num_inference_steps", - default=50, - type_hint=int, - description="The number of denoising steps. More denoising steps usually lead to a higher quality image at the" - " expense of slower inference." - ), - InputParam( - "timesteps", - type_hint=Optional[torch.Tensor], - description="Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed will be used. Must be in descending order." - ), - InputParam( - "sigmas", - type_hint=Optional[torch.Tensor], - description="Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed will be used." - ), - InputParam( - "denoising_end", - type_hint=Optional[float], - description="When specified, determines the fraction (between 0.0 and 1.0) of the total denoising process to be completed before it is intentionally prematurely terminated. As a result, the returned sample will still retain a substantial amount of noise as determined by the discrete timesteps selected by the scheduler. The denoising_end parameter should ideally be utilized when this pipeline forms a part of a 'Mixture of Denoisers' multi-pipeline setup." - ), - InputParam( - "strength", - default=0.3, - type_hint=float, - description="Conceptually, indicates how much to transform the reference `image` (the masked portion of image for inpainting). Must be between 0 and 1. `image` " - "will be used as a starting point, adding more noise to it the larger the `strength`. The number of " - "denoising steps depends on the amount of noise initially added. When `strength` is 1, added noise will " - "be maximum and the denoising process will run for the full number of iterations specified in " - "`num_inference_steps`. A value of 1, therefore, essentially ignores `image`. Note that in the case of " - "`denoising_start` being declared as an integer, the value of `strength` will be ignored." - ), - InputParam( - "denoising_start", - type_hint=Optional[float], - description="The denoising start value to use for the scheduler. Determines the starting point of the denoising process." - ), - InputParam( - "num_images_per_prompt", - default=1, - type_hint=int, - description="The number of images to generate per prompt. Defaults to 1." - ), - ] - - @property - def intermediates_inputs(self) -> List[str]: - return [ - InputParam("batch_size", required=True, type_hint=int, description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt"), - ] - - @property - def intermediates_outputs(self) -> List[str]: - return [ - OutputParam("timesteps", type_hint=torch.Tensor, description="The timesteps to use for inference"), - OutputParam("num_inference_steps", type_hint=int, description="The number of denoising steps to perform at inference time"), - OutputParam("latent_timestep", type_hint=torch.Tensor, description="The timestep that represents the initial noise level for image-to-image generation") - ] - - def __init__(self): - super().__init__() - self.components["scheduler"] = None - - @torch.no_grad() - def __call__(self, pipeline, state: PipelineState) -> PipelineState: - data = self.get_block_state(state) - - data.device = pipeline._execution_device - - data.timesteps, data.num_inference_steps = retrieve_timesteps( - pipeline.scheduler, data.num_inference_steps, data.device, data.timesteps, data.sigmas - ) - - def denoising_value_valid(dnv): - return isinstance(dnv, float) and 0 < dnv < 1 - - data.timesteps, data.num_inference_steps = pipeline.get_timesteps( - data.num_inference_steps, - data.strength, - data.device, - denoising_start=data.denoising_start if denoising_value_valid(data.denoising_start) else None, - ) - data.latent_timestep = data.timesteps[:1].repeat(data.batch_size * data.num_images_per_prompt) - - if data.denoising_end is not None and isinstance(data.denoising_end, float) and data.denoising_end > 0 and data.denoising_end < 1: - data.discrete_timestep_cutoff = int( - round( - pipeline.scheduler.config.num_train_timesteps - - (data.denoising_end * pipeline.scheduler.config.num_train_timesteps) - ) - ) - data.num_inference_steps = len(list(filter(lambda ts: ts >= data.discrete_timestep_cutoff, data.timesteps))) - data.timesteps = data.timesteps[:data.num_inference_steps] - - self.add_block_state(state, data) - - return pipeline, state - - -class StableDiffusionXLSetTimestepsStep(PipelineBlock): - expected_components = ["scheduler"] - model_name = "stable-diffusion-xl" - - @property - def description(self) -> str: - return ( - "Step that sets the scheduler's timesteps for inference" - ) - - @property - def inputs(self) -> List[InputParam]: - return [ - InputParam( - "num_inference_steps", - default=50, - type_hint=int, - description="The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference." - ), - InputParam( - "timesteps", - type_hint=Optional[torch.Tensor], - description="Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed will be used. Must be in descending order." - ), - InputParam( - "sigmas", - type_hint=Optional[torch.Tensor], - description="Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed will be used." - ), - InputParam( - "denoising_end", - type_hint=Optional[float], - description="When specified, determines the fraction (between 0.0 and 1.0) of the total denoising process to be completed before it is intentionally prematurely terminated. As a result, the returned sample will still retain a substantial amount of noise as determined by the discrete timesteps selected by the scheduler. The denoising_end parameter should ideally be utilized when this pipeline forms a part of a 'Mixture of Denoisers' multi-pipeline setup." - ), - ] - - @property - def intermediates_outputs(self) -> List[OutputParam]: - return [OutputParam("timesteps", type_hint=torch.Tensor, description="The timesteps to use for inference"), - OutputParam("num_inference_steps", type_hint=int, description="The number of denoising steps to perform at inference time")] - - @property - def intermediates_inputs(self) -> List[InputParam]: - return [] - - def __init__(self): - super().__init__() - self.components["scheduler"] = None - - @torch.no_grad() - def __call__(self, pipeline, state: PipelineState) -> PipelineState: - data = self.get_block_state(state) - - data.device = pipeline._execution_device - - data.timesteps, data.num_inference_steps = retrieve_timesteps( - pipeline.scheduler, data.num_inference_steps, data.device, data.timesteps, data.sigmas - ) - - if data.denoising_end is not None and isinstance(data.denoising_end, float) and data.denoising_end > 0 and data.denoising_end < 1: - data.discrete_timestep_cutoff = int( - round( - pipeline.scheduler.config.num_train_timesteps - - (data.denoising_end * pipeline.scheduler.config.num_train_timesteps) - ) - ) - data.num_inference_steps = len(list(filter(lambda ts: ts >= data.discrete_timestep_cutoff, data.timesteps))) - data.timesteps = data.timesteps[:data.num_inference_steps] - - self.add_block_state(state, data) - return pipeline, state - - -class StableDiffusionXLInpaintPrepareLatentsStep(PipelineBlock): - expected_components = ["scheduler"] - model_name = "stable-diffusion-xl" - - @property - def description(self) -> str: - return ( - "Step that prepares the latents for the inpainting process" - ) - - @property - def inputs(self) -> List[Tuple[str, Any]]: - return [ - InputParam( - "generator", - type_hint=Optional[Union[torch.Generator, List[torch.Generator]]], - description="One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) " - "to make generation deterministic."), - InputParam( - "latents", - type_hint=Optional[torch.Tensor], - description="Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor will ge generated by sampling using the supplied random `generator`." - ), - InputParam( - "num_images_per_prompt", - default=1, - type_hint=int, - description="The number of images to generate per prompt" - ), - InputParam( - "denoising_start", - type_hint=Optional[float], - description="When specified, indicates the fraction (between 0.0 and 1.0) of the total denoising process to be bypassed before it is initiated. The initial part of the denoising process is skipped and it is assumed that the passed `image` is a partly denoised image. Note that when this is specified, strength will be ignored. Useful for 'Mixture of Denoisers' multi-pipeline setups." - ), - InputParam( - "strength", - default=0.9999, - type_hint=float, - description="Conceptually, indicates how much to transform the reference `image` (the masked portion of image for inpainting). Must be between 0 and 1. `image` " - "will be used as a starting point, adding more noise to it the larger the `strength`. The number of " - "denoising steps depends on the amount of noise initially added. When `strength` is 1, added noise will " - "be maximum and the denoising process will run for the full number of iterations specified in " - "`num_inference_steps`. A value of 1, therefore, essentially ignores `image`. Note that in the case of " - "`denoising_start` being declared as an integer, the value of `strength` will be ignored." - ), - ] - - @property - def intermediates_inputs(self) -> List[str]: - return [ - InputParam( - "batch_size", - required=True, - type_hint=int, - description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can be generated in input step." - ), - InputParam( - "latent_timestep", - required=True, - type_hint=torch.Tensor, - description="The timestep that represents the initial noise level for image-to-image/inpainting generation. Can be generated in set_timesteps step." - ), - InputParam( - "image_latents", - required=True, - type_hint=torch.Tensor, - description="The latents representing the reference image for image-to-image/inpainting generation. Can be generated in vae_encode step." - ), - InputParam( - "mask", - required=True, - type_hint=torch.Tensor, - description="The mask for the inpainting generation. Can be generated in vae_encode step." - ), - InputParam( - "masked_image_latents", - type_hint=torch.Tensor, - description="The masked image latents for the inpainting generation (only for inpainting-specific unet). Can be generated in vae_encode step." - ), - InputParam( - "dtype", - type_hint=torch.dtype, - description="The dtype of the model inputs" - ) - ] - - @property - def intermediates_outputs(self) -> List[str]: - return [OutputParam("latents", type_hint=torch.Tensor, description="The initial latents to use for the denoising process"), - OutputParam("mask", type_hint=torch.Tensor, description="The mask to use for inpainting generation"), - OutputParam("masked_image_latents", type_hint=torch.Tensor, description="The masked image latents to use for the inpainting generation (only for inpainting-specific unet)"), - OutputParam("noise", type_hint=torch.Tensor, description="The noise added to the image latents, used for inpainting generation")] - - def __init__(self): - super().__init__() - self.components["scheduler"] = None - - @torch.no_grad() - def __call__(self, pipeline: DiffusionPipeline, state: PipelineState) -> PipelineState: - data = self.get_block_state(state) - - data.dtype = data.dtype if data.dtype is not None else pipeline.vae.dtype - data.device = pipeline._execution_device - - data.is_strength_max = data.strength == 1.0 - - # for non-inpainting specific unet, we do not need masked_image_latents - if hasattr(pipeline,"unet") and pipeline.unet is not None: - if pipeline.unet.config.in_channels == 4: - data.masked_image_latents = None - - data.add_noise = True if data.denoising_start is None else False - - data.height = data.image_latents.shape[-2] * pipeline.vae_scale_factor - data.width = data.image_latents.shape[-1] * pipeline.vae_scale_factor - - data.latents, data.noise = pipeline.prepare_latents_inpaint( - data.batch_size * data.num_images_per_prompt, - pipeline.num_channels_latents, - data.height, - data.width, - data.dtype, - data.device, - data.generator, - data.latents, - image=data.image_latents, - timestep=data.latent_timestep, - is_strength_max=data.is_strength_max, - add_noise=data.add_noise, - return_noise=True, - return_image_latents=False, - ) - - # 7. Prepare mask latent variables - data.mask, data.masked_image_latents = pipeline.prepare_mask_latents( - data.mask, - data.masked_image_latents, - data.batch_size * data.num_images_per_prompt, - data.height, - data.width, - data.dtype, - data.device, - data.generator, - ) - - self.add_block_state(state, data) - - return pipeline, state - - -class StableDiffusionXLImg2ImgPrepareLatentsStep(PipelineBlock): - expected_components = ["vae", "scheduler"] - model_name = "stable-diffusion-xl" - - @property - def description(self) -> str: - return ( - "Step that prepares the latents for the image-to-image generation process" - ) - - @property - def inputs(self) -> List[Tuple[str, Any]]: - return [ - InputParam( - "generator", - type_hint=Optional[Union[torch.Generator, List[torch.Generator]]], - description="One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) " - "to make generation deterministic." - ), - InputParam( - "latents", - type_hint=Optional[torch.Tensor], - description="Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor will ge generated by sampling using the supplied random `generator`." - ), - InputParam( - "num_images_per_prompt", - default=1, - type_hint=int, - description="The number of images to generate per prompt" - ), - InputParam( - "denoising_start", - type_hint=Optional[float], - description="When specified, indicates the fraction (between 0.0 and 1.0) of the total denoising process to be bypassed before it is initiated. The initial part of the denoising process is skipped and it is assumed that the passed `image` is a partly denoised image. Note that when this is specified, strength will be ignored. Useful for 'Mixture of Denoisers' multi-pipeline setups." - ), - ] - - @property - def intermediates_inputs(self) -> List[InputParam]: - return [ - InputParam("latent_timestep", required=True, type_hint=torch.Tensor, description="The timestep that represents the initial noise level for image-to-image/inpainting generation. Can be generated in set_timesteps step."), - InputParam("image_latents", required=True, type_hint=torch.Tensor, description="The latents representing the reference image for image-to-image/inpainting generation. Can be generated in vae_encode step."), - InputParam("batch_size", required=True, type_hint=int, description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can be generated in input step."), - InputParam("dtype", required=True, type_hint=torch.dtype, description="The dtype of the model inputs")] - - @property - def intermediates_outputs(self) -> List[OutputParam]: - return [OutputParam("latents", type_hint=torch.Tensor, description="The initial latents to use for the denoising process")] - - def __init__(self): - super().__init__() - self.components["scheduler"] = None - - @torch.no_grad() - def __call__(self, pipeline: DiffusionPipeline, state: PipelineState) -> PipelineState: - data = self.get_block_state(state) - - data.dtype = data.dtype if data.dtype is not None else pipeline.vae.dtype - data.device = pipeline._execution_device - data.add_noise = True if data.denoising_start is None else False - if data.latents is None: - data.latents = pipeline.prepare_latents_img2img( - data.image_latents, - data.latent_timestep, - data.batch_size, - data.num_images_per_prompt, - data.dtype, - data.device, - data.generator, - data.add_noise, - ) - - self.add_block_state(state, data) - - return pipeline, state - - -class StableDiffusionXLPrepareLatentsStep(PipelineBlock): - expected_components = ["scheduler"] - model_name = "stable-diffusion-xl" - - @property - def description(self) -> str: - return ( - "Prepare latents step that prepares the latents for the text-to-image generation process" - ) - - @property - def inputs(self) -> List[InputParam]: - return [ - InputParam( - "height", - type_hint=Optional[int], - description="The height in pixels of the generated image. This is set to 1024 by default for the best results. " - "Anything below 512 pixels won't work well for [stabilityai/stable-diffusion-xl-base-1.0]" - "(https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0) and checkpoints that are not " - "specifically fine-tuned on low resolutions."), - InputParam( - "width", - type_hint=Optional[int], - description="The width in pixels of the generated image. This is set to 1024 by default for the best results. " - "Anything below 512 pixels won't work well for [stabilityai/stable-diffusion-xl-base-1.0]" - "(https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0) and checkpoints that are not " - "specifically fine-tuned on low resolutions."), - InputParam( - "generator", - type_hint=Optional[Union[torch.Generator, List[torch.Generator]]], - description="One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) " - "to make generation deterministic." - ), - InputParam( - "latents", - type_hint=Optional[torch.Tensor], - description="Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor will ge generated by sampling using the supplied random `generator`." - ), - InputParam( - "num_images_per_prompt", - default=1, - type_hint=int, - description="The number of images to generate per prompt" - ), - ] - - @property - def intermediates_inputs(self) -> List[InputParam]: - return [ - InputParam( - "batch_size", - required=True, - type_hint=int, - description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can be generated in input step." - ), - InputParam( - "dtype", - type_hint=torch.dtype, - description="The dtype of the model inputs" - ) - ] - - @property - def intermediates_outputs(self) -> List[OutputParam]: - return [ - OutputParam( - "latents", - type_hint=torch.Tensor, - description="The initial latents to use for the denoising process" - ) - ] - - def __init__(self): - super().__init__() - self.components["scheduler"] = None - - @staticmethod - def check_inputs(pipeline, data): - if ( - data.height is not None - and data.height % pipeline.vae_scale_factor != 0 - or data.width is not None - and data.width % pipeline.vae_scale_factor != 0 - ): - raise ValueError( - f"`height` and `width` have to be divisible by {pipeline.vae_scale_factor} but are {data.height} and {data.width}." - ) - - @torch.no_grad() - def __call__(self, pipeline: DiffusionPipeline, state: PipelineState) -> PipelineState: - data = self.get_block_state(state) - - if data.dtype is None: - data.dtype = pipeline.vae.dtype - - data.device = pipeline._execution_device - - self.check_inputs(pipeline, data) - - data.height = data.height or pipeline.default_sample_size * pipeline.vae_scale_factor - data.width = data.width or pipeline.default_sample_size * pipeline.vae_scale_factor - data.num_channels_latents = pipeline.num_channels_latents - data.latents = pipeline.prepare_latents( - data.batch_size * data.num_images_per_prompt, - data.num_channels_latents, - data.height, - data.width, - data.dtype, - data.device, - data.generator, - data.latents, - ) - - self.add_block_state(state, data) - - return pipeline, state - - -class StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep(PipelineBlock): - expected_configs = ["requires_aesthetics_score"] - model_name = "stable-diffusion-xl" - - @property - def description(self) -> str: - return ( - "Step that prepares the additional conditioning for the image-to-image/inpainting generation process" - ) - - @property - def inputs(self) -> List[Tuple[str, Any]]: - return [ - InputParam( - "original_size", - type_hint=Optional[Tuple[int]], - description="If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled. " - "`original_size` defaults to `(height, width)` if not specified. Part of SDXL's micro-conditioning as " - "explained in section 2.2 of https://huggingface.co/papers/2307.01952" - ), - InputParam( - "target_size", - type_hint=Optional[Tuple[int]], - description="For most cases, `target_size` should be set to the desired height and width of the generated image. If " - "not specified it will default to `(height, width)`. Part of SDXL's micro-conditioning as explained in " - "section 2.2 of https://huggingface.co/papers/2307.01952" - ), - InputParam( - "negative_original_size", - type_hint=Optional[Tuple[int]], - description="To negatively condition the generation process based on a specific image resolution. Part of SDXL's " - "micro-conditioning as explained in section 2.2 of https://huggingface.co/papers/2307.01952" - ), - InputParam( - "negative_target_size", - type_hint=Optional[Tuple[int]], - description="To negatively condition the generation process based on a target image resolution. It should be as same " - "as the `target_size` for most cases. Part of SDXL's micro-conditioning as explained in section 2.2 of " - "https://huggingface.co/papers/2307.01952" - ), - InputParam( - "crops_coords_top_left", - default=(0, 0), - type_hint=Tuple[int], - description="`crops_coords_top_left` can be used to generate an image that appears to be \"cropped\" from the position " - "`crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning" - ), - InputParam( - "negative_crops_coords_top_left", - default=(0, 0), - type_hint=Tuple[int], - description="To negatively condition the generation process based on a specific crop coordinates. Part of SDXL's " - "micro-conditioning" - ), - InputParam( - "num_images_per_prompt", - default=1, - type_hint=int, - description="The number of images to generate per prompt." - ), - InputParam( - "guidance_scale", - default=5.0, - type_hint=float, - description="Guidance scale as defined in Classifier-Free Diffusion Guidance. `guidance_scale` is defined as `w` of equation 2. " - "Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, " - "usually at the expense of lower image quality." - ), - InputParam( - "aesthetic_score", - default=6.0, - type_hint=float, - description="Used to simulate an aesthetic score of the generated image by influencing the positive text condition. " - "Part of SDXL's micro-conditioning as explained in section 2.2 of https://huggingface.co/papers/2307.01952" - ), - InputParam( - "negative_aesthetic_score", - default=2.0, - type_hint=float, - description="Part of SDXL's micro-conditioning as explained in section 2.2 of https://huggingface.co/papers/2307.01952. " - "Can be used to simulate an aesthetic score of the generated image by influencing the negative text condition." - ), - ] - - @property - def intermediates_inputs(self) -> List[InputParam]: - return [ - InputParam("latents", required=True, type_hint=torch.Tensor, description="The initial latents to use for the denoising process. Can be generated in prepare_latent step."), - InputParam("pooled_prompt_embeds", required=True, type_hint=torch.Tensor, description="The pooled prompt embeddings to use for the denoising process (used to determine shapes and dtypes for other additional conditioning inputs). Can be generated in text_encoder step."), - InputParam("batch_size", required=True, type_hint=int, description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can be generated in input step."), - ] - - @property - def intermediates_outputs(self) -> List[OutputParam]: - return [OutputParam("add_time_ids", type_hint=torch.Tensor, description="The time ids to condition the denoising process"), - OutputParam("negative_add_time_ids", type_hint=torch.Tensor, description="The negative time ids to condition the denoising process"), - OutputParam("timestep_cond", type_hint=torch.Tensor, description="The timestep cond to use for LCM")] - - def __init__(self): - super().__init__() - self.configs["requires_aesthetics_score"] = False - - @torch.no_grad() - def __call__(self, pipeline: DiffusionPipeline, state: PipelineState) -> PipelineState: - data = self.get_block_state(state) - data.device = pipeline._execution_device - - data.vae_scale_factor = pipeline.vae_scale_factor - - data.height, data.width = data.latents.shape[-2:] - data.height = data.height * data.vae_scale_factor - data.width = data.width * data.vae_scale_factor - - data.original_size = data.original_size or (data.height, data.width) - data.target_size = data.target_size or (data.height, data.width) - - data.text_encoder_projection_dim = int(data.pooled_prompt_embeds.shape[-1]) - - if data.negative_original_size is None: - data.negative_original_size = data.original_size - if data.negative_target_size is None: - data.negative_target_size = data.target_size - - data.add_time_ids, data.negative_add_time_ids = pipeline._get_add_time_ids_img2img( - data.original_size, - data.crops_coords_top_left, - data.target_size, - data.aesthetic_score, - data.negative_aesthetic_score, - data.negative_original_size, - data.negative_crops_coords_top_left, - data.negative_target_size, - dtype=data.pooled_prompt_embeds.dtype, - text_encoder_projection_dim=data.text_encoder_projection_dim, - ) - data.add_time_ids = data.add_time_ids.repeat(data.batch_size * data.num_images_per_prompt, 1).to(device=data.device) - data.negative_add_time_ids = data.negative_add_time_ids.repeat(data.batch_size * data.num_images_per_prompt, 1).to(device=data.device) - - # Optionally get Guidance Scale Embedding for LCM - data.timestep_cond = None - if ( - hasattr(pipeline, "unet") - and pipeline.unet is not None - and pipeline.unet.config.time_cond_proj_dim is not None - ): - data.guidance_scale_tensor = torch.tensor(data.guidance_scale - 1).repeat(data.batch_size * data.num_images_per_prompt) - data.timestep_cond = pipeline.get_guidance_scale_embedding( - data.guidance_scale_tensor, embedding_dim=pipeline.unet.config.time_cond_proj_dim - ).to(device=data.device, dtype=data.latents.dtype) - - self.add_block_state(state, data) - return pipeline, state - - -class StableDiffusionXLPrepareAdditionalConditioningStep(PipelineBlock): - model_name = "stable-diffusion-xl" - - @property - def description(self) -> str: - return ( - "Step that prepares the additional conditioning for the text-to-image generation process" - ) - - @property - def inputs(self) -> List[Tuple[str, Any]]: - return [ - InputParam( - "original_size", - type_hint=Tuple[int, int], - default=(1024, 1024), - description="The original size (height, width) of the image that conditions the generation process. If different from target_size, the image will appear to be down- or upsampled. Part of SDXL's micro-conditioning as explained in section 2.2 of https://huggingface.co/papers/2307.01952" - ), - InputParam( - "target_size", - type_hint=Tuple[int, int], - default=(1024, 1024), - description="The target size (height, width) of the generated image. For most cases, this should be set to the desired output dimensions. Part of SDXL's micro-conditioning as explained in section 2.2 of https://huggingface.co/papers/2307.01952" - ), - InputParam( - "negative_original_size", - type_hint=Tuple[int, int], - default=(1024, 1024), - description="The negative original size to condition against during generation. Part of SDXL's micro-conditioning as explained in section 2.2 of https://huggingface.co/papers/2307.01952. See: https://github.com/huggingface/diffusers/issues/4208" - ), - InputParam( - "negative_target_size", - type_hint=Tuple[int, int], - default=(1024, 1024), - description="The negative target size to condition against during generation. Should typically match target_size. Part of SDXL's micro-conditioning as explained in section 2.2 of https://huggingface.co/papers/2307.01952. See: https://github.com/huggingface/diffusers/issues/4208" - ), - InputParam( - "crops_coords_top_left", - default=(0, 0), - type_hint=Tuple[int, int], - description="The top-left coordinates (x, y) used to condition the generation process. Setting this to (0, 0) typically produces well-centered images. Part of SDXL's micro-conditioning as explained in section 2.2 of https://huggingface.co/papers/2307.01952" - ), - InputParam( - "negative_crops_coords_top_left", - default=(0, 0), - type_hint=Tuple[int, int], - description="The top-left coordinates (x, y) used to negatively condition the generation process. Part of SDXL's micro-conditioning as explained in section 2.2 of https://huggingface.co/papers/2307.01952. For more information, see: https://github.com/huggingface/diffusers/issues/4208" - ), - InputParam( - "num_images_per_prompt", - default=1, - type_hint=int, - description="The number of images to generate per prompt" - ), - InputParam( - "guidance_scale", - default=5.0, - type_hint=float, - description="Guidance scale as defined in Classifier-Free Diffusion Guidance. `guidance_scale` is defined as `w` of equation 2. " - "Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, " - "usually at the expense of lower image quality."), - ] - - @property - def intermediates_inputs(self) -> List[InputParam]: - return [ - InputParam( - "latents", - required=True, - type_hint=torch.Tensor, - description="The initial latents to use for the denoising process. Can be generated in prepare_latent step." - ), - InputParam( - "pooled_prompt_embeds", - required=True, - type_hint=torch.Tensor, - description="The pooled prompt embeddings to use for the denoising process (used to determine shapes and dtypes for other additional conditioning inputs). Can be generated in text_encoder step." - ), - InputParam( - "batch_size", - required=True, - type_hint=int, - description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can be generated in input step." - ), - ] - - @property - def intermediates_outputs(self) -> List[OutputParam]: - return [OutputParam("add_time_ids", type_hint=torch.Tensor, description="The time ids to condition the denoising process"), - OutputParam("negative_add_time_ids", type_hint=torch.Tensor, description="The negative time ids to condition the denoising process"), - OutputParam("timestep_cond", type_hint=torch.Tensor, description="The timestep cond to use for LCM")] - - @torch.no_grad() - def __call__(self, pipeline: DiffusionPipeline, state: PipelineState) -> PipelineState: - data = self.get_block_state(state) - data.device = pipeline._execution_device - - data.height, data.width = data.latents.shape[-2:] - data.height = data.height * pipeline.vae_scale_factor - data.width = data.width * pipeline.vae_scale_factor - - data.original_size = data.original_size or (data.height, data.width) - data.target_size = data.target_size or (data.height, data.width) - - data.text_encoder_projection_dim = int(data.pooled_prompt_embeds.shape[-1]) - - data.add_time_ids = pipeline._get_add_time_ids( - data.original_size, - data.crops_coords_top_left, - data.target_size, - data.pooled_prompt_embeds.dtype, - text_encoder_projection_dim=data.text_encoder_projection_dim, - ) - if data.negative_original_size is not None and data.negative_target_size is not None: - data.negative_add_time_ids = pipeline._get_add_time_ids( - data.negative_original_size, - data.negative_crops_coords_top_left, - data.negative_target_size, - data.pooled_prompt_embeds.dtype, - text_encoder_projection_dim=data.text_encoder_projection_dim, - ) - else: - data.negative_add_time_ids = data.add_time_ids - - data.add_time_ids = data.add_time_ids.repeat(data.batch_size * data.num_images_per_prompt, 1).to(device=data.device) - data.negative_add_time_ids = data.negative_add_time_ids.repeat(data.batch_size * data.num_images_per_prompt, 1).to(device=data.device) - - # Optionally get Guidance Scale Embedding for LCM - data.timestep_cond = None - if ( - hasattr(pipeline, "unet") - and pipeline.unet is not None - and pipeline.unet.config.time_cond_proj_dim is not None - ): - data.guidance_scale_tensor = torch.tensor(data.guidance_scale - 1).repeat(data.batch_size * data.num_images_per_prompt) - data.timestep_cond = pipeline.get_guidance_scale_embedding( - data.guidance_scale_tensor, embedding_dim=pipeline.unet.config.time_cond_proj_dim - ).to(device=data.device, dtype=data.latents.dtype) - - self.add_block_state(state, data) - return pipeline, state - - -class StableDiffusionXLDenoiseStep(PipelineBlock): - expected_components = ["unet", "scheduler", "guider"] - model_name = "stable-diffusion-xl" - - @property - def description(self) -> str: - return ( - "Step that iteratively denoise the latents for the text-to-image/image-to-image/inpainting generation process" - ) - - @property - def inputs(self) -> List[Tuple[str, Any]]: - return [ - InputParam( - "guidance_scale", - type_hint=float, - default=5.0, - description="Guidance scale as defined in Classifier-Free Diffusion Guidance. Higher values encourage images closely linked to the text prompt, potentially at the expense of image quality. Enabled when > 1." - ), - InputParam( - "guidance_rescale", - type_hint=float, - default=0.0, - description="Guidance rescale factor (φ) to fix overexposure when using zero terminal SNR, as proposed in 'Common Diffusion Noise Schedules and Sample Steps are Flawed'." - ), - InputParam( - "cross_attention_kwargs", - type_hint=Optional[Dict[str, Any]], - default=None, - description="Optional kwargs dictionary passed to the AttentionProcessor." - ), - InputParam( - "generator", - type_hint=Optional[Union[torch.Generator, List[torch.Generator]]], - description="One or a list of torch generator(s) to make generation deterministic." - ), - InputParam( - "eta", - type_hint=float, - default=0.0, - description="Parameter η in the DDIM paper. Only applies to DDIMScheduler, ignored for others." - ), - InputParam( - "guider_kwargs", - type_hint=Optional[Dict[str, Any]], - default=None, - description="Optional kwargs dictionary passed to the Guider." - ), - InputParam( - "num_images_per_prompt", - type_hint=int, - default=1, - description="The number of images to generate per prompt." - ), - ] - - @property - def intermediates_inputs(self) -> List[str]: - return [ - InputParam( - "latents", - required=True, - type_hint=torch.Tensor, - description="The initial latents to use for the denoising process. Can be generated in prepare_latent step." - ), - InputParam( - "batch_size", - required=True, - type_hint=int, - description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can be generated in input step." - ), - InputParam( - "timesteps", - required=True, - type_hint=torch.Tensor, - description="The timesteps to use for the denoising process. Can be generated in set_timesteps step." - ), - InputParam( - "num_inference_steps", - required=True, - type_hint=int, - description="The number of inference steps to use for the denoising process. Can be generated in set_timesteps step." - ), - InputParam( - "pooled_prompt_embeds", - required=True, - type_hint=torch.Tensor, - description="The pooled prompt embeddings to use to condition the denoising process. Can be generated in text_encoder step." - ), - InputParam( - "negative_pooled_prompt_embeds", - type_hint=Optional[torch.Tensor], - description="The negative pooled prompt embeddings to use to condition the denoising process. Can be generated in text_encoder step. " - ), - InputParam( - "add_time_ids", - required=True, - type_hint=torch.Tensor, - description="The time ids to use as additional conditioning for the denoising process. Can be generated in prepare_additional_conditioning step." - ), - InputParam( - "negative_add_time_ids", - type_hint=Optional[torch.Tensor], - description="The negative time ids to use as additional conditioning for the denoising process. Can be generated in prepare_additional_conditioning step." - ), - InputParam( - "prompt_embeds", - required=True, - type_hint=torch.Tensor, - description="The prompt embeddings to use to condition the denoising process. Can be generated in text_encoder step." - ), - InputParam( - "negative_prompt_embeds", - type_hint=Optional[torch.Tensor], - description="The negative prompt embeddings to use to condition the denoising process. Can be generated in text_encoder step. " - ), - InputParam( - "timestep_cond", - type_hint=Optional[torch.Tensor], - description="The guidance scale embedding to use for Latent Consistency Models(LCMs). Can be generated in prepare_additional_conditioning step." - ), - InputParam( - "mask", - type_hint=Optional[torch.Tensor], - description="The mask to use for the denoising process, for inpainting task only. Can be generated in vae_encode or prepare_latent step." - ), - InputParam( - "masked_image_latents", - type_hint=Optional[torch.Tensor], - description="The masked image latents to use for the denoising process, for inpainting task only. Can be generated in vae_encode or prepare_latent step." - ), - InputParam( - "noise", - type_hint=Optional[torch.Tensor], - description="The noise added to the image latents, for inpainting task only. Can be generated in prepare_latent step." - ), - InputParam( - "image_latents", - type_hint=Optional[torch.Tensor], - description="The image latents to use for the denoising process, for inpainting/image-to-image task only. Can be generated in vae_encode or prepare_latent step." - ), - InputParam( - "ip_adapter_embeds", - type_hint=Optional[torch.Tensor], - description="The ip adapter embeddings to use to condition the denoising process, need to have ip adapter model loaded. Can be generated in ip_adapter step." - ), - InputParam( - "negative_ip_adapter_embeds", - type_hint=Optional[torch.Tensor], - description="The negative ip adapter embeddings to use to condition the denoising process, need to have ip adapter model loaded. Can be generated in ip_adapter step." - ), - ] - - @property - def intermediates_outputs(self) -> List[OutputParam]: - return [OutputParam("latents", type_hint=torch.Tensor, description="The denoised latents")] - - def __init__(self): - super().__init__() - self.components["guider"] = CFGGuider() - self.components["scheduler"] = None - self.components["unet"] = None - - def check_inputs(self, pipeline, data): - - num_channels_unet = pipeline.unet.config.in_channels - if num_channels_unet == 9: - # default case for runwayml/stable-diffusion-inpainting - if data.mask is None or data.masked_image_latents is None: - raise ValueError("mask and masked_image_latents must be provided for inpainting-specific Unet") - num_channels_latents = data.latents.shape[1] - num_channels_mask = data.mask.shape[1] - num_channels_masked_image = data.masked_image_latents.shape[1] - if num_channels_latents + num_channels_mask + num_channels_masked_image != num_channels_unet: - raise ValueError( - f"Incorrect configuration settings! The config of `pipeline.unet`: {pipeline.unet.config} expects" - f" {pipeline.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +" - f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}" - f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of" - " `pipeline.unet` or your `mask_image` or `image` input." - ) - - - @torch.no_grad() - def __call__(self, pipeline, state: PipelineState) -> PipelineState: - - data = self.get_block_state(state) - self.check_inputs(pipeline, data) - - data.num_channels_unet = pipeline.unet.config.in_channels - data.disable_guidance = True if pipeline.unet.config.time_cond_proj_dim is not None else False - - # adding default guider arguments: do_classifier_free_guidance, guidance_scale, guidance_rescale - data.guider_kwargs = data.guider_kwargs or {} - data.guider_kwargs = { - **data.guider_kwargs, - "disable_guidance": data.disable_guidance, - "guidance_scale": data.guidance_scale, - "guidance_rescale": data.guidance_rescale, - "batch_size": data.batch_size * data.num_images_per_prompt, - } - - pipeline.guider.set_guider(pipeline, data.guider_kwargs) - # Prepare conditional inputs using the guider - data.prompt_embeds = pipeline.guider.prepare_input( - data.prompt_embeds, - data.negative_prompt_embeds, - ) - data.add_time_ids = pipeline.guider.prepare_input( - data.add_time_ids, - data.negative_add_time_ids, - ) - data.pooled_prompt_embeds = pipeline.guider.prepare_input( - data.pooled_prompt_embeds, - data.negative_pooled_prompt_embeds, - ) - - if data.num_channels_unet == 9: - data.mask = pipeline.guider.prepare_input(data.mask, data.mask) - data.masked_image_latents = pipeline.guider.prepare_input(data.masked_image_latents, data.masked_image_latents) - - data.added_cond_kwargs = { - "text_embeds": data.pooled_prompt_embeds, - "time_ids": data.add_time_ids, - } - - if data.ip_adapter_embeds is not None: - data.ip_adapter_embeds = pipeline.guider.prepare_input(data.ip_adapter_embeds, data.negative_ip_adapter_embeds) - data.added_cond_kwargs["image_embeds"] = data.ip_adapter_embeds - - # Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline - data.extra_step_kwargs = pipeline.prepare_extra_step_kwargs(data.generator, data.eta) - data.num_warmup_steps = max(len(data.timesteps) - data.num_inference_steps * pipeline.scheduler.order, 0) - - with pipeline.progress_bar(total=data.num_inference_steps) as progress_bar: - for i, t in enumerate(data.timesteps): - # expand the latents if we are doing classifier free guidance - data.latent_model_input = pipeline.guider.prepare_input(data.latents, data.latents) - data.latent_model_input = pipeline.scheduler.scale_model_input(data.latent_model_input, t) - - # inpainting - if data.num_channels_unet == 9: - data.latent_model_input = torch.cat([data.latent_model_input, data.mask, data.masked_image_latents], dim=1) - - # predict the noise residual - data.noise_pred = pipeline.unet( - data.latent_model_input, - t, - encoder_hidden_states=data.prompt_embeds, - timestep_cond=data.timestep_cond, - cross_attention_kwargs=data.cross_attention_kwargs, - added_cond_kwargs=data.added_cond_kwargs, - return_dict=False, - )[0] - # perform guidance - data.noise_pred = pipeline.guider.apply_guidance( - data.noise_pred, - timestep=t, - latents=data.latents, - ) - # compute the previous noisy sample x_t -> x_t-1 - data.latents_dtype = data.latents.dtype - data.latents = pipeline.scheduler.step(data.noise_pred, t, data.latents, **data.extra_step_kwargs, return_dict=False)[0] - if data.latents.dtype != data.latents_dtype: - if torch.backends.mps.is_available(): - # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 - data.latents = data.latents.to(data.latents_dtype) - - if data.num_channels_unet == 4 and data.mask is not None and data.image_latents is not None: - data.init_latents_proper = data.image_latents - if i < len(data.timesteps) - 1: - data.noise_timestep = data.timesteps[i + 1] - data.init_latents_proper = pipeline.scheduler.add_noise( - data.init_latents_proper, data.noise, torch.tensor([data.noise_timestep]) - ) - - data.latents = (1 - data.mask) * data.init_latents_proper + data.mask * data.latents - - if i == len(data.timesteps) - 1 or ((i + 1) > data.num_warmup_steps and (i + 1) % pipeline.scheduler.order == 0): - progress_bar.update() - - pipeline.guider.reset_guider(pipeline) - self.add_block_state(state, data) - - return pipeline, state - - -class StableDiffusionXLControlNetDenoiseStep(PipelineBlock): - expected_components = ["unet", "controlnet", "scheduler", "guider", "controlnet_guider"] - model_name = "stable-diffusion-xl" - - @property - def description(self) -> str: - return "step that iteratively denoise the latents for the text-to-image/image-to-image/inpainting generation process. Using ControlNet to condition the denoising process" - - @property - def inputs(self) -> List[Tuple[str, Any]]: - return [ - InputParam( - "control_image", - required=True, - type_hint=PipelineImageInput, - description="The ControlNet input condition to provide guidance to the unet for generation. If passed as torch.Tensor, it is used as-is. PIL.Image.Image inputs are accepted and default to image dimensions. For multiple ControlNets, pass images as a list for proper batching." - ), - InputParam( - "control_guidance_start", - default=0.0, - type_hint=Union[float, List[float]], - description="The percentage of total steps at which the ControlNet starts applying." - ), - InputParam( - "control_guidance_end", - default=1.0, - type_hint=Union[float, List[float]], - description="The percentage of total steps at which the ControlNet stops applying." - ), - InputParam( - "controlnet_conditioning_scale", - default=1.0, - type_hint=Union[float, List[float]], - description="Scale factor for ControlNet outputs before adding to unet residual. For multiple ControlNets, can be set as a list of scales." - ), - InputParam( - "guess_mode", - default=False, - type_hint=bool, - description="Enables ControlNet encoder to recognize input image content without prompts. Recommended guidance_scale: 3.0-5.0." - ), - InputParam( - "num_images_per_prompt", - default=1, - type_hint=int, - description="The number of images to generate per prompt." - ), - InputParam( - "guidance_scale", - default=5.0, - type_hint=float, - description="Guidance scale as defined in Classifier-Free Diffusion Guidance. Higher values encourage images closely linked to the text prompt, potentially at the expense of image quality. Enabled when > 1." - ), - InputParam( - "guidance_rescale", - default=0.0, - type_hint=float, - description="Guidance rescale factor (φ) to fix overexposure when using zero terminal SNR, as proposed in 'Common Diffusion Noise Schedules and Sample Steps are Flawed'." - ), - InputParam( - "cross_attention_kwargs", - default=None, - type_hint=Optional[Dict[str, Any]], - description="Optional kwargs dictionary passed to the AttentionProcessor." - ), - InputParam( - "generator", - default=None, - type_hint=Optional[Union[torch.Generator, List[torch.Generator]]], - description="One or a list of torch generator(s) to make generation deterministic." - ), - InputParam( - "eta", - default=0.0, - type_hint=float, - description="Parameter η in the DDIM paper. Only applies to DDIMScheduler, ignored for others." - ), - InputParam( - "guider_kwargs", - default=None, - type_hint=Optional[Dict[str, Any]], - description="Optional kwargs dictionary passed to the Guider." - ), - ] - - @property - def intermediates_inputs(self) -> List[str]: - return [ - InputParam( - "latents", - required=True, - type_hint=torch.Tensor, - description="The initial latents to use for the denoising process. Can be generated in prepare_latent step." - ), - InputParam( - "batch_size", - required=True, - type_hint=int, - description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can be generated in input step." - ), - InputParam( - "timesteps", - required=True, - type_hint=torch.Tensor, - description="The timesteps to use for the denoising process. Can be generated in set_timesteps step." - ), - InputParam( - "num_inference_steps", - required=True, - type_hint=int, - description="The number of inference steps to use for the denoising process. Can be generated in set_timesteps step." - ), - InputParam( - "prompt_embeds", - required=True, - type_hint=torch.Tensor, - description="The prompt embeddings used to condition the denoising process. Can be generated in text_encoder step." - ), - InputParam( - "negative_prompt_embeds", - type_hint=Optional[torch.Tensor], - description="The negative prompt embeddings used to condition the denoising process. Can be generated in text_encoder step." - ), - InputParam( - "add_time_ids", - required=True, - type_hint=torch.Tensor, - description="The time ids used to condition the denoising process. Can be generated in parepare_additional_conditioning step." - ), - InputParam( - "negative_add_time_ids", - type_hint=Optional[torch.Tensor], - description="The negative time ids used to condition the denoising process. Can be generated in parepare_additional_conditioning step." - ), - InputParam( - "pooled_prompt_embeds", - required=True, - type_hint=torch.Tensor, - description="The pooled prompt embeddings used to condition the denoising process. Can be generated in text_encoder step." - ), - InputParam( - "negative_pooled_prompt_embeds", - type_hint=Optional[torch.Tensor], - description="The negative pooled prompt embeddings to use to condition the denoising process. Can be generated in text_encoder step." - ), - InputParam( - "timestep_cond", - type_hint=Optional[torch.Tensor], - description="The guidance scale embedding to use for Latent Consistency Models(LCMs), can be generated by prepare_additional_conditioning step" - ), - InputParam( - "mask", - type_hint=Optional[torch.Tensor], - description="The mask to use for the denoising process, for inpainting task only. Can be generated in vae_encode or prepare_latent step." - ), - InputParam( - "masked_image_latents", - type_hint=Optional[torch.Tensor], - description="The masked image latents to use for the denoising process, for inpainting task only. Can be generated in vae_encode or prepare_latent step." - ), - InputParam( - "noise", - type_hint=Optional[torch.Tensor], - description="The noise added to the image latents, for inpainting task only. Can be generated in prepare_latent step." - ), - InputParam( - "image_latents", - type_hint=Optional[torch.Tensor], - description="The image latents to use for the denoising process, for inpainting/image-to-image task only. Can be generated in vae_encode or prepare_latent step." - ), - InputParam( - "crops_coords", - type_hint=Optional[Tuple[int]], - description="The crop coordinates to use for preprocess/postprocess the image and mask, for inpainting task only. Can be generated in vae_encode step." - ), - InputParam( - "ip_adapter_embeds", - type_hint=Optional[torch.Tensor], - description="The ip adapter embeddings to use to condition the denoising process, need to have ip adapter model loaded. Can be generated in ip_adapter step." - ), - InputParam( - "negative_ip_adapter_embeds", - type_hint=Optional[torch.Tensor], - description="The negative ip adapter embeddings to use to condition the denoising process, need to have ip adapter model loaded. Can be generated in ip_adapter step." - ), - ] - - @property - def intermediates_outputs(self) -> List[OutputParam]: - return [OutputParam("latents", type_hint=torch.Tensor, description="The denoised latents")] - - def __init__(self): - super().__init__() - self.components["guider"] = CFGGuider() - self.components["controlnet_guider"] = CFGGuider() - self.components["scheduler"] = None - self.components["unet"] = None - self.components["controlnet"] = None - control_image_processor = VaeImageProcessor(do_convert_rgb=True, do_normalize=False) - self.auxiliaries["control_image_processor"] = control_image_processor - - def check_inputs(self, pipeline, data): - - num_channels_unet = pipeline.unet.config.in_channels - if num_channels_unet == 9: - # default case for runwayml/stable-diffusion-inpainting - if data.mask is None or data.masked_image_latents is None: - raise ValueError("mask and masked_image_latents must be provided for inpainting-specific Unet") - num_channels_latents = data.latents.shape[1] - num_channels_mask = data.mask.shape[1] - num_channels_masked_image = data.masked_image_latents.shape[1] - if num_channels_latents + num_channels_mask + num_channels_masked_image != num_channels_unet: - raise ValueError( - f"Incorrect configuration settings! The config of `pipeline.unet`: {pipeline.unet.config} expects" - f" {pipeline.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +" - f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}" - f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of" - " `pipeline.unet` or your `mask_image` or `image` input." - ) - - @torch.no_grad() - def __call__(self, pipeline, state: PipelineState) -> PipelineState: - - data = self.get_block_state(state) - self.check_inputs(pipeline, data) - - data.num_channels_unet = pipeline.unet.config.in_channels - - # (1) prepare controlnet inputs - - data.device = pipeline._execution_device - - data.height, data.width = data.latents.shape[-2:] - data.height = data.height * pipeline.vae_scale_factor - data.width = data.width * pipeline.vae_scale_factor - - controlnet = pipeline.controlnet._orig_mod if is_compiled_module(pipeline.controlnet) else pipeline.controlnet - - # (1.1) - # control_guidance_start/control_guidance_end (align format) - if not isinstance(data.control_guidance_start, list) and isinstance(data.control_guidance_end, list): - data.control_guidance_start = len(data.control_guidance_end) * [data.control_guidance_start] - elif not isinstance(data.control_guidance_end, list) and isinstance(data.control_guidance_start, list): - data.control_guidance_end = len(data.control_guidance_start) * [data.control_guidance_end] - elif not isinstance(data.control_guidance_start, list) and not isinstance(data.control_guidance_end, list): - mult = len(controlnet.nets) if isinstance(controlnet, MultiControlNetModel) else 1 - data.control_guidance_start, data.control_guidance_end = ( - mult * [data.control_guidance_start], - mult * [data.control_guidance_end], - ) - - # (1.2) - # controlnet_conditioning_scale (align format) - if isinstance(controlnet, MultiControlNetModel) and isinstance(data.controlnet_conditioning_scale, float): - data.controlnet_conditioning_scale = [data.controlnet_conditioning_scale] * len(controlnet.nets) - - # (1.3) - # global_pool_conditions - data.global_pool_conditions = ( - controlnet.config.global_pool_conditions - if isinstance(controlnet, ControlNetModel) - else controlnet.nets[0].config.global_pool_conditions - ) - # (1.4) - # guess_mode - data.guess_mode = data.guess_mode or data.global_pool_conditions - - # (1.5) - # control_image - if isinstance(controlnet, ControlNetModel): - data.control_image = pipeline.prepare_control_image( - image=data.control_image, - width=data.width, - height=data.height, - batch_size=data.batch_size * data.num_images_per_prompt, - num_images_per_prompt=data.num_images_per_prompt, - device=data.device, - dtype=controlnet.dtype, - crops_coords=data.crops_coords, - ) - elif isinstance(controlnet, MultiControlNetModel): - control_images = [] - - for control_image_ in data.control_image: - control_image = pipeline.prepare_control_image( - image=control_image_, - width=data.width, - height=data.height, - batch_size=data.batch_size * data.num_images_per_prompt, - num_images_per_prompt=data.num_images_per_prompt, - device=data.device, - dtype=controlnet.dtype, - crops_coords=data.crops_coords, - ) - - control_images.append(control_image) - - data.control_image = control_images - else: - assert False - - # (1.6) - # controlnet_keep - data.controlnet_keep = [] - for i in range(len(data.timesteps)): - keeps = [ - 1.0 - float(i / len(data.timesteps) < s or (i + 1) / len(data.timesteps) > e) - for s, e in zip(data.control_guidance_start, data.control_guidance_end) - ] - data.controlnet_keep.append(keeps[0] if isinstance(controlnet, ControlNetModel) else keeps) - - # (2) Prepare conditional inputs for unet using the guider - # adding default guider arguments: disable_guidance, guidance_scale, guidance_rescale - data.disable_guidance = True if pipeline.unet.config.time_cond_proj_dim is not None else False - data.guider_kwargs = data.guider_kwargs or {} - data.guider_kwargs = { - **data.guider_kwargs, - "disable_guidance": data.disable_guidance, - "guidance_scale": data.guidance_scale, - "guidance_rescale": data.guidance_rescale, - "batch_size": data.batch_size * data.num_images_per_prompt, - } - pipeline.guider.set_guider(pipeline, data.guider_kwargs) - data.prompt_embeds = pipeline.guider.prepare_input( - data.prompt_embeds, - data.negative_prompt_embeds, - ) - data.add_time_ids = pipeline.guider.prepare_input( - data.add_time_ids, - data.negative_add_time_ids, - ) - data.pooled_prompt_embeds = pipeline.guider.prepare_input( - data.pooled_prompt_embeds, - data.negative_pooled_prompt_embeds, - ) - if data.num_channels_unet == 9: - data.mask = pipeline.guider.prepare_input(data.mask, data.mask) - data.masked_image_latents = pipeline.guider.prepare_input(data.masked_image_latents, data.masked_image_latents) - - data.added_cond_kwargs = { - "text_embeds": data.pooled_prompt_embeds, - "time_ids": data.add_time_ids, - } - - if data.ip_adapter_embeds is not None: - data.ip_adapter_embeds = pipeline.guider.prepare_input(data.ip_adapter_embeds, data.negative_ip_adapter_embeds) - data.added_cond_kwargs["image_embeds"] = data.ip_adapter_embeds - - # (3) Prepare conditional inputs for controlnet using the guider - data.controlnet_disable_guidance = True if data.disable_guidance or data.guess_mode else False - data.controlnet_guider_kwargs = data.guider_kwargs or {} - data.controlnet_guider_kwargs = { - **data.controlnet_guider_kwargs, - "disable_guidance": data.controlnet_disable_guidance, - "guidance_scale": data.guidance_scale, - "guidance_rescale": data.guidance_rescale, - "batch_size": data.batch_size * data.num_images_per_prompt, - } - pipeline.controlnet_guider.set_guider(pipeline, data.controlnet_guider_kwargs) - data.controlnet_prompt_embeds = pipeline.controlnet_guider.prepare_input(data.prompt_embeds) - data.controlnet_added_cond_kwargs = { - "text_embeds": pipeline.controlnet_guider.prepare_input(data.pooled_prompt_embeds), - "time_ids": pipeline.controlnet_guider.prepare_input(data.add_time_ids), - } - data.control_image = pipeline.controlnet_guider.prepare_input(data.control_image, data.control_image) - - # (4) Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline - data.extra_step_kwargs = pipeline.prepare_extra_step_kwargs(data.generator, data.eta) - data.num_warmup_steps = max(len(data.timesteps) - data.num_inference_steps * pipeline.scheduler.order, 0) - - # (5) Denoise loop - with pipeline.progress_bar(total=data.num_inference_steps) as progress_bar: - for i, t in enumerate(data.timesteps): - # prepare latents for unet using the guider - data.latent_model_input = pipeline.guider.prepare_input(data.latents, data.latents) - - # prepare latents for controlnet using the guider - data.control_model_input = pipeline.controlnet_guider.prepare_input(data.latents, data.latents) - - if isinstance(data.controlnet_keep[i], list): - data.cond_scale = [c * s for c, s in zip(data.controlnet_conditioning_scale, data.controlnet_keep[i])] - else: - data.controlnet_cond_scale = data.controlnet_conditioning_scale - if isinstance(data.controlnet_cond_scale, list): - data.controlnet_cond_scale = data.controlnet_cond_scale[0] - data.cond_scale = data.controlnet_cond_scale * data.controlnet_keep[i] - - data.down_block_res_samples, data.mid_block_res_sample = pipeline.controlnet( - pipeline.scheduler.scale_model_input(data.control_model_input, t), - t, - encoder_hidden_states=data.controlnet_prompt_embeds, - controlnet_cond=data.control_image, - conditioning_scale=data.cond_scale, - guess_mode=data.guess_mode, - added_cond_kwargs=data.controlnet_added_cond_kwargs, - return_dict=False, - ) - - # when we apply guidance for unet, but not for controlnet: - # add 0 to the unconditional batch - data.down_block_res_samples = pipeline.guider.prepare_input( - data.down_block_res_samples, [torch.zeros_like(d) for d in data.down_block_res_samples] - ) - data.mid_block_res_sample = pipeline.guider.prepare_input( - data.mid_block_res_sample, torch.zeros_like(data.mid_block_res_sample) - ) - - data.latent_model_input = pipeline.scheduler.scale_model_input(data.latent_model_input, t) - if data.num_channels_unet == 9: - data.latent_model_input = torch.cat([data.latent_model_input, data.mask, data.masked_image_latents], dim=1) - - data.noise_pred = pipeline.unet( - data.latent_model_input, - t, - encoder_hidden_states=data.prompt_embeds, - timestep_cond=data.timestep_cond, - cross_attention_kwargs=data.cross_attention_kwargs, - added_cond_kwargs=data.added_cond_kwargs, - down_block_additional_residuals=data.down_block_res_samples, - mid_block_additional_residual=data.mid_block_res_sample, - return_dict=False, - )[0] - # perform guidance - data.noise_pred = pipeline.guider.apply_guidance(data.noise_pred, timestep=t, latents=data.latents) - # compute the previous noisy sample x_t -> x_t-1 - data.latents_dtype = data.latents.dtype - data.latents = pipeline.scheduler.step(data.noise_pred, t, data.latents, **data.extra_step_kwargs, return_dict=False)[0] - if data.latents.dtype != data.latents_dtype: - if torch.backends.mps.is_available(): - # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 - data.latents = data.latents.to(data.latents_dtype) - - - if data.num_channels_unet == 4 and data.mask is not None and data.image_latents is not None: - data.init_latents_proper = data.image_latents - if i < len(data.timesteps) - 1: - data.noise_timestep = data.timesteps[i + 1] - data.init_latents_proper = pipeline.scheduler.add_noise( - data.init_latents_proper, data.noise, torch.tensor([data.noise_timestep]) - ) - - data.latents = (1 - data.mask) * data.init_latents_proper + data.mask * data.latents - - if i == len(data.timesteps) - 1 or ((i + 1) > data.num_warmup_steps and (i + 1) % pipeline.scheduler.order == 0): - progress_bar.update() - - pipeline.guider.reset_guider(pipeline) - pipeline.controlnet_guider.reset_guider(pipeline) - - self.add_block_state(state, data) - - return pipeline, state - - -class StableDiffusionXLControlNetUnionDenoiseStep(PipelineBlock): - expected_components = ["unet", "controlnet", "scheduler", "guider", "controlnet_guider"] - model_name = "stable-diffusion-xl" - - @property - def description(self) -> str: - return " The denoising step for the controlnet union model, works for inpainting, image-to-image, and text-to-image tasks" - @property - def inputs(self) -> List[Tuple[str, Any]]: - return [ - InputParam( - "control_image", - required=True, - type_hint=PipelineImageInput, - description="The ControlNet input condition to provide guidance to the unet for generation. If passed as torch.Tensor, it is used as-is. PIL.Image.Image inputs are accepted and default to image dimensions. For multiple ControlNets, pass images as a list for proper batching."), - InputParam( - "control_guidance_start", - default=0.0, - type_hint=Union[float, List[float]], - description="The percentage of total steps at which the ControlNet starts applying."), - InputParam( - "control_guidance_end", - default=1.0, - type_hint=Union[float, List[float]], - description="The percentage of total steps at which the ControlNet stops applying."), - InputParam( - "control_mode", - required=True, - type_hint=List[int], - description="The control mode for union controlnet, 0 for openpose, 1 for depth, 2 for hed/pidi/scribble/ted, 3 for canny/lineart/anime_lineart/mlsd, 4 for normal and 5 for segment" - ), - InputParam( - "controlnet_conditioning_scale", - default=1.0, - type_hint=Union[float, List[float]], - description="Scale factor for ControlNet outputs before adding to unet residual. For multiple ControlNets, can be set as a list of scales." - ), - InputParam( - "guess_mode", - default=False, - type_hint=bool, - description="Enables ControlNet encoder to recognize input image content without prompts. Recommended guidance_scale: 3.0-5.0." - ), - InputParam( - "num_images_per_prompt", - default=1, - type_hint=int, - description="The number of images to generate per prompt." - ), - InputParam( - "guidance_scale", - default=5.0, - type_hint=float, - description="Guidance scale as defined in Classifier-Free Diffusion Guidance. Higher values encourage images closely linked to the text prompt, potentially at the expense of image quality. Enabled when > 1."), - InputParam( - "guidance_rescale", - default=0.0, - type_hint=float, - description="Guidance rescale factor (φ) to fix overexposure when using zero terminal SNR, as proposed in 'Common Diffusion Noise Schedules and Sample Steps are Flawed'."), - InputParam( - "cross_attention_kwargs", - default=None, - type_hint=Optional[Dict[str, Any]], - description="Optional kwargs dictionary passed to the AttentionProcessor."), - InputParam( - "generator", - default=None, - type_hint=Optional[Union[torch.Generator, List[torch.Generator]]], - description="One or a list of torch generator(s) to make generation deterministic."), - InputParam( - "eta", - default=0.0, - type_hint=float, - description="Parameter η in the DDIM paper. Only applies to DDIMScheduler, ignored for others."), - InputParam( - "guider_kwargs", - default=None, - type_hint=Optional[Dict[str, Any]], - description="Optional kwargs dictionary passed to the Guider."), - ] - - @property - def intermediates_inputs(self) -> List[str]: - return [ - InputParam( - "latents", - required=True, - type_hint=torch.Tensor, - description="The initial latents to use for the denoising process. Can be generated in prepare_latent step." - ), - InputParam( - "batch_size", - required=True, - type_hint=int, - description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can be generated in input step." - ), - InputParam( - "timesteps", - required=True, - type_hint=torch.Tensor, - description="The timesteps to use for the denoising process. Can be generated in set_timesteps step." - ), - InputParam( - "num_inference_steps", - required=True, - type_hint=int, - description="The number of inference steps to use for the denoising process. Can be generated in set_timesteps step." - ), - InputParam( - "prompt_embeds", - required=True, - type_hint=torch.Tensor, - description="The prompt embeddings used to condition the denoising process. Can be generated in text_encoder step." - ), - InputParam( - "negative_prompt_embeds", - type_hint=Optional[torch.Tensor], - description="The negative prompt embeddings used to condition the denoising process. Can be generated in text_encoder step. See: https://github.com/huggingface/diffusers/issues/4208" - ), - InputParam( - "add_time_ids", - required=True, - type_hint=torch.Tensor, - description="The time ids used to condition the denoising process. Can be generated in prepare_additional_conditioning step." - ), - InputParam( - "negative_add_time_ids", - type_hint=Optional[torch.Tensor], - description="The negative time ids used to condition the denoising process. Can be generated in prepare_additional_conditioning step. " - ), - InputParam( - "pooled_prompt_embeds", - required=True, - type_hint=torch.Tensor, - description="The pooled prompt embeddings used to condition the denoising process. Can be generated in text_encoder step." - ), - InputParam( - "negative_pooled_prompt_embeds", - type_hint=Optional[torch.Tensor], - description="The negative pooled prompt embeddings to use to condition the denoising process. Can be generated in text_encoder step. See: https://github.com/huggingface/diffusers/issues/4208" - ), - InputParam( - "timestep_cond", - type_hint=Optional[torch.Tensor], - description="The guidance scale embedding to use for Latent Consistency Models(LCMs). Can be generated in prepare_additional_conditioning step." - ), - InputParam( - "mask", - type_hint=Optional[torch.Tensor], - description="The mask to use for the denoising process, for inpainting task only. Can be generated in vae_encode or prepare_latent step." - ), - InputParam( - "masked_image_latents", - type_hint=Optional[torch.Tensor], - description="The masked image latents to use for the denoising process, for inpainting task only. Can be generated in vae_encode or prepare_latent step." - ), - InputParam( - "noise", - type_hint=Optional[torch.Tensor], - description="The noise added to the image latents, for inpainting task only. Can be generated in prepare_latent step." - ), - InputParam( - "image_latents", - type_hint=Optional[torch.Tensor], - description="The image latents to use for the denoising process, for inpainting/image-to-image task only. Can be generated in vae_encode or prepare_latent step." - ), - InputParam( - "crops_coords", - type_hint=Optional[Tuple[int]], - description="The crop coordinates to use for preprocess/postprocess the image and mask, for inpainting task only. Can be generated in vae_encode or prepare_latent step." - ), - InputParam( - "ip_adapter_embeds", - type_hint=Optional[torch.Tensor], - description="The ip adapter embeddings to use to condition the denoising process, need to have ip adapter model loaded. Can be generated in ip_adapter step." - ), - InputParam( - "negative_ip_adapter_embeds", - type_hint=Optional[torch.Tensor], - description="The negative ip adapter embeddings to use to condition the denoising process, need to have ip adapter model loaded. Can be generated in ip_adapter step." - ), - ] - - @property - def intermediates_outputs(self) -> List[str]: - return [OutputParam("latents", type_hint=torch.Tensor, description="The denoised latents")] - - def __init__(self): - super().__init__() - self.components["guider"] = CFGGuider() - self.components["controlnet_guider"] = CFGGuider() - self.components["scheduler"] = None - self.components["unet"] = None - self.components["controlnet"] = None - control_image_processor = VaeImageProcessor(do_convert_rgb=True, do_normalize=False) - self.auxiliaries["control_image_processor"] = control_image_processor - - def check_inputs(self, pipeline, data): - - num_channels_unet = pipeline.unet.config.in_channels - if num_channels_unet == 9: - # default case for runwayml/stable-diffusion-inpainting - if data.mask is None or data.masked_image_latents is None: - raise ValueError("mask and masked_image_latents must be provided for inpainting-specific Unet") - num_channels_latents = data.latents.shape[1] - num_channels_mask = data.mask.shape[1] - num_channels_masked_image = data.masked_image_latents.shape[1] - if num_channels_latents + num_channels_mask + num_channels_masked_image != num_channels_unet: - raise ValueError( - f"Incorrect configuration settings! The config of `pipeline.unet`: {pipeline.unet.config} expects" - f" {pipeline.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +" - f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}" - f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of" - " `pipeline.unet` or your `mask_image` or `image` input." - ) - - @torch.no_grad() - def __call__(self, pipeline, state: PipelineState) -> PipelineState: - data = self.get_block_state(state) - self.check_inputs(pipeline, data) - - data.num_channels_unet = pipeline.unet.config.in_channels - - # (1) prepare controlnet inputs - data.device = pipeline._execution_device - data.height, data.width = data.latents.shape[-2:] - data.height = data.height * pipeline.vae_scale_factor - data.width = data.width * pipeline.vae_scale_factor - - controlnet = pipeline.controlnet._orig_mod if is_compiled_module(pipeline.controlnet) else pipeline.controlnet - - # (1.1) - # control guidance - if not isinstance(data.control_guidance_start, list) and isinstance(data.control_guidance_end, list): - data.control_guidance_start = len(data.control_guidance_end) * [data.control_guidance_start] - elif not isinstance(data.control_guidance_end, list) and isinstance(data.control_guidance_start, list): - data.control_guidance_end = len(data.control_guidance_start) * [data.control_guidance_end] - - # (1.2) - # global_pool_conditions & guess_mode - data.global_pool_conditions = controlnet.config.global_pool_conditions - data.guess_mode = data.guess_mode or data.global_pool_conditions - - # (1.3) - # control_type - data.num_control_type = controlnet.config.num_control_type - - # (1.4) - # control_type - if not isinstance(data.control_image, list): - data.control_image = [data.control_image] - - if not isinstance(data.control_mode, list): - data.control_mode = [data.control_mode] - - if len(data.control_image) != len(data.control_mode): - raise ValueError("Expected len(control_image) == len(control_type)") - - data.control_type = [0 for _ in range(data.num_control_type)] - for control_idx in data.control_mode: - data.control_type[control_idx] = 1 - - data.control_type = torch.Tensor(data.control_type) - - # (1.5) - # prepare control_image - for idx, _ in enumerate(data.control_image): - data.control_image[idx] = pipeline.prepare_control_image( - image=data.control_image[idx], - width=data.width, - height=data.height, - batch_size=data.batch_size * data.num_images_per_prompt, - num_images_per_prompt=data.num_images_per_prompt, - device=data.device, - dtype=controlnet.dtype, - crops_coords=data.crops_coords, - ) - data.height, data.width = data.control_image[idx].shape[-2:] - - - # (1.6) - # controlnet_keep - data.controlnet_keep = [] - for i in range(len(data.timesteps)): - data.controlnet_keep.append( - 1.0 - - float(i / len(data.timesteps) < data.control_guidance_start or (i + 1) / len(data.timesteps) > data.control_guidance_end) - ) - - # (2) Prepare conditional inputs for unet using the guider - # adding default guider arguments: disable_guidance, guidance_scale, guidance_rescale - data.disable_guidance = True if pipeline.unet.config.time_cond_proj_dim is not None else False - data.guider_kwargs = data.guider_kwargs or {} - data.guider_kwargs = { - **data.guider_kwargs, - "disable_guidance": data.disable_guidance, - "guidance_scale": data.guidance_scale, - "guidance_rescale": data.guidance_rescale, - "batch_size": data.batch_size * data.num_images_per_prompt, - } - pipeline.guider.set_guider(pipeline, data.guider_kwargs) - data.prompt_embeds = pipeline.guider.prepare_input( - data.prompt_embeds, - data.negative_prompt_embeds, - ) - data.add_time_ids = pipeline.guider.prepare_input( - data.add_time_ids, - data.negative_add_time_ids, - ) - data.pooled_prompt_embeds = pipeline.guider.prepare_input( - data.pooled_prompt_embeds, - data.negative_pooled_prompt_embeds, - ) - - if data.num_channels_unet == 9: - data.mask = pipeline.guider.prepare_input(data.mask, data.mask) - data.masked_image_latents = pipeline.guider.prepare_input(data.masked_image_latents, data.masked_image_latents) - - data.added_cond_kwargs = { - "text_embeds": data.pooled_prompt_embeds, - "time_ids": data.add_time_ids, - } - - if data.ip_adapter_embeds is not None: - data.ip_adapter_embeds = pipeline.guider.prepare_input(data.ip_adapter_embeds, data.negative_ip_adapter_embeds) - data.added_cond_kwargs["image_embeds"] = data.ip_adapter_embeds - - # (3) Prepare conditional inputs for controlnet using the guider - data.controlnet_disable_guidance = True if data.disable_guidance or data.guess_mode else False - data.controlnet_guider_kwargs = data.guider_kwargs or {} - data.controlnet_guider_kwargs = { - **data.controlnet_guider_kwargs, - "disable_guidance": data.controlnet_disable_guidance, - "guidance_scale": data.guidance_scale, - "guidance_rescale": data.guidance_rescale, - "batch_size": data.batch_size * data.num_images_per_prompt, - } - pipeline.controlnet_guider.set_guider(pipeline, data.controlnet_guider_kwargs) - data.controlnet_prompt_embeds = pipeline.controlnet_guider.prepare_input(data.prompt_embeds) - data.controlnet_added_cond_kwargs = { - "text_embeds": pipeline.controlnet_guider.prepare_input(data.pooled_prompt_embeds), - "time_ids": pipeline.controlnet_guider.prepare_input(data.add_time_ids), - } - for idx, _ in enumerate(data.control_image): - data.control_image[idx] = pipeline.controlnet_guider.prepare_input(data.control_image[idx], data.control_image[idx]) - - data.control_type = ( - data.control_type.reshape(1, -1) - .to(data.device, dtype=data.prompt_embeds.dtype) - ) - repeat_by = data.batch_size * data.num_images_per_prompt // data.control_type.shape[0] - data.control_type = data.control_type.repeat_interleave(repeat_by, dim=0) - data.control_type = pipeline.controlnet_guider.prepare_input(data.control_type, data.control_type) - - # (4) Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline - data.extra_step_kwargs = pipeline.prepare_extra_step_kwargs(data.generator, data.eta) - data.num_warmup_steps = max(len(data.timesteps) - data.num_inference_steps * pipeline.scheduler.order, 0) - - - with pipeline.progress_bar(total=data.num_inference_steps) as progress_bar: - for i, t in enumerate(data.timesteps): - # prepare latents for unet using the guider - data.latent_model_input = pipeline.guider.prepare_input(data.latents, data.latents) - - # prepare latents for controlnet using the guider - data.control_model_input = pipeline.controlnet_guider.prepare_input(data.latents, data.latents) - - if isinstance(data.controlnet_keep[i], list): - data.cond_scale = [c * s for c, s in zip(data.controlnet_conditioning_scale, data.controlnet_keep[i])] - else: - data.controlnet_cond_scale = data.controlnet_conditioning_scale - if isinstance(data.controlnet_cond_scale, list): - data.controlnet_cond_scale = data.controlnet_cond_scale[0] - data.cond_scale = data.controlnet_cond_scale * data.controlnet_keep[i] - - data.down_block_res_samples, data.mid_block_res_sample = pipeline.controlnet( - pipeline.scheduler.scale_model_input(data.control_model_input, t), - t, - encoder_hidden_states=data.controlnet_prompt_embeds, - controlnet_cond=data.control_image, - control_type=data.control_type, - control_type_idx=data.control_mode, - conditioning_scale=data.cond_scale, - guess_mode=data.guess_mode, - added_cond_kwargs=data.controlnet_added_cond_kwargs, - return_dict=False, - ) - - # when we apply guidance for unet, but not for controlnet: - # add 0 to the unconditional batch - data.down_block_res_samples = pipeline.guider.prepare_input( - data.down_block_res_samples, [torch.zeros_like(d) for d in data.down_block_res_samples] - ) - data.mid_block_res_sample = pipeline.guider.prepare_input( - data.mid_block_res_sample, torch.zeros_like(data.mid_block_res_sample) - ) - - data.latent_model_input = pipeline.scheduler.scale_model_input(data.latent_model_input, t) - if data.num_channels_unet == 9: - data.latent_model_input = torch.cat([data.latent_model_input, data.mask, data.masked_image_latents], dim=1) - - data.noise_pred = pipeline.unet( - data.latent_model_input, - t, - encoder_hidden_states=data.prompt_embeds, - timestep_cond=data.timestep_cond, - cross_attention_kwargs=data.cross_attention_kwargs, - added_cond_kwargs=data.added_cond_kwargs, - down_block_additional_residuals=data.down_block_res_samples, - mid_block_additional_residual=data.mid_block_res_sample, - return_dict=False, - )[0] - # perform guidance - data.noise_pred = pipeline.guider.apply_guidance(data.noise_pred, timestep=t, latents=data.latents) - # compute the previous noisy sample x_t -> x_t-1 - data.latents_dtype = data.latents.dtype - data.latents = pipeline.scheduler.step(data.noise_pred, t, data.latents, **data.extra_step_kwargs, return_dict=False)[0] - if data.latents.dtype != data.latents_dtype: - if torch.backends.mps.is_available(): - # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 - data.latents = data.latents.to(data.latents_dtype) - - if data.num_channels_unet == 9 and data.mask is not None and data.image_latents is not None: - data.init_latents_proper = data.image_latents - if i < len(data.timesteps) - 1: - data.noise_timestep = data.timesteps[i + 1] - data.init_latents_proper = pipeline.scheduler.add_noise( - data.init_latents_proper, data.noise, torch.tensor([data.noise_timestep]) - ) - - data.latents = (1 - data.mask) * data.init_latents_proper + data.mask * data.latents - - if i == len(data.timesteps) - 1 or ((i + 1) > data.num_warmup_steps and (i + 1) % pipeline.scheduler.order == 0): - progress_bar.update() - - pipeline.guider.reset_guider(pipeline) - pipeline.controlnet_guider.reset_guider(pipeline) - - self.add_block_state(state, data) - - return pipeline, state - - -class StableDiffusionXLDecodeLatentsStep(PipelineBlock): - expected_components = ["vae"] - model_name = "stable-diffusion-xl" - - @property - def description(self) -> str: - return "Step that decodes the denoised latents into images" - - @property - def inputs(self) -> List[Tuple[str, Any]]: - return [ - InputParam( - "output_type", - type_hint=str, - default="pil", - description="The output format of the generated image. Choose between PIL (PIL.Image.Image), torch.Tensor or np.array." - ), - ] - - @property - def intermediates_inputs(self) -> List[str]: - return [InputParam("latents", required=True, type_hint=torch.Tensor, description="The denoised latents from the denoising step")] - - @property - def intermediates_outputs(self) -> List[str]: - return [OutputParam("images", type_hint=Union[List[PIL.Image.Image], List[torch.Tensor], List[np.array]], description="The generated images, can be a PIL.Image.Image, torch.Tensor or a numpy array")] - - def __init__(self): - super().__init__() - self.components["vae"] = None - self.auxiliaries["image_processor"] = VaeImageProcessor(vae_scale_factor=8) - - @torch.no_grad() - def __call__(self, pipeline, state: PipelineState) -> PipelineState: - data = self.get_block_state(state) - - if not data.output_type == "latent": - # make sure the VAE is in float32 mode, as it overflows in float16 - data.needs_upcasting = pipeline.vae.dtype == torch.float16 and pipeline.vae.config.force_upcast - - if data.needs_upcasting: - pipeline.upcast_vae() - data.latents = data.latents.to(next(iter(pipeline.vae.post_quant_conv.parameters())).dtype) - elif data.latents.dtype != pipeline.vae.dtype: - if torch.backends.mps.is_available(): - # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 - pipeline.vae = pipeline.vae.to(data.latents.dtype) - - # unscale/denormalize the latents - # denormalize with the mean and std if available and not None - data.has_latents_mean = ( - hasattr(pipeline.vae.config, "latents_mean") and pipeline.vae.config.latents_mean is not None - ) - data.has_latents_std = ( - hasattr(pipeline.vae.config, "latents_std") and pipeline.vae.config.latents_std is not None - ) - if data.has_latents_mean and data.has_latents_std: - data.latents_mean = ( - torch.tensor(pipeline.vae.config.latents_mean).view(1, 4, 1, 1).to(data.latents.device, data.latents.dtype) - ) - data.latents_std = ( - torch.tensor(pipeline.vae.config.latents_std).view(1, 4, 1, 1).to(data.latents.device, data.latents.dtype) - ) - data.latents = data.latents * data.latents_std / pipeline.vae.config.scaling_factor + data.latents_mean - else: - data.latents = data.latents / pipeline.vae.config.scaling_factor - - data.images = pipeline.vae.decode(data.latents, return_dict=False)[0] - - # cast back to fp16 if needed - if data.needs_upcasting: - pipeline.vae.to(dtype=torch.float16) - else: - data.images = data.latents - - # apply watermark if available - if hasattr(pipeline, "watermark") and pipeline.watermark is not None: - data.images = pipeline.watermark.apply_watermark(data.images) - - data.images = pipeline.image_processor.postprocess(data.images, output_type=data.output_type) - - self.add_block_state(state, data) - - return pipeline, state - - -class StableDiffusionXLInpaintOverlayMaskStep(PipelineBlock): - model_name = "stable-diffusion-xl" - - @property - def description(self) -> str: - return "A post-processing step that overlays the mask on the image (inpainting task only).\n" + \ - "only needed when you are using the `padding_mask_crop` option when pre-processing the image and mask" - - @property - def inputs(self) -> List[Tuple[str, Any]]: - return [ - InputParam( - "image", - type_hint=PipelineImageInput, - required=True, - description="The image(s) to modify with the pipeline, for img2img or inpainting task. When using for inpainting task, parts of the image will be masked out with `mask_image` and repainted according to `prompt`." - ), - InputParam( - "mask_image", - type_hint=PipelineImageInput, - required=True, - description="The mask image(s) to use for inpainting, white pixels in the mask will be repainted, while black pixels will be preserved. If mask_image is a PIL image, it will be converted to a single channel (luminance) before use. If it's a tensor, it should contain one color channel (L) instead of 3, so the expected shape would be (B, H, W, 1). Must be a `PIL.Image.Image`" - ), - InputParam( - "padding_mask_crop", - type_hint=Optional[Tuple[int, int]], - default=None, - description="The size of margin in the crop to be applied to the image and masking. If `None`, no crop is applied. If set, it will find a rectangular region with the same aspect ratio as the image that contains all masked areas, then expand that area by this margin. The image and mask_image are cropped to this expanded area before resizing to the original size for inpainting. Useful when the masked area is small in a large image with irrelevant background information." - ), - ] - - @property - def intermediates_inputs(self) -> List[str]: - return [ - InputParam("images", required=True, type_hint=Union[List[PIL.Image.Image], List[torch.Tensor], List[np.array]], description="The generated images from the decode step"), - InputParam("crops_coords", required=True, type_hint=Tuple[int, int], description="The crop coordinates to use for preprocess/postprocess the image and mask, for inpainting task only. Can be generated in vae_encode step.") - ] - - @property - def intermediates_outputs(self) -> List[str]: - return [OutputParam("images", type_hint=Union[List[PIL.Image.Image], List[torch.Tensor], List[np.array]], description="The generated images with the mask overlayed")] - - @torch.no_grad() - def __call__(self, pipeline, state: PipelineState) -> PipelineState: - data = self.get_block_state(state) - - if data.padding_mask_crop is not None and data.crops_coords is not None: - data.images = [pipeline.image_processor.apply_overlay(data.mask_image, data.image, i, data.crops_coords) for i in data.images] - - self.add_block_state(state, data) - - return pipeline, state - - -class StableDiffusionXLOutputStep(PipelineBlock): - model_name = "stable-diffusion-xl" - - @property - def description(self) -> str: - return "final step to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] or a plain tuple." - - @property - def inputs(self) -> List[Tuple[str, Any]]: - return [(InputParam("return_dict", type_hint=bool, default=True, description="Whether or not to return a StableDiffusionXLPipelineOutput instead of a plain tuple."))] - - @property - def intermediates_inputs(self) -> List[str]: - return [InputParam("images", required=True, type_hint=Union[List[PIL.Image.Image], List[torch.Tensor], List[np.array]], description="The generated images from the decode step.")] - - @property - def intermediates_outputs(self) -> List[str]: - return [OutputParam("images", description="The final images output, can be a tuple or a `StableDiffusionXLPipelineOutput`")] - - - @torch.no_grad() - def __call__(self, pipeline, state: PipelineState) -> PipelineState: - data = self.get_block_state(state) - - if not data.return_dict: - data.images = (data.images,) - else: - data.images = StableDiffusionXLPipelineOutput(images=data.images) - self.add_block_state(state, data) - return pipeline, state - - -# Encode -class StableDiffusionXLAutoVaeEncoderStep(AutoPipelineBlocks): - block_classes = [StableDiffusionXLInpaintVaeEncoderStep, StableDiffusionXLVaeEncoderStep] - block_names = ["inpaint", "img2img"] - block_trigger_inputs = ["mask_image", "image"] - - @property - def description(self): - return "Vae encoder step that encode the image inputs into their latent representations.\n" + \ - "This is an auto pipeline block that works for both inpainting and img2img tasks.\n" + \ - " - `StableDiffusionXLInpaintVaeEncoderStep` (inpaint) is used when both `mask_image` and `image` are provided.\n" + \ - " - `StableDiffusionXLVaeEncoderStep` (img2img) is used when only `image` is provided." - - -# Before denoise -class StableDiffusionXLBeforeDenoiseStep(SequentialPipelineBlocks): - block_classes = [StableDiffusionXLInputStep, StableDiffusionXLSetTimestepsStep, StableDiffusionXLPrepareLatentsStep, StableDiffusionXLPrepareAdditionalConditioningStep] - block_names = ["input", "set_timesteps", "prepare_latents", "prepare_add_cond"] - - @property - def description(self): - return "Before denoise step that prepare the inputs for the denoise step.\n" + \ - "This is a sequential pipeline blocks:\n" + \ - " - `StableDiffusionXLInputStep` is used to adjust the batch size of the model inputs\n" + \ - " - `StableDiffusionXLSetTimestepsStep` is used to set the timesteps\n" + \ - " - `StableDiffusionXLPrepareLatentsStep` is used to prepare the latents\n" + \ - " - `StableDiffusionXLPrepareAdditionalConditioningStep` is used to prepare the additional conditioning" - -class StableDiffusionXLImg2ImgBeforeDenoiseStep(SequentialPipelineBlocks): - block_classes = [StableDiffusionXLInputStep, StableDiffusionXLImg2ImgSetTimestepsStep, StableDiffusionXLImg2ImgPrepareLatentsStep, StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep] - block_names = ["input", "set_timesteps", "prepare_latents", "prepare_add_cond"] - - @property - def description(self): - return "Before denoise step that prepare the inputs for the denoise step for img2img task.\n" + \ - "This is a sequential pipeline blocks:\n" + \ - " - `StableDiffusionXLInputStep` is used to adjust the batch size of the model inputs\n" + \ - " - `StableDiffusionXLImg2ImgSetTimestepsStep` is used to set the timesteps\n" + \ - " - `StableDiffusionXLImg2ImgPrepareLatentsStep` is used to prepare the latents\n" + \ - " - `StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep` is used to prepare the additional conditioning" - -class StableDiffusionXLInpaintBeforeDenoiseStep(SequentialPipelineBlocks): - block_classes = [StableDiffusionXLInputStep, StableDiffusionXLImg2ImgSetTimestepsStep, StableDiffusionXLInpaintPrepareLatentsStep, StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep] - block_names = ["input", "set_timesteps", "prepare_latents", "prepare_add_cond"] - - @property - def description(self): - return "Before denoise step that prepare the inputs for the denoise step for inpainting task.\n" + \ - "This is a sequential pipeline blocks:\n" + \ - " - `StableDiffusionXLInputStep` is used to adjust the batch size of the model inputs\n" + \ - " - `StableDiffusionXLImg2ImgSetTimestepsStep` is used to set the timesteps\n" + \ - " - `StableDiffusionXLInpaintPrepareLatentsStep` is used to prepare the latents\n" + \ - " - `StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep` is used to prepare the additional conditioning" - - -class StableDiffusionXLAutoBeforeDenoiseStep(AutoPipelineBlocks): - block_classes = [StableDiffusionXLInpaintBeforeDenoiseStep, StableDiffusionXLImg2ImgBeforeDenoiseStep, StableDiffusionXLBeforeDenoiseStep] - block_names = ["inpaint", "img2img", "text2img"] - block_trigger_inputs = ["mask", "image_latents", None] - - @property - def description(self): - return "Before denoise step that prepare the inputs for the denoise step.\n" + \ - "This is an auto pipeline block that works for text2img, img2img and inpainting tasks.\n" + \ - " - `StableDiffusionXLInpaintBeforeDenoiseStep` (inpaint) is used when both `mask` and `image_latents` are provided.\n" + \ - " - `StableDiffusionXLImg2ImgBeforeDenoiseStep` (img2img) is used when only `image_latents` is provided.\n" + \ - " - `StableDiffusionXLBeforeDenoiseStep` (text2img) is used when both `image_latents` and `mask` are not provided." - - -# Denoise -class StableDiffusionXLAutoDenoiseStep(AutoPipelineBlocks): - block_classes = [StableDiffusionXLControlNetUnionDenoiseStep, StableDiffusionXLControlNetDenoiseStep, StableDiffusionXLDenoiseStep] - block_names = ["controlnet_union", "controlnet", "unet"] - block_trigger_inputs = ["control_mode", "control_image", None] - - @property - def description(self): - return "Denoise step that denoise the latents.\n" + \ - "This is an auto pipeline block that works for controlnet, controlnet_union and no controlnet.\n" + \ - " - `StableDiffusionXLControlNetUnionDenoiseStep` (controlnet_union) is used when both `control_mode` and `control_image` are provided.\n" + \ - " - `StableDiffusionXLControlNetDenoiseStep` (controlnet) is used when `control_image` is provided.\n" + \ - " - `StableDiffusionXLDenoiseStep` (unet only) is used when both `control_mode` and `control_image` are not provided." - -# After denoise - -class StableDiffusionXLDecodeStep(SequentialPipelineBlocks): - block_classes = [StableDiffusionXLDecodeLatentsStep, StableDiffusionXLOutputStep] - block_names = ["decode", "output"] - - @property - def description(self): - return """Decode step that decode the denoised latents into images outputs. -This is a sequential pipeline blocks: - - `StableDiffusionXLDecodeLatentsStep` is used to decode the denoised latents into images - - `StableDiffusionXLOutputStep` is used to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] or a plain tuple.""" - - -class StableDiffusionXLInpaintDecodeStep(SequentialPipelineBlocks): - block_classes = [StableDiffusionXLDecodeLatentsStep, StableDiffusionXLInpaintOverlayMaskStep, StableDiffusionXLOutputStep] - block_names = ["decode", "mask_overlay", "output"] - - @property - def description(self): - return "Inpaint decode step that decode the denoised latents into images outputs.\n" + \ - "This is a sequential pipeline blocks:\n" + \ - " - `StableDiffusionXLDecodeLatentsStep` is used to decode the denoised latents into images\n" + \ - " - `StableDiffusionXLInpaintOverlayMaskStep` is used to overlay the mask on the image\n" + \ - " - `StableDiffusionXLOutputStep` is used to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] or a plain tuple." - - - -class StableDiffusionXLAutoDecodeStep(AutoPipelineBlocks): - block_classes = [StableDiffusionXLInpaintDecodeStep, StableDiffusionXLDecodeStep] - block_names = ["inpaint", "non-inpaint"] - block_trigger_inputs = ["padding_mask_crop", None] - - @property - def description(self): - return "Decode step that decode the denoised latents into images outputs.\n" + \ - "This is an auto pipeline block that works for inpainting and non-inpainting tasks.\n" + \ - " - `StableDiffusionXLInpaintDecodeStep` (inpaint) is used when `padding_mask_crop` is provided.\n" + \ - " - `StableDiffusionXLDecodeStep` (non-inpaint) is used when `padding_mask_crop` is not provided." - -class StableDiffusionXLAutoIPAdapterStep(AutoPipelineBlocks): - block_classes = [StableDiffusionXLIPAdapterStep] - block_names = ["ip_adapter"] - block_trigger_inputs = ["ip_adapter_image"] - - @property - def description(self): - return "Run IP Adapter step if `ip_adapter_image` is provided." - -class StableDiffusionXLAutoPipeline(SequentialPipelineBlocks): - block_classes = [StableDiffusionXLTextEncoderStep, StableDiffusionXLAutoIPAdapterStep, StableDiffusionXLAutoVaeEncoderStep, StableDiffusionXLAutoBeforeDenoiseStep, StableDiffusionXLAutoDenoiseStep, StableDiffusionXLAutoDecodeStep] - block_names = ["text_encoder", "ip_adapter", "image_encoder", "before_denoise", "denoise", "decode"] - - @property - def description(self): - return "Auto Modular pipeline for text-to-image, image-to-image, inpainting, and controlnet tasks using Stable Diffusion XL.\n" + \ - "- for image-to-image generation, you need to provide either `image` or `image_latents`\n" + \ - "- for inpainting, you need to provide `mask_image` and `image`, optionally you can provide `padding_mask_crop` \n" + \ - "- to run the controlnet workflow, you need to provide `control_image`\n" + \ - "- to run the controlnet_union workflow, you need to provide `control_image` and `control_mode`\n" + \ - "- to run the ip_adapter workflow, you need to provide `ip_adapter_image`\n" + \ - "- for text-to-image generation, all you need to provide is `prompt`" - -# block mapping -TEXT2IMAGE_BLOCKS = OrderedDict([ - ("text_encoder", StableDiffusionXLTextEncoderStep), - ("ip_adapter", StableDiffusionXLAutoIPAdapterStep), - ("input", StableDiffusionXLInputStep), - ("set_timesteps", StableDiffusionXLSetTimestepsStep), - ("prepare_latents", StableDiffusionXLPrepareLatentsStep), - ("prepare_add_cond", StableDiffusionXLPrepareAdditionalConditioningStep), - ("denoise", StableDiffusionXLDenoiseStep), - ("decode", StableDiffusionXLDecodeStep) -]) - -IMAGE2IMAGE_BLOCKS = OrderedDict([ - ("text_encoder", StableDiffusionXLTextEncoderStep), - ("ip_adapter", StableDiffusionXLAutoIPAdapterStep), - ("image_encoder", StableDiffusionXLVaeEncoderStep), - ("input", StableDiffusionXLInputStep), - ("set_timesteps", StableDiffusionXLImg2ImgSetTimestepsStep), - ("prepare_latents", StableDiffusionXLImg2ImgPrepareLatentsStep), - ("prepare_add_cond", StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep), - ("denoise", StableDiffusionXLDenoiseStep), - ("decode", StableDiffusionXLDecodeStep) -]) - -INPAINT_BLOCKS = OrderedDict([ - ("text_encoder", StableDiffusionXLTextEncoderStep), - ("ip_adapter", StableDiffusionXLAutoIPAdapterStep), - ("image_encoder", StableDiffusionXLInpaintVaeEncoderStep), - ("input", StableDiffusionXLInputStep), - ("set_timesteps", StableDiffusionXLImg2ImgSetTimestepsStep), - ("prepare_latents", StableDiffusionXLInpaintPrepareLatentsStep), - ("prepare_add_cond", StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep), - ("denoise", StableDiffusionXLDenoiseStep), - ("decode", StableDiffusionXLInpaintDecodeStep) -]) - -CONTROLNET_BLOCKS = OrderedDict([ - ("denoise", StableDiffusionXLControlNetDenoiseStep), -]) - -CONTROLNET_UNION_BLOCKS = OrderedDict([ - ("denoise", StableDiffusionXLControlNetUnionDenoiseStep), -]) - -IP_ADAPTER_BLOCKS = OrderedDict([ - ("ip_adapter", StableDiffusionXLIPAdapterStep), -]) - -AUTO_BLOCKS = OrderedDict([ - ("text_encoder", StableDiffusionXLTextEncoderStep), - ("ip_adapter", StableDiffusionXLAutoIPAdapterStep), - ("image_encoder", StableDiffusionXLAutoVaeEncoderStep), - ("before_denoise", StableDiffusionXLAutoBeforeDenoiseStep), - ("denoise", StableDiffusionXLAutoDenoiseStep), - ("decode", StableDiffusionXLAutoDecodeStep) -]) - -AUTO_CORE_BLOCKS = OrderedDict([ - ("before_denoise", StableDiffusionXLAutoBeforeDenoiseStep), - ("denoise", StableDiffusionXLAutoDenoiseStep), -]) - - -SDXL_SUPPORTED_BLOCKS = { - "text2img": TEXT2IMAGE_BLOCKS, - "img2img": IMAGE2IMAGE_BLOCKS, - "inpaint": INPAINT_BLOCKS, - "controlnet": CONTROLNET_BLOCKS, - "controlnet_union": CONTROLNET_UNION_BLOCKS, - "ip_adapter": IP_ADAPTER_BLOCKS, - "auto": AUTO_BLOCKS -} - - -class StableDiffusionXLModularPipeline( - ModularPipeline, - StableDiffusionMixin, - TextualInversionLoaderMixin, - StableDiffusionXLLoraLoaderMixin, - ModularIPAdapterMixin, -): - @property - def default_sample_size(self): - default_sample_size = 128 - if hasattr(self, "unet") and self.unet is not None: - default_sample_size = self.unet.config.sample_size - return default_sample_size - - @property - def vae_scale_factor(self): - vae_scale_factor = 8 - if hasattr(self, "vae") and self.vae is not None: - vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) - return vae_scale_factor - - @property - def num_channels_unet(self): - num_channels_unet = 4 - if hasattr(self, "unet") and self.unet is not None: - num_channels_unet = self.unet.config.in_channels - return num_channels_unet - - @property - def num_channels_latents(self): - num_channels_latents = 4 - if hasattr(self, "vae") and self.vae is not None: - num_channels_latents = self.vae.config.latent_channels - return num_channels_latents - - # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline._get_add_time_ids - def _get_add_time_ids( - self, original_size, crops_coords_top_left, target_size, dtype, text_encoder_projection_dim=None - ): - add_time_ids = list(original_size + crops_coords_top_left + target_size) - - passed_add_embed_dim = ( - self.unet.config.addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim - ) - expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features - - if expected_add_embed_dim != passed_add_embed_dim: - raise ValueError( - f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`." - ) - - add_time_ids = torch.tensor([add_time_ids], dtype=dtype) - return add_time_ids - - # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_img2img.StableDiffusionXLImg2ImgPipeline._get_add_time_ids - def _get_add_time_ids_img2img( - self, - original_size, - crops_coords_top_left, - target_size, - aesthetic_score, - negative_aesthetic_score, - negative_original_size, - negative_crops_coords_top_left, - negative_target_size, - dtype, - text_encoder_projection_dim=None, - ): - if self.config.requires_aesthetics_score: - add_time_ids = list(original_size + crops_coords_top_left + (aesthetic_score,)) - add_neg_time_ids = list( - negative_original_size + negative_crops_coords_top_left + (negative_aesthetic_score,) - ) - else: - add_time_ids = list(original_size + crops_coords_top_left + target_size) - add_neg_time_ids = list(negative_original_size + crops_coords_top_left + negative_target_size) - - passed_add_embed_dim = ( - self.unet.config.addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim - ) - expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features - - if ( - expected_add_embed_dim > passed_add_embed_dim - and (expected_add_embed_dim - passed_add_embed_dim) == self.unet.config.addition_time_embed_dim - ): - raise ValueError( - f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. Please make sure to enable `requires_aesthetics_score` with `pipe.register_to_config(requires_aesthetics_score=True)` to make sure `aesthetic_score` {aesthetic_score} and `negative_aesthetic_score` {negative_aesthetic_score} is correctly used by the model." - ) - elif ( - expected_add_embed_dim < passed_add_embed_dim - and (passed_add_embed_dim - expected_add_embed_dim) == self.unet.config.addition_time_embed_dim - ): - raise ValueError( - f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. Please make sure to disable `requires_aesthetics_score` with `pipe.register_to_config(requires_aesthetics_score=False)` to make sure `target_size` {target_size} is correctly used by the model." - ) - elif expected_add_embed_dim != passed_add_embed_dim: - raise ValueError( - f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`." - ) - - add_time_ids = torch.tensor([add_time_ids], dtype=dtype) - add_neg_time_ids = torch.tensor([add_neg_time_ids], dtype=dtype) - - return add_time_ids, add_neg_time_ids - - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image - def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None): - dtype = next(self.image_encoder.parameters()).dtype - - if not isinstance(image, torch.Tensor): - image = self.feature_extractor(image, return_tensors="pt").pixel_values - - image = image.to(device=device, dtype=dtype) - if output_hidden_states: - image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] - image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) - uncond_image_enc_hidden_states = self.image_encoder( - torch.zeros_like(image), output_hidden_states=True - ).hidden_states[-2] - uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( - num_images_per_prompt, dim=0 - ) - return image_enc_hidden_states, uncond_image_enc_hidden_states - else: - image_embeds = self.image_encoder(image).image_embeds - image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) - uncond_image_embeds = torch.zeros_like(image_embeds) - - return image_embeds, uncond_image_embeds - - # Modified from diffusers.pipelines.controlnet.pipeline_controlnet_sd_xl.StableDiffusionXLControlNetPipeline.prepare_image - # 1. return image without apply any guidance - # 2. add crops_coords and resize_mode to preprocess() - def prepare_control_image( - self, - image, - width, - height, - batch_size, - num_images_per_prompt, - device, - dtype, - crops_coords=None, - ): - if crops_coords is not None: - image = self.control_image_processor.preprocess(image, height=height, width=width, crops_coords=crops_coords, resize_mode="fill").to(dtype=torch.float32) - else: - image = self.control_image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32) - image_batch_size = image.shape[0] - - if image_batch_size == 1: - repeat_by = batch_size - else: - # image batch size is the same as prompt batch size - repeat_by = num_images_per_prompt - - image = image.repeat_interleave(repeat_by, dim=0) - - image = image.to(device=device, dtype=dtype) - - return image - - # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.encode_prompt - def encode_prompt( - self, - prompt: str, - prompt_2: Optional[str] = None, - device: Optional[torch.device] = None, - num_images_per_prompt: int = 1, - do_classifier_free_guidance: bool = True, - negative_prompt: Optional[str] = None, - negative_prompt_2: Optional[str] = None, - prompt_embeds: Optional[torch.Tensor] = None, - negative_prompt_embeds: Optional[torch.Tensor] = None, - pooled_prompt_embeds: Optional[torch.Tensor] = None, - negative_pooled_prompt_embeds: Optional[torch.Tensor] = None, - lora_scale: Optional[float] = None, - clip_skip: Optional[int] = None, - ): - r""" - Encodes the prompt into text encoder hidden states. - - Args: - prompt (`str` or `List[str]`, *optional*): - prompt to be encoded - prompt_2 (`str` or `List[str]`, *optional*): - The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is - used in both text-encoders - device: (`torch.device`): - torch device - num_images_per_prompt (`int`): - number of images that should be generated per prompt - do_classifier_free_guidance (`bool`): - whether to use classifier free guidance or not - negative_prompt (`str` or `List[str]`, *optional*): - The prompt or prompts not to guide the image generation. If not defined, one has to pass - `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is - less than `1`). - negative_prompt_2 (`str` or `List[str]`, *optional*): - The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and - `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders - prompt_embeds (`torch.Tensor`, *optional*): - Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not - provided, text embeddings will be generated from `prompt` input argument. - negative_prompt_embeds (`torch.Tensor`, *optional*): - Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt - weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input - argument. - pooled_prompt_embeds (`torch.Tensor`, *optional*): - Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. - If not provided, pooled text embeddings will be generated from `prompt` input argument. - negative_pooled_prompt_embeds (`torch.Tensor`, *optional*): - Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt - weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` - input argument. - lora_scale (`float`, *optional*): - A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. - clip_skip (`int`, *optional*): - Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that - the output of the pre-final layer will be used for computing the prompt embeddings. - """ - device = device or self._execution_device - - # set lora scale so that monkey patched LoRA - # function of text encoder can correctly access it - if lora_scale is not None and isinstance(self, StableDiffusionXLLoraLoaderMixin): - self._lora_scale = lora_scale - - # dynamically adjust the LoRA scale - if self.text_encoder is not None: - if not USE_PEFT_BACKEND: - adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) - else: - scale_lora_layers(self.text_encoder, lora_scale) - - if self.text_encoder_2 is not None: - if not USE_PEFT_BACKEND: - adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale) - else: - scale_lora_layers(self.text_encoder_2, lora_scale) - - prompt = [prompt] if isinstance(prompt, str) else prompt - - if prompt is not None: - batch_size = len(prompt) - else: - batch_size = prompt_embeds.shape[0] - - # Define tokenizers and text encoders - tokenizers = [self.tokenizer, self.tokenizer_2] if self.tokenizer is not None else [self.tokenizer_2] - text_encoders = ( - [self.text_encoder, self.text_encoder_2] if self.text_encoder is not None else [self.text_encoder_2] - ) - - if prompt_embeds is None: - prompt_2 = prompt_2 or prompt - prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2 - - # textual inversion: process multi-vector tokens if necessary - prompt_embeds_list = [] - prompts = [prompt, prompt_2] - for prompt, tokenizer, text_encoder in zip(prompts, tokenizers, text_encoders): - if isinstance(self, TextualInversionLoaderMixin): - prompt = self.maybe_convert_prompt(prompt, tokenizer) - - text_inputs = tokenizer( - prompt, - padding="max_length", - max_length=tokenizer.model_max_length, - truncation=True, - return_tensors="pt", - ) - - text_input_ids = text_inputs.input_ids - untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids - - if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( - text_input_ids, untruncated_ids - ): - removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1]) - logger.warning( - "The following part of your input was truncated because CLIP can only handle sequences up to" - f" {tokenizer.model_max_length} tokens: {removed_text}" - ) - - prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True) - - # We are only ALWAYS interested in the pooled output of the final text encoder - pooled_prompt_embeds = prompt_embeds[0] - if clip_skip is None: - prompt_embeds = prompt_embeds.hidden_states[-2] - else: - # "2" because SDXL always indexes from the penultimate layer. - prompt_embeds = prompt_embeds.hidden_states[-(clip_skip + 2)] - - prompt_embeds_list.append(prompt_embeds) - - prompt_embeds = torch.concat(prompt_embeds_list, dim=-1) - - # get unconditional embeddings for classifier free guidance - zero_out_negative_prompt = negative_prompt is None and self.config.force_zeros_for_empty_prompt - if do_classifier_free_guidance and negative_prompt_embeds is None and zero_out_negative_prompt: - negative_prompt_embeds = torch.zeros_like(prompt_embeds) - negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds) - elif do_classifier_free_guidance and negative_prompt_embeds is None: - negative_prompt = negative_prompt or "" - negative_prompt_2 = negative_prompt_2 or negative_prompt - - # normalize str to list - negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt - negative_prompt_2 = ( - batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2 - ) - - uncond_tokens: List[str] - if prompt is not None and type(prompt) is not type(negative_prompt): - raise TypeError( - f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" - f" {type(prompt)}." - ) - elif batch_size != len(negative_prompt): - raise ValueError( - f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" - f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" - " the batch size of `prompt`." - ) - else: - uncond_tokens = [negative_prompt, negative_prompt_2] - - negative_prompt_embeds_list = [] - for negative_prompt, tokenizer, text_encoder in zip(uncond_tokens, tokenizers, text_encoders): - if isinstance(self, TextualInversionLoaderMixin): - negative_prompt = self.maybe_convert_prompt(negative_prompt, tokenizer) - - max_length = prompt_embeds.shape[1] - uncond_input = tokenizer( - negative_prompt, - padding="max_length", - max_length=max_length, - truncation=True, - return_tensors="pt", - ) - - negative_prompt_embeds = text_encoder( - uncond_input.input_ids.to(device), - output_hidden_states=True, - ) - # We are only ALWAYS interested in the pooled output of the final text encoder - negative_pooled_prompt_embeds = negative_prompt_embeds[0] - negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2] - - negative_prompt_embeds_list.append(negative_prompt_embeds) - - negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1) - - if self.text_encoder_2 is not None: - prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) - else: - prompt_embeds = prompt_embeds.to(dtype=self.unet.dtype, device=device) - - bs_embed, seq_len, _ = prompt_embeds.shape - # duplicate text embeddings for each generation per prompt, using mps friendly method - prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) - prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) - - if do_classifier_free_guidance: - # duplicate unconditional embeddings for each generation per prompt, using mps friendly method - seq_len = negative_prompt_embeds.shape[1] - - if self.text_encoder_2 is not None: - negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) - else: - negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.unet.dtype, device=device) - - negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) - negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) - - pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( - bs_embed * num_images_per_prompt, -1 - ) - if do_classifier_free_guidance: - negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( - bs_embed * num_images_per_prompt, -1 - ) - - if self.text_encoder is not None: - if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND: - # Retrieve the original scale by scaling back the LoRA layers - unscale_lora_layers(self.text_encoder, lora_scale) - - if self.text_encoder_2 is not None: - if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND: - # Retrieve the original scale by scaling back the LoRA layers - unscale_lora_layers(self.text_encoder_2, lora_scale) - - return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds - - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds - def prepare_ip_adapter_image_embeds( - self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance - ): - image_embeds = [] - if do_classifier_free_guidance: - negative_image_embeds = [] - if ip_adapter_image_embeds is None: - if not isinstance(ip_adapter_image, list): - ip_adapter_image = [ip_adapter_image] - - if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers): - raise ValueError( - f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters." - ) - - for single_ip_adapter_image, image_proj_layer in zip( - ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers - ): - output_hidden_state = not isinstance(image_proj_layer, ImageProjection) - single_image_embeds, single_negative_image_embeds = self.encode_image( - single_ip_adapter_image, device, 1, output_hidden_state - ) - - image_embeds.append(single_image_embeds[None, :]) - if do_classifier_free_guidance: - negative_image_embeds.append(single_negative_image_embeds[None, :]) - else: - for single_image_embeds in ip_adapter_image_embeds: - if do_classifier_free_guidance: - single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2) - negative_image_embeds.append(single_negative_image_embeds) - image_embeds.append(single_image_embeds) - - ip_adapter_image_embeds = [] - for i, single_image_embeds in enumerate(image_embeds): - single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0) - if do_classifier_free_guidance: - single_negative_image_embeds = torch.cat([negative_image_embeds[i]] * num_images_per_prompt, dim=0) - single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0) - - single_image_embeds = single_image_embeds.to(device=device) - ip_adapter_image_embeds.append(single_image_embeds) - - return ip_adapter_image_embeds - - # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_img2img.StableDiffusionXLImg2ImgPipeline.get_timesteps - def get_timesteps(self, num_inference_steps, strength, device, denoising_start=None): - # get the original timestep using init_timestep - if denoising_start is None: - init_timestep = min(int(num_inference_steps * strength), num_inference_steps) - t_start = max(num_inference_steps - init_timestep, 0) - - timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :] - if hasattr(self.scheduler, "set_begin_index"): - self.scheduler.set_begin_index(t_start * self.scheduler.order) - - return timesteps, num_inference_steps - t_start - - else: - # Strength is irrelevant if we directly request a timestep to start at; - # that is, strength is determined by the denoising_start instead. - discrete_timestep_cutoff = int( - round( - self.scheduler.config.num_train_timesteps - - (denoising_start * self.scheduler.config.num_train_timesteps) - ) - ) - - num_inference_steps = (self.scheduler.timesteps < discrete_timestep_cutoff).sum().item() - if self.scheduler.order == 2 and num_inference_steps % 2 == 0: - # if the scheduler is a 2nd order scheduler we might have to do +1 - # because `num_inference_steps` might be even given that every timestep - # (except the highest one) is duplicated. If `num_inference_steps` is even it would - # mean that we cut the timesteps in the middle of the denoising step - # (between 1st and 2nd derivative) which leads to incorrect results. By adding 1 - # we ensure that the denoising process always ends after the 2nd derivate step of the scheduler - num_inference_steps = num_inference_steps + 1 - - # because t_n+1 >= t_n, we slice the timesteps starting from the end - t_start = len(self.scheduler.timesteps) - num_inference_steps - timesteps = self.scheduler.timesteps[t_start:] - if hasattr(self.scheduler, "set_begin_index"): - self.scheduler.set_begin_index(t_start) - return timesteps, num_inference_steps - - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents - def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): - shape = ( - batch_size, - num_channels_latents, - int(height) // self.vae_scale_factor, - int(width) // self.vae_scale_factor, - ) - if isinstance(generator, list) and len(generator) != batch_size: - raise ValueError( - f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" - f" size of {batch_size}. Make sure the batch size matches the length of the generators." - ) - - if latents is None: - latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) - else: - latents = latents.to(device) - - # scale the initial noise by the standard deviation required by the scheduler - latents = latents * self.scheduler.init_noise_sigma - return latents - - # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_img2img.StableDiffusionXLImg2ImgPipeline.prepare_latents - # YiYi TODO: refactor using _encode_vae_image - def prepare_latents_img2img( - self, image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None, add_noise=True - ): - if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)): - raise ValueError( - f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}" - ) - - # Offload text encoder if `enable_model_cpu_offload` was enabled - if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: - self.text_encoder_2.to("cpu") - torch.cuda.empty_cache() - - image = image.to(device=device, dtype=dtype) - - batch_size = batch_size * num_images_per_prompt - - if image.shape[1] == 4: - init_latents = image - - else: - latents_mean = latents_std = None - if hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None: - latents_mean = torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1) - if hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None: - latents_std = torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1) - # make sure the VAE is in float32 mode, as it overflows in float16 - if self.vae.config.force_upcast: - image = image.float() - self.vae.to(dtype=torch.float32) - - if isinstance(generator, list) and len(generator) != batch_size: - raise ValueError( - f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" - f" size of {batch_size}. Make sure the batch size matches the length of the generators." - ) - - elif isinstance(generator, list): - if image.shape[0] < batch_size and batch_size % image.shape[0] == 0: - image = torch.cat([image] * (batch_size // image.shape[0]), dim=0) - elif image.shape[0] < batch_size and batch_size % image.shape[0] != 0: - raise ValueError( - f"Cannot duplicate `image` of batch size {image.shape[0]} to effective batch_size {batch_size} " - ) - - init_latents = [ - retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i]) - for i in range(batch_size) - ] - init_latents = torch.cat(init_latents, dim=0) - else: - init_latents = retrieve_latents(self.vae.encode(image), generator=generator) - - if self.vae.config.force_upcast: - self.vae.to(dtype) - - init_latents = init_latents.to(dtype) - if latents_mean is not None and latents_std is not None: - latents_mean = latents_mean.to(device=device, dtype=dtype) - latents_std = latents_std.to(device=device, dtype=dtype) - init_latents = (init_latents - latents_mean) * self.vae.config.scaling_factor / latents_std - else: - init_latents = self.vae.config.scaling_factor * init_latents - - if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0: - # expand init_latents for batch_size - additional_image_per_prompt = batch_size // init_latents.shape[0] - init_latents = torch.cat([init_latents] * additional_image_per_prompt, dim=0) - elif batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0: - raise ValueError( - f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts." - ) - else: - init_latents = torch.cat([init_latents], dim=0) - - if add_noise: - shape = init_latents.shape - noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) - # get latents - init_latents = self.scheduler.add_noise(init_latents, noise, timestep) - - latents = init_latents - - return latents - - # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_inpaint.StableDiffusionXLInpaintPipeline.prepare_latents - def prepare_latents_inpaint( - self, - batch_size, - num_channels_latents, - height, - width, - dtype, - device, - generator, - latents=None, - image=None, - timestep=None, - is_strength_max=True, - add_noise=True, - return_noise=False, - return_image_latents=False, - ): - shape = ( - batch_size, - num_channels_latents, - int(height) // self.vae_scale_factor, - int(width) // self.vae_scale_factor, - ) - if isinstance(generator, list) and len(generator) != batch_size: - raise ValueError( - f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" - f" size of {batch_size}. Make sure the batch size matches the length of the generators." - ) - - if (image is None or timestep is None) and not is_strength_max: - raise ValueError( - "Since strength < 1. initial latents are to be initialised as a combination of Image + Noise." - "However, either the image or the noise timestep has not been provided." - ) - - if image.shape[1] == 4: - image_latents = image.to(device=device, dtype=dtype) - image_latents = image_latents.repeat(batch_size // image_latents.shape[0], 1, 1, 1) - elif return_image_latents or (latents is None and not is_strength_max): - image = image.to(device=device, dtype=dtype) - image_latents = self._encode_vae_image(image=image, generator=generator) - image_latents = image_latents.repeat(batch_size // image_latents.shape[0], 1, 1, 1) - - if latents is None and add_noise: - noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) - # if strength is 1. then initialise the latents to noise, else initial to image + noise - latents = noise if is_strength_max else self.scheduler.add_noise(image_latents, noise, timestep) - # if pure noise then scale the initial latents by the Scheduler's init sigma - latents = latents * self.scheduler.init_noise_sigma if is_strength_max else latents - elif add_noise: - noise = latents.to(device) - latents = noise * self.scheduler.init_noise_sigma - else: - noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) - latents = image_latents.to(device) - - outputs = (latents,) - - if return_noise: - outputs += (noise,) - - if return_image_latents: - outputs += (image_latents,) - - return outputs - - - # Modified from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_inpaint.StableDiffusionXLInpaintPipeline._encode_vae_image - # YiYi TODO: update the _encode_vae_image so that we can use #Coped from - def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator): - - latents_mean = latents_std = None - if hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None: - latents_mean = torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1) - if hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None: - latents_std = torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1) - - dtype = image.dtype - if self.vae.config.force_upcast: - image = image.float() - self.vae.to(dtype=torch.float32) - - if isinstance(generator, list): - image_latents = [ - retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i]) - for i in range(image.shape[0]) - ] - image_latents = torch.cat(image_latents, dim=0) - else: - image_latents = retrieve_latents(self.vae.encode(image), generator=generator) - - if self.vae.config.force_upcast: - self.vae.to(dtype) - - image_latents = image_latents.to(dtype) - if latents_mean is not None and latents_std is not None: - latents_mean = latents_mean.to(device=image_latents.device, dtype=dtype) - latents_std = latents_std.to(device=image_latents.device, dtype=dtype) - image_latents = (image_latents - latents_mean) * self.vae.config.scaling_factor / latents_std - else: - image_latents = self.vae.config.scaling_factor * image_latents - - return image_latents - - - # modified from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_inpaint.StableDiffusionXLInpaintPipeline.prepare_mask_latents - # do not accept do_classifier_free_guidance - def prepare_mask_latents( - self, mask, masked_image, batch_size, height, width, dtype, device, generator - ): - # resize the mask to latents shape as we concatenate the mask to the latents - # we do that before converting to dtype to avoid breaking in case we're using cpu_offload - # and half precision - mask = torch.nn.functional.interpolate( - mask, size=(height // self.vae_scale_factor, width // self.vae_scale_factor) - ) - mask = mask.to(device=device, dtype=dtype) - - # duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method - if mask.shape[0] < batch_size: - if not batch_size % mask.shape[0] == 0: - raise ValueError( - "The passed mask and the required batch size don't match. Masks are supposed to be duplicated to" - f" a total batch size of {batch_size}, but {mask.shape[0]} masks were passed. Make sure the number" - " of masks that you pass is divisible by the total requested batch size." - ) - mask = mask.repeat(batch_size // mask.shape[0], 1, 1, 1) - - if masked_image is not None and masked_image.shape[1] == 4: - masked_image_latents = masked_image - else: - masked_image_latents = None - - if masked_image is not None: - if masked_image_latents is None: - masked_image = masked_image.to(device=device, dtype=dtype) - masked_image_latents = self._encode_vae_image(masked_image, generator=generator) - - if masked_image_latents.shape[0] < batch_size: - if not batch_size % masked_image_latents.shape[0] == 0: - raise ValueError( - "The passed images and the required batch size don't match. Images are supposed to be duplicated" - f" to a total batch size of {batch_size}, but {masked_image_latents.shape[0]} images were passed." - " Make sure the number of images that you pass is divisible by the total requested batch size." - ) - masked_image_latents = masked_image_latents.repeat( - batch_size // masked_image_latents.shape[0], 1, 1, 1 - ) - - # aligning device to prevent device errors when concating it with the latent model input - masked_image_latents = masked_image_latents.to(device=device, dtype=dtype) - - return mask, masked_image_latents - - - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs - def prepare_extra_step_kwargs(self, generator, eta): - # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature - # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. - # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 - # and should be between [0, 1] - - accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) - extra_step_kwargs = {} - if accepts_eta: - extra_step_kwargs["eta"] = eta - - # check if the scheduler accepts generator - accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) - if accepts_generator: - extra_step_kwargs["generator"] = generator - return extra_step_kwargs - - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_upscale.StableDiffusionUpscalePipeline.upcast_vae - def upcast_vae(self): - dtype = self.vae.dtype - self.vae.to(dtype=torch.float32) - use_torch_2_0_or_xformers = isinstance( - self.vae.decoder.mid_block.attentions[0].processor, - ( - AttnProcessor2_0, - XFormersAttnProcessor, - ), - ) - # if xformers or torch_2_0 is used attention block does not need - # to be in float32 which can save lots of memory - if use_torch_2_0_or_xformers: - self.vae.post_quant_conv.to(dtype) - self.vae.decoder.conv_in.to(dtype) - self.vae.decoder.mid_block.to(dtype) - - # Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding - def get_guidance_scale_embedding( - self, w: torch.Tensor, embedding_dim: int = 512, dtype: torch.dtype = torch.float32 - ) -> torch.Tensor: - """ - See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298 - - Args: - w (`torch.Tensor`): - Generate embedding vectors with a specified guidance scale to subsequently enrich timestep embeddings. - embedding_dim (`int`, *optional*, defaults to 512): - Dimension of the embeddings to generate. - dtype (`torch.dtype`, *optional*, defaults to `torch.float32`): - Data type of the generated embeddings. - - Returns: - `torch.Tensor`: Embedding vectors with shape `(len(w), embedding_dim)`. - """ - assert len(w.shape) == 1 - w = w * 1000.0 - - half_dim = embedding_dim // 2 - emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1) - emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb) - emb = w.to(dtype)[:, None] * emb[None, :] - emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) - if embedding_dim % 2 == 1: # zero pad - emb = torch.nn.functional.pad(emb, (0, 1)) - assert emb.shape == (w.shape[0], embedding_dim) - return emb diff --git a/src/diffusers/utils/dummy_pt_objects.py b/src/diffusers/utils/dummy_pt_objects.py index b94b9ad4a7e3..0d28cb81af38 100644 --- a/src/diffusers/utils/dummy_pt_objects.py +++ b/src/diffusers/utils/dummy_pt_objects.py @@ -1388,7 +1388,7 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) -class ModularPipeline(metaclass=DummyObject): +class ModularLoader(metaclass=DummyObject): _backends = ["torch"] def __init__(self, *args, **kwargs): diff --git a/src/diffusers/utils/dummy_torch_and_transformers_objects.py b/src/diffusers/utils/dummy_torch_and_transformers_objects.py index 3e7c3a735ee9..a512b107cf96 100644 --- a/src/diffusers/utils/dummy_torch_and_transformers_objects.py +++ b/src/diffusers/utils/dummy_torch_and_transformers_objects.py @@ -2432,7 +2432,7 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) -class StableDiffusionXLModularPipeline(metaclass=DummyObject): +class StableDiffusionXLModularLoader(metaclass=DummyObject): _backends = ["torch", "transformers"] def __init__(self, *args, **kwargs): diff --git a/src/diffusers/utils/dynamic_modules_utils.py b/src/diffusers/utils/dynamic_modules_utils.py index 5d0752af8983..5d5eb23969ab 100644 --- a/src/diffusers/utils/dynamic_modules_utils.py +++ b/src/diffusers/utils/dynamic_modules_utils.py @@ -15,13 +15,16 @@ """Utilities to dynamically load objects from the Hub.""" import importlib +import signal import inspect import json import os import re import shutil import sys +import threading from pathlib import Path +from types import ModuleType from typing import Dict, Optional, Union from urllib import request @@ -37,6 +40,8 @@ # See https://huggingface.co/datasets/diffusers/community-pipelines-mirror COMMUNITY_PIPELINES_MIRROR_ID = "diffusers/community-pipelines-mirror" +TIME_OUT_REMOTE_CODE = int(os.getenv("DIFFUSERS_TIMEOUT_REMOTE_CODE", 15)) +_HF_REMOTE_CODE_LOCK = threading.Lock() def get_diffusers_versions(): @@ -154,15 +159,87 @@ def check_imports(filename): return get_relative_imports(filename) -def get_class_in_module(class_name, module_path): +def _raise_timeout_error(signum, frame): + raise ValueError( + "Loading this model requires you to execute custom code contained in the model repository on your local " + "machine. Please set the option `trust_remote_code=True` to permit loading of this model." + ) + + +def resolve_trust_remote_code(trust_remote_code, model_name, has_remote_code): + if trust_remote_code is None: + if has_remote_code and TIME_OUT_REMOTE_CODE > 0: + prev_sig_handler = None + try: + prev_sig_handler = signal.signal(signal.SIGALRM, _raise_timeout_error) + signal.alarm(TIME_OUT_REMOTE_CODE) + while trust_remote_code is None: + answer = input( + f"The repository for {model_name} contains custom code which must be executed to correctly " + f"load the model. You can inspect the repository content at https://hf.co/{model_name}.\n" + f"You can avoid this prompt in future by passing the argument `trust_remote_code=True`.\n\n" + f"Do you wish to run the custom code? [y/N] " + ) + if answer.lower() in ["yes", "y", "1"]: + trust_remote_code = True + elif answer.lower() in ["no", "n", "0", ""]: + trust_remote_code = False + signal.alarm(0) + except Exception: + # OS which does not support signal.SIGALRM + raise ValueError( + f"The repository for {model_name} contains custom code which must be executed to correctly " + f"load the model. You can inspect the repository content at https://hf.co/{model_name}.\n" + f"Please pass the argument `trust_remote_code=True` to allow custom code to be run." + ) + finally: + if prev_sig_handler is not None: + signal.signal(signal.SIGALRM, prev_sig_handler) + signal.alarm(0) + elif has_remote_code: + # For the CI which puts the timeout at 0 + _raise_timeout_error(None, None) + + if has_remote_code and not trust_remote_code: + raise ValueError( + f"Loading {model_name} requires you to execute the configuration file in that" + " repo on your local machine. Make sure you have read the code there to avoid malicious use, then" + " set the option `trust_remote_code=True` to remove this error." + ) + + return trust_remote_code + + +def get_class_in_module(class_name, module_path, force_reload=False): """ Import a module on the cache directory for modules and extract a class from it. """ - module_path = module_path.replace(os.path.sep, ".") - module = importlib.import_module(module_path) + name = os.path.normpath(module_path) + if name.endswith(".py"): + name = name[:-3] + name = name.replace(os.path.sep, ".") + module_file: Path = Path(HF_MODULES_CACHE) / module_path + + with _HF_REMOTE_CODE_LOCK: + if force_reload: + sys.modules.pop(name, None) + importlib.invalidate_caches() + cached_module: Optional[ModuleType] = sys.modules.get(name) + module_spec = importlib.util.spec_from_file_location(name, location=module_file) + + module: ModuleType + if cached_module is None: + module = importlib.util.module_from_spec(module_spec) + # insert it into sys.modules before any loading begins + sys.modules[name] = module + else: + module = cached_module + + module_spec.loader.exec_module(module) if class_name is None: return find_pipeline_class(module) + return getattr(module, class_name) @@ -454,4 +531,4 @@ def get_class_from_dynamic_module( revision=revision, local_files_only=local_files_only, ) - return get_class_in_module(class_name, final_module.replace(".py", "")) + return get_class_in_module(class_name, final_module) diff --git a/src/diffusers/utils/torch_utils.py b/src/diffusers/utils/torch_utils.py index a5df07e4a3c2..622c0d124f97 100644 --- a/src/diffusers/utils/torch_utils.py +++ b/src/diffusers/utils/torch_utils.py @@ -90,6 +90,11 @@ def is_compiled_module(module) -> bool: return isinstance(module, torch._dynamo.eval_frame.OptimizedModule) +def unwrap_module(module): + """Unwraps a module if it was compiled with torch.compile()""" + return module._orig_mod if is_compiled_module(module) else module + + def fourier_filter(x_in: "torch.Tensor", threshold: int, scale: int) -> "torch.Tensor": """Fourier filter as introduced in FreeU (https://arxiv.org/abs/2309.11497).