diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index a4f55acf8b70..c9ee38ac6fda 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -249,7 +249,7 @@ "KarrasVePipeline", "LDMPipeline", "LDMSuperResolutionPipeline", - "ModularPipeline", + "ModularLoader", "PNDMPipeline", "RePaintPipeline", "ScoreSdeVePipeline", @@ -502,7 +502,7 @@ "StableDiffusionXLImg2ImgPipeline", "StableDiffusionXLInpaintPipeline", "StableDiffusionXLInstructPix2PixPipeline", - "StableDiffusionXLModularPipeline", + "StableDiffusionXLModularLoader", "StableDiffusionXLPAGImg2ImgPipeline", "StableDiffusionXLPAGInpaintPipeline", "StableDiffusionXLPAGPipeline", @@ -840,7 +840,7 @@ KarrasVePipeline, LDMPipeline, LDMSuperResolutionPipeline, - ModularPipeline, + ModularLoader, PNDMPipeline, RePaintPipeline, ScoreSdeVePipeline, @@ -1071,7 +1071,7 @@ StableDiffusionXLImg2ImgPipeline, StableDiffusionXLInpaintPipeline, StableDiffusionXLInstructPix2PixPipeline, - StableDiffusionXLModularPipeline, + StableDiffusionXLModularLoader, StableDiffusionXLPAGImg2ImgPipeline, StableDiffusionXLPAGInpaintPipeline, StableDiffusionXLPAGPipeline, diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index aee275db0336..7b6bd2071ef4 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -46,7 +46,7 @@ "AutoPipelineForInpainting", "AutoPipelineForText2Image", ] - _import_structure["modular_pipeline"] = ["ModularPipeline"] + _import_structure["modular_pipeline"] = ["ModularLoader"] _import_structure["consistency_models"] = ["ConsistencyModelPipeline"] _import_structure["dance_diffusion"] = ["DanceDiffusionPipeline"] _import_structure["ddim"] = ["DDIMPipeline"] @@ -329,7 +329,7 @@ "StableDiffusionXLInpaintPipeline", "StableDiffusionXLInstructPix2PixPipeline", "StableDiffusionXLPipeline", - "StableDiffusionXLModularPipeline", + "StableDiffusionXLModularLoader", "StableDiffusionXLAutoPipeline", ] ) @@ -468,7 +468,7 @@ from .deprecated import KarrasVePipeline, LDMPipeline, PNDMPipeline, RePaintPipeline, ScoreSdeVePipeline from .dit import DiTPipeline from .latent_diffusion import LDMSuperResolutionPipeline - from .modular_pipeline import ModularPipeline + from .modular_pipeline import ModularLoader from .pipeline_utils import ( AudioPipelineOutput, DiffusionPipeline, @@ -693,7 +693,7 @@ StableDiffusionXLImg2ImgPipeline, StableDiffusionXLInpaintPipeline, StableDiffusionXLInstructPix2PixPipeline, - StableDiffusionXLModularPipeline, + StableDiffusionXLModularLoader, StableDiffusionXLPipeline, StableDiffusionXLAutoPipeline, ) diff --git a/src/diffusers/pipelines/components_manager.py b/src/diffusers/pipelines/components_manager.py index 8c14321ccfac..bdff133e22d9 100644 --- a/src/diffusers/pipelines/components_manager.py +++ b/src/diffusers/pipelines/components_manager.py @@ -26,6 +26,7 @@ logging, ) from ..models.modeling_utils import ModelMixin +from .modular_pipeline_utils import ComponentSpec if is_accelerate_available(): @@ -229,54 +230,175 @@ def search_best_candidate(module_sizes, min_memory_offload): return hooks_to_offload + +from .modular_pipeline_utils import ComponentSpec +import uuid class ComponentsManager: def __init__(self): self.components = OrderedDict() - self.added_time = OrderedDict() # Store when components were added + self.added_time = OrderedDict() # Store when components were added + self.collections = OrderedDict() # collection_name -> set of component_names self.model_hooks = None self._auto_offload_enabled = False - def add(self, name, component): - if name in self.components: - logger.warning(f"Overriding existing component '{name}' in ComponentsManager") - self.components[name] = component - self.added_time[name] = time.time() + + def _get_by_collection(self, collection: str): + """ + Select components by collection name. + """ + selected_components = {} + if collection in self.collections: + component_ids = self.collections[collection] + for component_id in component_ids: + selected_components[component_id] = self.components[component_id] + return selected_components + + def _get_by_load_id(self, load_id: str): + """ + Select components by its load_id. + """ + selected_components = {} + for name, component in self.components.items(): + if hasattr(component, "_diffusers_load_id") and component._diffusers_load_id == load_id: + selected_components[name] = component + return selected_components + + + def add(self, name, component, collection: Optional[str] = None): + + for comp_id, comp in self.components.items(): + if comp == component: + logger.warning(f"Component '{name}' already exists in ComponentsManager") + return comp_id + + component_id = f"{name}_{uuid.uuid4()}" + + if hasattr(component, "_diffusers_load_id") and component._diffusers_load_id != "null": + components_with_same_load_id = self._get_by_load_id(component._diffusers_load_id) + if components_with_same_load_id: + existing = ", ".join(components_with_same_load_id.keys()) + logger.warning( + f"Component '{name}' has duplicate load_id '{component._diffusers_load_id}' with existing components: {existing}. " + f"To remove a duplicate, call `components_manager.remove('')`." + ) + + + # add component to components manager + self.components[component_id] = component + self.added_time[component_id] = time.time() + if collection: + if collection not in self.collections: + self.collections[collection] = set() + self.collections[collection].add(component_id) + if self._auto_offload_enabled: - self.enable_auto_cpu_offload(self._auto_offload_device) + self.enable_auto_cpu_offload(self._auto_offload_device) + + logger.info(f"Added component '{name}' to ComponentsManager as '{component_id}'") + return component_id + + + def remove(self, name: Union[str, List[str]]): - def remove(self, name): if name not in self.components: logger.warning(f"Component '{name}' not found in ComponentsManager") return self.components.pop(name) self.added_time.pop(name) + + for collection in self.collections: + if name in self.collections[collection]: + self.collections[collection].remove(name) if self._auto_offload_enabled: self.enable_auto_cpu_offload(self._auto_offload_device) - # YiYi TODO: looking into improving the search pattern - def get(self, names: Union[str, List[str]]): + def get(self, names: Union[str, List[str]] = None, collection: Optional[str] = None, load_id: Optional[str] = None, + as_name_component_tuples: bool = False): """ - Get components by name with simple pattern matching. + Select components by name with simple pattern matching. Args: names: Component name(s) or pattern(s) Patterns: - - "unet" : exact match - - "!unet" : everything except exact match "unet" - - "base_*" : everything starting with "base_" - - "!base_*" : everything NOT starting with "base_" - - "*unet*" : anything containing "unet" - - "!*unet*" : anything NOT containing "unet" - - "refiner|vae|unet" : anything containing any of these terms - - "!refiner|vae|unet" : anything NOT containing any of these terms + - "unet" : match any component with base name "unet" (e.g., unet_123abc) + - "!unet" : everything except components with base name "unet" + - "unet*" : anything with base name starting with "unet" + - "!unet*" : anything with base name NOT starting with "unet" + - "*unet*" : anything with base name containing "unet" + - "!*unet*" : anything with base name NOT containing "unet" + - "refiner|vae|unet" : anything with base name exactly matching "refiner", "vae", or "unet" + - "!refiner|vae|unet" : anything with base name NOT exactly matching "refiner", "vae", or "unet" + - "unet*|vae*" : anything with base name starting with "unet" OR starting with "vae" + collection: Optional collection to filter by + load_id: Optional load_id to filter by + as_name_component_tuples: If True, returns a list of (name, component) tuples using base names + instead of a dictionary with component IDs as keys Returns: - Single component if names is str and matches one component, - dict of components if names matches multiple components or is a list + Dictionary mapping component IDs to components, + or list of (base_name, component) tuples if as_name_component_tuples=True """ + + if collection: + if collection not in self.collections: + logger.warning(f"Collection '{collection}' not found in ComponentsManager") + return [] if as_name_component_tuples else {} + components = self._get_by_collection(collection) + else: + components = self.components + + if load_id: + components = self._get_by_load_id(load_id) + + # Helper to extract base name from component_id + def get_base_name(component_id): + parts = component_id.split('_') + # If the last part looks like a UUID, remove it + if len(parts) > 1 and len(parts[-1]) >= 8 and '-' in parts[-1]: + return '_'.join(parts[:-1]) + return component_id + + if names is None: + if as_name_component_tuples: + return [(get_base_name(comp_id), comp) for comp_id, comp in components.items()] + else: + return components + + # Create mapping from component_id to base_name for all components + base_names = {comp_id: get_base_name(comp_id) for comp_id in components.keys()} + + def matches_pattern(component_id, pattern, exact_match=False): + """ + Helper function to check if a component matches a pattern based on its base name. + + Args: + component_id: The component ID to check + pattern: The pattern to match against + exact_match: If True, only exact matches to base_name are considered + """ + base_name = base_names[component_id] + + # Exact match with base name + if exact_match: + return pattern == base_name + + # Prefix match (ends with *) + elif pattern.endswith('*'): + prefix = pattern[:-1] + return base_name.startswith(prefix) + + # Contains match (starts with *) + elif pattern.startswith('*'): + search = pattern[1:-1] if pattern.endswith('*') else pattern[1:] + return search in base_name + + # Exact match (no wildcards) + else: + return pattern == base_name + if isinstance(names, str): # Check if this is a "not" pattern is_not_pattern = names.startswith('!') @@ -286,33 +408,45 @@ def get(self, names: Union[str, List[str]]): # Handle OR patterns (containing |) if '|' in names: terms = names.split('|') + matches = {} + + for comp_id, comp in components.items(): + # For OR patterns with exact names (no wildcards), we do exact matching on base names + exact_match = all(not (term.startswith('*') or term.endswith('*')) for term in terms) + + # Check if any of the terms match this component + should_include = any(matches_pattern(comp_id, term, exact_match) for term in terms) + + # Flip the decision if this is a NOT pattern + if is_not_pattern: + should_include = not should_include + + if should_include: + matches[comp_id] = comp + + log_msg = "NOT " if is_not_pattern else "" + match_type = "exactly matching" if exact_match else "matching any of patterns" + logger.info(f"Getting components {log_msg}{match_type} {terms}: {list(matches.keys())}") + + # Try exact match with a base name + elif any(names == base_name for base_name in base_names.values()): + # Find all components with this base name matches = { - name: comp for name, comp in self.components.items() - if any((term in name) != is_not_pattern for term in terms) # Flip condition if not pattern + comp_id: comp for comp_id, comp in components.items() + if (base_names[comp_id] == names) != is_not_pattern } + if is_not_pattern: - logger.info(f"Getting components NOT containing any of {terms}: {list(matches.keys())}") - else: - logger.info(f"Getting components containing any of {terms}: {list(matches.keys())}") - - # Exact match - elif names in self.components: - if is_not_pattern: - matches = { - name: comp for name, comp in self.components.items() - if name != names - } - logger.info(f"Getting all components except '{names}': {list(matches.keys())}") + logger.info(f"Getting all components except those with base name '{names}': {list(matches.keys())}") else: - logger.info(f"Getting component: {names}") - return self.components[names] + logger.info(f"Getting components with base name '{names}': {list(matches.keys())}") # Prefix match (ends with *) elif names.endswith('*'): prefix = names[:-1] matches = { - name: comp for name, comp in self.components.items() - if name.startswith(prefix) != is_not_pattern # Flip condition if not pattern + comp_id: comp for comp_id, comp in components.items() + if base_names[comp_id].startswith(prefix) != is_not_pattern } if is_not_pattern: logger.info(f"Getting components NOT starting with '{prefix}': {list(matches.keys())}") @@ -323,30 +457,46 @@ def get(self, names: Union[str, List[str]]): elif names.startswith('*'): search = names[1:-1] if names.endswith('*') else names[1:] matches = { - name: comp for name, comp in self.components.items() - if (search in name) != is_not_pattern # Flip condition if not pattern + comp_id: comp for comp_id, comp in components.items() + if (search in base_names[comp_id]) != is_not_pattern } if is_not_pattern: logger.info(f"Getting components NOT containing '{search}': {list(matches.keys())}") else: logger.info(f"Getting components containing '{search}': {list(matches.keys())}") + # Substring match (no wildcards, but not an exact component name) + elif any(names in base_name for base_name in base_names.values()): + matches = { + comp_id: comp for comp_id, comp in components.items() + if (names in base_names[comp_id]) != is_not_pattern + } + if is_not_pattern: + logger.info(f"Getting components NOT containing '{names}': {list(matches.keys())}") + else: + logger.info(f"Getting components containing '{names}': {list(matches.keys())}") + else: - raise ValueError(f"Component '{names}' not found in ComponentsManager") + raise ValueError(f"Component or pattern '{names}' not found in ComponentsManager") if not matches: raise ValueError(f"No components found matching pattern '{names}'") - return matches if len(matches) > 1 else next(iter(matches.values())) + + if as_name_component_tuples: + return [(base_names[comp_id], comp) for comp_id, comp in matches.items()] + else: + return matches elif isinstance(names, list): results = {} for name in names: - result = self.get(name) - if isinstance(result, dict): - results.update(result) - else: - results[name] = result - return results + result = self.get(name, collection, load_id, as_name_component_tuples=False) + results.update(result) + + if as_name_component_tuples: + return [(base_names[comp_id], comp) for comp_id, comp in results.items()] + else: + return results else: raise ValueError(f"Invalid type for names: {type(names)}") @@ -390,6 +540,7 @@ def disable_auto_cpu_offload(self): self.model_hooks = None self._auto_offload_enabled = False + # YiYi TODO: add quantization info def get_model_info(self, name: str, fields: Optional[Union[str, List[str]]] = None) -> Optional[Dict[str, Any]]: """Get comprehensive information about a component. @@ -412,14 +563,23 @@ def get_model_info(self, name: str, fields: Optional[Union[str, List[str]]] = No info = { "model_id": name, "added_time": self.added_time[name], + "collection": next((coll for coll, comps in self.collections.items() if name in comps), None), } # Additional info for torch.nn.Module components if isinstance(component, torch.nn.Module): + # Check for hook information + has_hook = hasattr(component, "_hf_hook") + execution_device = None + if has_hook and hasattr(component._hf_hook, "execution_device"): + execution_device = component._hf_hook.execution_device + info.update({ "class_name": component.__class__.__name__, "size_gb": get_memory_footprint(component) / (1024**3), "adapters": None, # Default to None + "has_hook": has_hook, + "execution_device": execution_device, }) # Get adapters if applicable @@ -453,12 +613,56 @@ def get_model_info(self, name: str, fields: Optional[Union[str, List[str]]] = No return info def __repr__(self): + # Helper to get simple name without UUID + def get_simple_name(name): + # Extract the base name by splitting on underscore and taking first part + # This assumes names are in format "name_uuid" + parts = name.split('_') + # If we have at least 2 parts and the last part looks like a UUID, remove it + if len(parts) > 1 and len(parts[-1]) >= 8 and '-' in parts[-1]: + return '_'.join(parts[:-1]) + return name + + # Extract load_id if available + def get_load_id(component): + if hasattr(component, "_diffusers_load_id"): + return component._diffusers_load_id + return "N/A" + + # Format device info compactly + def format_device(component, info): + if not info["has_hook"]: + return str(getattr(component, 'device', 'N/A')) + else: + device = str(getattr(component, 'device', 'N/A')) + exec_device = str(info['execution_device'] or 'N/A') + return f"{device}({exec_device})" + + # Get all simple names to calculate width + simple_names = [get_simple_name(id) for id in self.components.keys()] + + # Get max length of load_ids for models + load_ids = [ + get_load_id(component) + for component in self.components.values() + if isinstance(component, torch.nn.Module) and hasattr(component, "_diffusers_load_id") + ] + max_load_id_len = max([15] + [len(str(lid)) for lid in load_ids]) if load_ids else 15 + + # Collection names + collection_names = [ + next((coll for coll, comps in self.collections.items() if name in comps), "N/A") + for name in self.components.keys() + ] + col_widths = { - "id": max(15, max(len(id) for id in self.components.keys())), + "name": max(15, max(len(name) for name in simple_names)), "class": max(25, max(len(component.__class__.__name__) for component in self.components.values())), - "device": 10, + "device": 15, # Reduced since using more compact format "dtype": 15, "size": 10, + "load_id": max_load_id_len, + "collection": max(10, max(len(str(c)) for c in collection_names)) } # Create the header lines @@ -475,17 +679,23 @@ def __repr__(self): if models: output += "Models:\n" + dash_line # Column headers - output += f"{'Model ID':<{col_widths['id']}} | {'Class':<{col_widths['class']}} | " - output += f"{'Device':<{col_widths['device']}} | {'Dtype':<{col_widths['dtype']}} | Size (GB)\n" + output += f"{'Name':<{col_widths['name']}} | {'Class':<{col_widths['class']}} | " + output += f"{'Device':<{col_widths['device']}} | {'Dtype':<{col_widths['dtype']}} | " + output += f"{'Size (GB)':<{col_widths['size']}} | {'Load ID':<{col_widths['load_id']}} | Collection\n" output += dash_line # Model entries for name, component in models.items(): info = self.get_model_info(name) - device = str(getattr(component, "device", "N/A")) + simple_name = get_simple_name(name) + device_str = format_device(component, info) dtype = str(component.dtype) if hasattr(component, "dtype") else "N/A" - output += f"{name:<{col_widths['id']}} | {info['class_name']:<{col_widths['class']}} | " - output += f"{device:<{col_widths['device']}} | {dtype:<{col_widths['dtype']}} | {info['size_gb']:.2f}\n" + load_id = get_load_id(component) + collection = info["collection"] or "N/A" + + output += f"{simple_name:<{col_widths['name']}} | {info['class_name']:<{col_widths['class']}} | " + output += f"{device_str:<{col_widths['device']}} | {dtype:<{col_widths['dtype']}} | " + output += f"{info['size_gb']:<{col_widths['size']}.2f} | {load_id:<{col_widths['load_id']}} | {collection}\n" output += dash_line # Other components section @@ -494,12 +704,16 @@ def __repr__(self): output += "\n" output += "Other Components:\n" + dash_line # Column headers for other components - output += f"{'Component ID':<{col_widths['id']}} | {'Class':<{col_widths['class']}}\n" + output += f"{'Name':<{col_widths['name']}} | {'Class':<{col_widths['class']}} | Collection\n" output += dash_line # Other component entries for name, component in others.items(): - output += f"{name:<{col_widths['id']}} | {component.__class__.__name__:<{col_widths['class']}}\n" + info = self.get_model_info(name) + simple_name = get_simple_name(name) + collection = info["collection"] or "N/A" + + output += f"{simple_name:<{col_widths['name']}} | {component.__class__.__name__:<{col_widths['class']}} | {collection}\n" output += dash_line # Add additional component info @@ -507,7 +721,8 @@ def __repr__(self): for name in self.components: info = self.get_model_info(name) if info is not None and (info.get("adapters") is not None or info.get("ip_adapter")): - output += f"\n{name}:\n" + simple_name = get_simple_name(name) + output += f"\n{simple_name}:\n" if info.get("adapters") is not None: output += f" Adapters: {info['adapters']}\n" if info.get("ip_adapter"): @@ -516,7 +731,7 @@ def __repr__(self): return output - def add_from_pretrained(self, pretrained_model_name_or_path, prefix: Optional[str] = None, **kwargs): + def from_pretrained(self, pretrained_model_name_or_path, prefix: Optional[str] = None, **kwargs): """ Load components from a pretrained model and add them to the manager. @@ -526,17 +741,12 @@ def add_from_pretrained(self, pretrained_model_name_or_path, prefix: Optional[st If provided, components will be named as "{prefix}_{component_name}" **kwargs: Additional arguments to pass to DiffusionPipeline.from_pretrained() """ - from ..pipelines.pipeline_utils import DiffusionPipeline - - pipe = DiffusionPipeline.from_pretrained(pretrained_model_name_or_path, **kwargs) - for name, component in pipe.components.items(): - - if component is None: - continue - - # Add prefix if specified - component_name = f"{prefix}_{name}" if prefix else name - + subfolder = kwargs.pop("subfolder", None) + # YiYi TODO: extend AutoModel to support non-diffusers models + if subfolder: + from ..models import AutoModel + component = AutoModel.from_pretrained(pretrained_model_name_or_path, subfolder=subfolder, **kwargs) + component_name = f"{prefix}_{subfolder}" if prefix else subfolder if component_name not in self.components: self.add(component_name, component) else: @@ -545,6 +755,50 @@ def add_from_pretrained(self, pretrained_model_name_or_path, prefix: Optional[st f"1. remove the existing component with remove('{component_name}')\n" f"2. Use a different prefix: add_from_pretrained(..., prefix='{prefix}_2')" ) + else: + from ..pipelines.pipeline_utils import DiffusionPipeline + pipe = DiffusionPipeline.from_pretrained(pretrained_model_name_or_path, **kwargs) + for name, component in pipe.components.items(): + + if component is None: + continue + + # Add prefix if specified + component_name = f"{prefix}_{name}" if prefix else name + + if component_name not in self.components: + self.add(component_name, component) + else: + logger.warning( + f"Component '{component_name}' already exists in ComponentsManager and will not be added. To add it, either:\n" + f"1. remove the existing component with remove('{component_name}')\n" + f"2. Use a different prefix: add_from_pretrained(..., prefix='{prefix}_2')" + ) + + def get_one(self, name: Optional[str] = None, collection: Optional[str] = None, load_id: Optional[str] = None) -> Any: + """ + Get a single component by name. Raises an error if multiple components match or none are found. + + Args: + name: Component name or pattern + collection: Optional collection to filter by + load_id: Optional load_id to filter by + + Returns: + A single component + + Raises: + ValueError: If no components match or multiple components match + """ + results = self.get(name, collection, load_id) + + if not results: + raise ValueError(f"No components found matching '{name}'") + + if len(results) > 1: + raise ValueError(f"Multiple components found matching '{name}': {list(results.keys())}") + + return next(iter(results.values())) def summarize_dict_by_value_and_parts(d: Dict[str, Any]) -> Dict[str, Any]: """Summarizes a dictionary by finding common prefixes that share the same value. diff --git a/src/diffusers/pipelines/modular_pipeline.py b/src/diffusers/pipelines/modular_pipeline.py index 785f38cdbf8c..636b543395df 100644 --- a/src/diffusers/pipelines/modular_pipeline.py +++ b/src/diffusers/pipelines/modular_pipeline.py @@ -22,25 +22,45 @@ import torch from tqdm.auto import tqdm import re +import os +import importlib -from ..configuration_utils import ConfigMixin +from huggingface_hub.utils import validate_hf_hub_args + +from ..configuration_utils import ConfigMixin, FrozenDict from ..utils import ( is_accelerate_available, is_accelerate_version, logging, + PushToHubMixin, ) -from .pipeline_loading_utils import _get_pipeline_class - +from .pipeline_loading_utils import _get_pipeline_class, simple_get_class_obj,_fetch_class_library_tuple +from .modular_pipeline_utils import ( + ComponentSpec, + ConfigSpec, + InputParam, + OutputParam, + format_components, + format_configs, + format_input_params, + format_inputs_short, + format_intermediates_short, + format_output_params, + format_params, + make_doc_string, +) +from .components_manager import ComponentsManager +from copy import deepcopy if is_accelerate_available(): import accelerate logger = logging.get_logger(__name__) # pylint: disable=invalid-name -MODULAR_PIPELINE_MAPPING = OrderedDict( +MODULAR_LOADER_MAPPING = OrderedDict( [ - ("stable-diffusion-xl", "StableDiffusionXLModularPipeline"), + ("stable-diffusion-xl", "StableDiffusionXLModularLoader"), ] ) @@ -138,236 +158,116 @@ def format_value(v): return f"BlockState(\n{attributes}\n)" -@dataclass -class ComponentSpec: - """Specification for a pipeline component.""" - name: str - type_hint: Type - description: Optional[str] = None - obj: Any = None # you can create a default component if it is a stateless class like scheduler, guider or image processor - default_class_name: Union[str, List[str], Tuple[str, str]] = None # Either "class_name" or ["module", "class_name"] - default_repo: Optional[Union[str, List[str]]] = None # either "repo" or ["repo", "subfolder"] - -@dataclass -class ConfigSpec: - """Specification for a pipeline configuration parameter.""" - name: str - default: Any - description: Optional[str] = None - - -@dataclass -class InputParam: - name: str - type_hint: Any = None - default: Any = None - required: bool = False - description: str = "" - - def __repr__(self): - return f"<{self.name}: {'required' if self.required else 'optional'}, default={self.default}>" -@dataclass -class OutputParam: - name: str - type_hint: Any = None - description: str = "" - - def __repr__(self): - return f"<{self.name}: {self.type_hint.__name__ if hasattr(self.type_hint, '__name__') else str(self.type_hint)}>" - -def format_inputs_short(inputs): +class ModularPipelineMixin: """ - Format input parameters into a string representation, with required params first followed by optional ones. - - Args: - inputs: List of input parameters with 'required' and 'name' attributes, and 'default' for optional params - - Returns: - str: Formatted string of input parameters - - Example: - >>> inputs = [ - ... InputParam(name="prompt", required=True), - ... InputParam(name="image", required=True), - ... InputParam(name="guidance_scale", required=False, default=7.5), - ... InputParam(name="num_inference_steps", required=False, default=50) - ... ] - >>> format_inputs_short(inputs) - 'prompt, image, guidance_scale=7.5, num_inference_steps=50' + Mixin for all PipelineBlocks: PipelineBlock, AutoPipelineBlocks, SequentialPipelineBlocks """ - required_inputs = [param for param in inputs if param.required] - optional_inputs = [param for param in inputs if not param.required] - - required_str = ", ".join(param.name for param in required_inputs) - optional_str = ", ".join(f"{param.name}={param.default}" for param in optional_inputs) - inputs_str = required_str - if optional_str: - inputs_str = f"{inputs_str}, {optional_str}" if required_str else optional_str - - return inputs_str + def setup_loader(self, modular_repo: Optional[Union[str, os.PathLike]] = None, component_manager: Optional[ComponentsManager] = None, collection: Optional[str] = None): + """ + create a mouldar loader, optionally accept modular_repo to load from hub. + """ -def format_intermediates_short(intermediates_inputs: List[InputParam], required_intermediates_inputs: List[str], intermediates_outputs: List[OutputParam]) -> str: - """ - Formats intermediate inputs and outputs of a block into a string representation. - - Args: - intermediates_inputs: List of intermediate input parameters - required_intermediates_inputs: List of required intermediate input names - intermediates_outputs: List of intermediate output parameters - - Returns: - str: Formatted string like: - Intermediates: - - inputs: Required(latents), dtype - - modified: latents # variables that appear in both inputs and outputs - - outputs: images # new outputs only - """ - # Handle inputs - input_parts = [] - for inp in intermediates_inputs: - if inp.name in required_intermediates_inputs: - input_parts.append(f"Required({inp.name})") - else: - input_parts.append(inp.name) - - # Handle modified variables (appear in both inputs and outputs) - inputs_set = {inp.name for inp in intermediates_inputs} - modified_parts = [] - new_output_parts = [] - - for out in intermediates_outputs: - if out.name in inputs_set: - modified_parts.append(out.name) - else: - new_output_parts.append(out.name) - - result = [] - if input_parts: - result.append(f" - inputs: {', '.join(input_parts)}") - if modified_parts: - result.append(f" - modified: {', '.join(modified_parts)}") - if new_output_parts: - result.append(f" - outputs: {', '.join(new_output_parts)}") + # Import components loader (it is model-specific class) + loader_class_name = MODULAR_LOADER_MAPPING[self.model_name] + diffusers_module = importlib.import_module("diffusers") + loader_class = getattr(diffusers_module, loader_class_name) + + # Create deep copies to avoid modifying the original specs + component_specs = deepcopy(self.expected_components) + config_specs = deepcopy(self.expected_configs) + # Create the loader with the updated specs + specs = component_specs + config_specs - return "\n".join(result) if result else " (none)" + self.loader = loader_class(specs, modular_repo=modular_repo, component_manager=component_manager, collection=collection) -def format_params(params: List[Union[InputParam, OutputParam]], header: str = "Args", indent_level: int = 4, max_line_length: int = 115) -> str: - """Format a list of InputParam or OutputParam objects into a readable string representation. + @property + def default_call_parameters(self) -> Dict[str, Any]: + params = {} + for input_param in self.inputs: + params[input_param.name] = input_param.default + return params - Args: - params: List of InputParam or OutputParam objects to format - header: Header text to use (e.g. "Args" or "Returns") - indent_level: Number of spaces to indent each parameter line (default: 4) - max_line_length: Maximum length for each line before wrapping (default: 115) + def run(self, state: PipelineState = None, output: Union[str, List[str]] = None, **kwargs): + """ + Run one or more blocks in sequence, optionally you can pass a previous pipeline state. + """ + if state is None: + state = PipelineState() - Returns: - A formatted string representing all parameters - """ - if not params: - return "" - - base_indent = " " * indent_level - param_indent = " " * (indent_level + 4) - desc_indent = " " * (indent_level + 8) - formatted_params = [] - - def get_type_str(type_hint): - if hasattr(type_hint, "__origin__") and type_hint.__origin__ is Union: - types = [t.__name__ if hasattr(t, "__name__") else str(t) for t in type_hint.__args__] - return f"Union[{', '.join(types)}]" - return type_hint.__name__ if hasattr(type_hint, "__name__") else str(type_hint) - - def wrap_text(text: str, indent: str, max_length: int) -> str: - """Wrap text while preserving markdown links and maintaining indentation.""" - words = text.split() - lines = [] - current_line = [] - current_length = 0 - - for word in words: - word_length = len(word) + (1 if current_line else 0) - - if current_line and current_length + word_length > max_length: - lines.append(" ".join(current_line)) - current_line = [word] - current_length = len(word) - else: - current_line.append(word) - current_length += word_length - - if current_line: - lines.append(" ".join(current_line)) - - return f"\n{indent}".join(lines) - - # Add the header - formatted_params.append(f"{base_indent}{header}:") - - for param in params: - # Format parameter name and type - type_str = get_type_str(param.type_hint) if param.type_hint != Any else "" - param_str = f"{param_indent}{param.name} (`{type_str}`" - - # Add optional tag and default value if parameter is an InputParam and optional - if isinstance(param, InputParam): - if not param.required: - param_str += ", *optional*" - if param.default is not None: - param_str += f", defaults to {param.default}" - param_str += "):" - - # Add description on a new line with additional indentation and wrapping - if param.description: - desc = re.sub( - r'\[(.*?)\]\((https?://[^\s\)]+)\)', - r'[\1](\2)', - param.description - ) - wrapped_desc = wrap_text(desc, desc_indent, max_line_length) - param_str += f"\n{desc_indent}{wrapped_desc}" - - formatted_params.append(param_str) - - return "\n\n".join(formatted_params) + if not hasattr(self, "loader"): + raise ValueError("Loader is not set, please call `setup_loader()` first.") -# Then update the original functions to use this combined version: -def format_input_params(input_params: List[InputParam], indent_level: int = 4, max_line_length: int = 115) -> str: - return format_params(input_params, "Args", indent_level, max_line_length) + # Make a copy of the input kwargs + input_params = kwargs.copy() -def format_output_params(output_params: List[OutputParam], indent_level: int = 4, max_line_length: int = 115) -> str: - return format_params(output_params, "Returns", indent_level, max_line_length) + default_params = self.default_call_parameters + # Add inputs to state, using defaults if not provided in the kwargs or the state + # if same input already in the state, will override it if provided in the kwargs + intermediates_inputs = [inp.name for inp in self.intermediates_inputs] + for name, default in default_params.items(): + if name in input_params: + if name not in intermediates_inputs: + state.add_input(name, input_params.pop(name)) + else: + state.add_input(name, input_params[name]) + elif name not in state.inputs: + state.add_input(name, default) -def make_doc_string(inputs, intermediates_inputs, outputs, description=""): - """ - Generates a formatted documentation string describing the pipeline block's parameters and structure. - - Returns: - str: A formatted string containing information about call parameters, intermediate inputs/outputs, - and final intermediate outputs. - """ - output = "" + for name in intermediates_inputs: + if name in input_params: + state.add_intermediate(name, input_params.pop(name)) + + # Warn about unexpected inputs + if len(input_params) > 0: + logger.warning(f"Unexpected input '{input_params.keys()}' provided. This input will be ignored.") + # Run the pipeline + with torch.no_grad(): + try: + pipeline, state = self(self.loader, state) + except Exception: + error_msg = f"Error in block: ({self.__class__.__name__}):\n" + logger.error(error_msg) + raise - if description: - desc_lines = description.strip().split('\n') - aligned_desc = '\n'.join(' ' + line for line in desc_lines) - output += aligned_desc + "\n\n" + if output is None: + return state - output += format_input_params(inputs + intermediates_inputs, indent_level=2) - - output += "\n\n" - output += format_output_params(outputs, indent_level=2) - return output + elif isinstance(output, str): + return state.get_intermediate(output) + elif isinstance(output, (list, tuple)): + return state.get_intermediates(output) + else: + raise ValueError(f"Output '{output}' is not a valid output type") + @torch.compiler.disable + def progress_bar(self, iterable=None, total=None): + if not hasattr(self, "_progress_bar_config"): + self._progress_bar_config = {} + elif not isinstance(self._progress_bar_config, dict): + raise ValueError( + f"`self._progress_bar_config` should be of type `dict`, but is {type(self._progress_bar_config)}." + ) -class PipelineBlock: + if iterable is not None: + return tqdm(iterable, **self._progress_bar_config) + elif total is not None: + return tqdm(total=total, **self._progress_bar_config) + else: + raise ValueError("Either `total` or `iterable` has to be defined.") + + def set_progress_bar_config(self, **kwargs): + self._progress_bar_config = kwargs + + +class PipelineBlock(ModularPipelineMixin): model_name = None @@ -440,31 +340,15 @@ def __repr__(self): desc.extend(f" {line}" for line in desc_lines[1:]) desc = '\n'.join(desc) + '\n' - # Components section - focus only on expected components + # Components section - use format_components with add_empty_lines=False expected_components = getattr(self, "expected_components", []) - expected_components_str_list = [] - - for component_spec in expected_components: - component_str = f" - {component_spec.name} ({component_spec.type_hint})" - - # Add repo info if available - if component_spec.default_repo: - if isinstance(component_spec.default_repo, list) and len(component_spec.default_repo) == 2: - repo_info = component_spec.default_repo[0] - subfolder = component_spec.default_repo[1] - if subfolder: - repo_info += f", subfolder={subfolder}" - else: - repo_info = component_spec.default_repo - component_str += f" [{repo_info}]" - - expected_components_str_list.append(component_str) + components_str = format_components(expected_components, indent_level=2, add_empty_lines=False) + components = " " + components_str.replace("\n", "\n ") - components = "Components:\n" + "\n".join(expected_components_str_list) - - # Configs section - focus only on expected configs + # Configs section - use format_configs with add_empty_lines=False expected_configs = getattr(self, "expected_configs", []) - configs = "Configs:\n" + "\n".join(f" - {k}" for k in sorted(expected_configs)) + configs_str = format_configs(expected_configs, indent_level=2, add_empty_lines=False) + configs = " " + configs_str.replace("\n", "\n ") # Inputs section inputs_str = format_inputs_short(self.inputs) @@ -478,8 +362,8 @@ def __repr__(self): f"{class_name}(\n" f" Class: {base_class}\n" f"{desc}" - f" {components}\n" - f" {configs}\n" + f"{components}\n" + f"{configs}\n" f" {inputs}\n" f" {intermediates}\n" f")" @@ -488,7 +372,15 @@ def __repr__(self): @property def doc(self): - return make_doc_string(self.inputs, self.intermediates_inputs, self.outputs, self.description) + return make_doc_string( + self.inputs, + self.intermediates_inputs, + self.outputs, + self.description, + class_name=self.__class__.__name__, + expected_components=self.expected_components, + expected_configs=self.expected_configs + ) def get_block_state(self, state: PipelineState) -> dict: @@ -575,7 +467,7 @@ def combine_outputs(*named_output_lists: List[Tuple[str, List[OutputParam]]]) -> return list(combined_dict.values()) -class AutoPipelineBlocks: +class AutoPipelineBlocks(ModularPipelineMixin): """ A class that automatically selects a block to run based on the inputs. @@ -796,32 +688,13 @@ def __repr__(self): # Components section - focus only on expected components expected_components = getattr(self, "expected_components", []) - expected_components_str_list = [] + components_str = format_components(expected_components, indent_level=2, add_empty_lines=False) - for component_spec in expected_components: - - component_str = f" - {component_spec.name} ({component_spec.type_hint.__name__})" - - # Add repo info if available - if component_spec.default_repo: - if isinstance(component_spec.default_repo, list) and len(component_spec.default_repo) == 2: - repo_info = component_spec.default_repo[0] - subfolder = component_spec.default_repo[1] - if subfolder: - repo_info += f", subfolder={subfolder}" - else: - repo_info = component_spec.default_repo - component_str += f" [{repo_info}]" - - expected_components_str_list.append(component_str) - - components_str = " Components:\n" + "\n".join(expected_components_str_list) - - # Configs section - focus only on expected configs + # Configs section - use format_configs with add_empty_lines=False expected_configs = getattr(self, "expected_configs", []) - configs_str = " Configs:\n" + "\n".join(f" - {config.name}" for config in sorted(expected_configs, key=lambda x: x.name)) + configs_str = format_configs(expected_configs, indent_level=2, add_empty_lines=False) - # Blocks section + # Blocks section - moved to the end with simplified format blocks_str = " Blocks:\n" for i, (name, block) in enumerate(self.blocks.items()): # Get trigger input for this block @@ -846,54 +719,31 @@ def __repr__(self): indented_desc = desc_lines[0] if len(desc_lines) > 1: indented_desc += '\n' + '\n'.join(' ' + line for line in desc_lines[1:]) - blocks_str += f" Description: {indented_desc}\n" - - # Format inputs - inputs_str = format_inputs_short(block.inputs) - blocks_str += f" inputs: {inputs_str}\n" - - # Format intermediates - intermediates_str = format_intermediates_short( - block.intermediates_inputs, - block.required_intermediates_inputs, - block.intermediates_outputs - ) - if intermediates_str != " (none)": - blocks_str += " intermediates:\n" - indented_intermediates = "\n".join( - " " + line for line in intermediates_str.split("\n") - ) - blocks_str += f"{indented_intermediates}\n" - blocks_str += "\n" - - # Inputs and outputs section - inputs_str = format_inputs_short(self.inputs) - inputs_str = " Inputs:\n " + inputs_str - outputs = [out.name for out in self.outputs] - - intermediates_str = format_intermediates_short(self.intermediates_inputs, self.required_intermediates_inputs, self.intermediates_outputs) - intermediates_str = ( - "\n Intermediates:\n" - f"{intermediates_str}\n" - f" - final outputs: {', '.join(outputs)}" - ) + blocks_str += f" Description: {indented_desc}\n\n" return ( f"{header}\n" - f"{desc}" - f"{components_str}\n" - f"{configs_str}\n" - f"{blocks_str}\n" - f"{inputs_str}\n" - f"{intermediates_str}\n" + f"{desc}\n\n" + f"{components_str}\n\n" + f"{configs_str}\n\n" + f"{blocks_str}" f")" ) + @property def doc(self): - return make_doc_string(self.inputs, self.intermediates_inputs, self.outputs, self.description) + return make_doc_string( + self.inputs, + self.intermediates_inputs, + self.outputs, + self.description, + class_name=self.__class__.__name__, + expected_components=self.expected_components, + expected_configs=self.expected_configs + ) -class SequentialPipelineBlocks: +class SequentialPipelineBlocks(ModularPipelineMixin): """ A class that combines multiple pipeline block classes into one. When called, it will call each block in sequence. """ @@ -1168,32 +1018,13 @@ def __repr__(self): # Components section - focus only on expected components expected_components = getattr(self, "expected_components", []) - expected_components_str_list = [] + components_str = format_components(expected_components, indent_level=2, add_empty_lines=False) - for component_spec in expected_components: - - component_str = f" - {component_spec.name} ({component_spec.type_hint.__name__})" - - # Add repo info if available - if component_spec.default_repo: - if isinstance(component_spec.default_repo, list) and len(component_spec.default_repo) == 2: - repo_info = component_spec.default_repo[0] - subfolder = component_spec.default_repo[1] - if subfolder: - repo_info += f", subfolder={subfolder}" - else: - repo_info = component_spec.default_repo - component_str += f" [{repo_info}]" - - expected_components_str_list.append(component_str) - - components_str = " Components:\n" + "\n".join(expected_components_str_list) - - # Configs section - focus only on expected configs + # Configs section - use format_configs with add_empty_lines=False expected_configs = getattr(self, "expected_configs", []) - configs_str = " Configs:\n" + "\n".join(f" - {config.name}" for config in sorted(expected_configs, key=lambda x: x.name)) + configs_str = format_configs(expected_configs, indent_level=2, add_empty_lines=False) - # Blocks section + # Blocks section - moved to the end with simplified format blocks_str = " Blocks:\n" for i, (name, block) in enumerate(self.blocks.items()): # Get trigger input for this block @@ -1218,85 +1049,172 @@ def __repr__(self): indented_desc = desc_lines[0] if len(desc_lines) > 1: indented_desc += '\n' + '\n'.join(' ' + line for line in desc_lines[1:]) - blocks_str += f" Description: {indented_desc}\n" - - # Format inputs - inputs_str = format_inputs_short(block.inputs) - blocks_str += f" inputs: {inputs_str}\n" - - # Format intermediates - intermediates_str = format_intermediates_short( - block.intermediates_inputs, - block.required_intermediates_inputs, - block.intermediates_outputs - ) - if intermediates_str != " (none)": - blocks_str += " intermediates:\n" - indented_intermediates = "\n".join( - " " + line for line in intermediates_str.split("\n") - ) - blocks_str += f"{indented_intermediates}\n" - blocks_str += "\n" - - # Inputs and outputs section - inputs_str = format_inputs_short(self.inputs) - inputs_str = " Inputs:\n " + inputs_str - outputs = [out.name for out in self.outputs] - - intermediates_str = format_intermediates_short(self.intermediates_inputs, self.required_intermediates_inputs, self.intermediates_outputs) - intermediates_str = ( - "\n Intermediates:\n" - f"{intermediates_str}\n" - f" - final outputs: {', '.join(outputs)}" - ) + blocks_str += f" Description: {indented_desc}\n\n" return ( f"{header}\n" - f"{desc}" - f"{components_str}\n" - f"{configs_str}\n" - f"{blocks_str}\n" - f"{inputs_str}\n" - f"{intermediates_str}\n" + f"{desc}\n\n" + f"{components_str}\n\n" + f"{configs_str}\n\n" + f"{blocks_str}" f")" ) @property def doc(self): - return make_doc_string(self.inputs, self.intermediates_inputs, self.outputs, self.description) + return make_doc_string( + self.inputs, + self.intermediates_inputs, + self.outputs, + self.description, + class_name=self.__class__.__name__, + expected_components=self.expected_components, + expected_configs=self.expected_configs + ) -class ModularPipeline(ConfigMixin): + + +# YiYi TODO: +# 1. look into the serialization of modular_model_index.json, make sure the items are properly ordered like model_index.json (currently a mess) +# 2. do we need ConfigSpec? seems pretty unnecessrary for loader, can just add and kwargs to the loader +# 3. add validator for methods where we accpet kwargs to be passed to from_pretrained() +class ModularLoader(ConfigMixin, PushToHubMixin): """ - Base class for all Modular pipelines. + Base class for all Modular pipelines loaders. """ + config_name = "modular_model_index.json" + + + def register_components(self, **kwargs): + """ + Register components with their corresponding specs. + This method is called when component changed or __init__ is called. - config_name = "model_index.json" - _exclude_from_cpu_offload = [] + Args: + **kwargs: Keyword arguments where keys are component names and values are component objects. + + """ + for name, module in kwargs.items(): + + # current component spec + component_spec = self._component_specs.get(name) + if component_spec is None: + logger.warning(f"ModularLoader.register_components: skipping unknown component '{name}'") + continue + + is_registered = hasattr(self, name) - def __init__(self, block): - self.pipeline_block = block + if module is not None and not hasattr(module, "_diffusers_load_id"): + raise ValueError(f"`ModularLoader` only supports components created from `ComponentSpec`.") - for component_spec in self.expected_components: - if component_spec.obj is not None: - setattr(self, component_spec.name, component_spec.obj) + # actual library and class name of the module + + if module is not None: + library, class_name = _fetch_class_library_tuple(module) + new_component_spec = ComponentSpec.from_component(name, module) + component_spec_dict = self._component_spec_to_dict(new_component_spec) + + else: + library, class_name = None, None + # if module is None, we do not update the spec, + # but we still need to update the config to make sure it's synced with the component spec + # (in the case of the first time registration, we initilize the object with component spec, and then we call register_components() to register it to config) + new_component_spec = component_spec + component_spec_dict = self._component_spec_to_dict(component_spec) + + # do not register if component is not to be loaded from pretrained + if new_component_spec.default_creation_method == "from_pretrained": + register_dict = {name: (library, class_name, component_spec_dict)} else: - setattr(self, component_spec.name, None) + register_dict = {} + + # set the component as attribute + # if it is not set yet, just set it and skip the process to check and warn below + if not is_registered: + self.register_to_config(**register_dict) + self._component_specs[name] = new_component_spec + setattr(self, name, module) + if module is not None and self._component_manager is not None: + self._component_manager.add(name, module, self._collection) + continue + + current_module = getattr(self, name, None) + # skip if the component is already registered with the same object + if current_module is module: + logger.info(f"ModularLoader.register_components: {name} is already registered with same object, skipping") + continue + + # it module is not an instance of the expected type, still register it but with a warning + if module is not None and component_spec.type_hint is not None and not isinstance(module, component_spec.type_hint): + logger.warning(f"ModularLoader.register_components: adding {name} with new type: {module.__class__.__name__}, previous type: {component_spec.type_hint.__name__}") + + # warn if unregister + if current_module is not None and module is None: + logger.info( + f"ModularLoader.register_components: setting '{name}' to None " + f"(was {current_module.__class__.__name__})" + ) + # same type, new instance → debug + elif current_module is not None \ + and module is not None \ + and isinstance(module, current_module.__class__) \ + and current_module != module: + logger.debug( + f"ModularLoader.register_components: replacing existing '{name}' " + f"(same type {type(current_module).__name__}, new instance)" + ) + + # save modular_model_index.json config + self.register_to_config(**register_dict) + # update component spec + self._component_specs[name] = new_component_spec + # finally set models + setattr(self, name, module) + if module is not None and self._component_manager is not None: + self._component_manager.add(name, module, self._collection) + + + + # YiYi TODO: add warning for passing multiple ComponentSpec/ConfigSpec with the same name + def __init__(self, specs: List[Union[ComponentSpec, ConfigSpec]], modular_repo: Optional[str] = None, component_manager: Optional[ComponentsManager] = None, collection: Optional[str] = None, **kwargs): + """ + Initialize the loader with a list of component specs and config specs. + """ + self._component_manager = component_manager + self._collection = collection + self._component_specs = { + spec.name: deepcopy(spec) for spec in specs if isinstance(spec, ComponentSpec) + } + self._config_specs = { + spec.name: deepcopy(spec) for spec in specs if isinstance(spec, ConfigSpec) + } + + # update component_specs and config_specs from modular_repo + if modular_repo is not None: + config_dict = self.load_config(modular_repo, **kwargs) + + for name, value in config_dict.items(): + if name in self._component_specs and self._component_specs[name].default_creation_method == "from_pretrained" and isinstance(value, (tuple, list)) and len(value) == 3: + library, class_name, component_spec_dict = value + component_spec = self._dict_to_component_spec(name, component_spec_dict) + self._component_specs[name] = component_spec + + elif name in self._config_specs: + self._config_specs[name].default = value + + register_components_dict = {} + for name, component_spec in self._component_specs.items(): + register_components_dict[name] = None + self.register_components(**register_components_dict) default_configs = {} - for config_spec in self.expected_configs: - default_configs[config_spec.name] = config_spec.default + for name, config_spec in self._config_specs.items(): + default_configs[name] = config_spec.default self.register_to_config(**default_configs) - @classmethod - def from_block(cls, block): - modular_pipeline_class_name = MODULAR_PIPELINE_MAPPING[block.model_name] - modular_pipeline_class = _get_pipeline_class(cls, class_name=modular_pipeline_class_name) - - return modular_pipeline_class(block) - @property def device(self) -> torch.device: r""" @@ -1320,7 +1238,7 @@ def _execution_device(self): Accelerate's module hooks. """ for name, model in self.components.items(): - if not isinstance(model, torch.nn.Module) or name in self._exclude_from_cpu_offload: + if not isinstance(model, torch.nn.Module): continue if not hasattr(model, "_hf_hook"): @@ -1333,11 +1251,21 @@ def _execution_device(self): ): return torch.device(module._hf_hook.execution_device) return self.device - - - def get_execution_blocks(self, *trigger_inputs): - return self.pipeline_block.get_execution_blocks(*trigger_inputs) + @property + def device(self) -> torch.device: + r""" + Returns: + `torch.device`: The torch device on which the pipeline is located. + """ + + modules = [m for m in self.components.values() if isinstance(m, torch.nn.Module)] + + for module in modules: + return module.device + + return torch.device("cpu") + @property def dtype(self) -> torch.dtype: r""" @@ -1352,340 +1280,257 @@ def dtype(self) -> torch.dtype: return torch.float32 - @property - def expected_components(self): - return self.pipeline_block.expected_components - - @property - def expected_configs(self): - return self.pipeline_block.expected_configs @property - def components(self): - components = {} - for component_spec in self.expected_components: - if hasattr(self, component_spec.name): - components[component_spec.name] = getattr(self, component_spec.name) - return components - - # Copied from diffusers.pipelines.pipeline_utils.DiffusionPipeline.progress_bar - def progress_bar(self, iterable=None, total=None): - if not hasattr(self, "_progress_bar_config"): - self._progress_bar_config = {} - elif not isinstance(self._progress_bar_config, dict): - raise ValueError( - f"`self._progress_bar_config` should be of type `dict`, but is {type(self._progress_bar_config)}." - ) - - if iterable is not None: - return tqdm(iterable, **self._progress_bar_config) - elif total is not None: - return tqdm(total=total, **self._progress_bar_config) - else: - raise ValueError("Either `total` or `iterable` has to be defined.") - - # Copied from diffusers.pipelines.pipeline_utils.DiffusionPipeline.set_progress_bar_config - def set_progress_bar_config(self, **kwargs): - self._progress_bar_config = kwargs - - def __call__(self, state: PipelineState = None, output: Union[str, List[str]] = None, **kwargs): + def components(self) -> Dict[str, Any]: + # return only components we've actually set as attributes on self + return { + name: getattr(self, name) + for name in self._component_specs.keys() + if hasattr(self, name) + } + + def update(self, **kwargs): """ - Run one or more blocks in sequence, optionally you can pass a previous pipeline state. - """ - if state is None: - state = PipelineState() - - # Make a copy of the input kwargs - input_params = kwargs.copy() - - default_params = self.default_call_parameters + Update components and configs after instance creation. + + Args: - # Add inputs to state, using defaults if not provided in the kwargs or the state - # if same input already in the state, will override it if provided in the kwargs + """ + """ + Update components and configuration values after the loader has been instantiated. + + This method allows you to: + 1. Replace existing components with new ones (e.g., updating the unet or text_encoder) + 2. Update configuration values (e.g., changing requires_safety_checker flag) + + Args: + **kwargs: Component objects or configuration values to update: + - Component objects: Must be created using ComponentSpec (e.g., `unet=new_unet, text_encoder=new_encoder`) + - Configuration values: Simple values to update configuration settings (e.g., `requires_safety_checker=False`) + + Raises: + ValueError: If a component wasn't created using ComponentSpec (doesn't have `_diffusers_load_id` attribute) + + Examples: + ```python + # Update multiple components at once + loader.update( + unet=new_unet_model, + text_encoder=new_text_encoder + ) + + # Update configuration values + loader.update( + requires_safety_checker=False, + guidance_rescale=0.7 + ) + + # Update both components and configs together + loader.update( + unet=new_unet_model, + requires_safety_checker=False + ) + ``` + """ - intermediates_inputs = [inp.name for inp in self.pipeline_block.intermediates_inputs] - for name, default in default_params.items(): - if name in input_params: - if name not in intermediates_inputs: - state.add_input(name, input_params.pop(name)) - else: - state.add_input(name, input_params[name]) - elif name not in state.inputs: - state.add_input(name, default) + # extract component_specs_updates & config_specs_updates from `specs` + passed_components = {k: kwargs.pop(k) for k in self._component_specs if k in kwargs} + passed_config_values = {k: kwargs.pop(k) for k in self._config_specs if k in kwargs} - for name in intermediates_inputs: - if name in input_params: - state.add_intermediate(name, input_params.pop(name)) + for name, component in passed_components.items(): + if not hasattr(component, "_diffusers_load_id"): + raise ValueError(f"`ModularLoader` only supports components created from `ComponentSpec`.") + + if len(kwargs) > 0: + logger.warning(f"Unexpected keyword arguments, will be ignored: {kwargs.keys()}") + - # Warn about unexpected inputs - if len(input_params) > 0: - logger.warning(f"Unexpected input '{input_params.keys()}' provided. This input will be ignored.") - # Run the pipeline - with torch.no_grad(): - try: - pipeline, state = self.pipeline_block(self, state) - except Exception: - error_msg = f"Error in block: ({self.pipeline_block.__class__.__name__}):\n" - logger.error(error_msg) - raise + self.register_components(**passed_components) - if output is None: - return state + config_to_register = {} + for name, new_value in passed_config_values.items(): - elif isinstance(output, str): - return state.get_intermediate(output) + # e.g. requires_aesthetics_score = False + self._config_specs[name].default = new_value + config_to_register[name] = new_value + self.register_to_config(**config_to_register) - elif isinstance(output, (list, tuple)): - return state.get_intermediates(output) - else: - raise ValueError(f"Output '{output}' is not a valid output type") - def update_states(self, **kwargs): + # YiYi TODO: support map for additional from_pretrained kwargs + def load(self, component_names: Optional[List[str]] = None, **kwargs): """ - Update components and configs after instance creation. Auxiliaries (e.g. image_processor) should be defined for - each pipeline block, does not need to be updated by users. Logs if existing non-None components are being - overwritten. - + Load selectedcomponents from specs. + Args: - kwargs (dict): Keyword arguments to update the states. + component_names: List of component names to load + **kwargs: additional kwargs to be passed to `from_pretrained()`.Can be: + - a single value to be applied to all components to be loaded, e.g. torch_dtype=torch.bfloat16 + - a dict, e.g. torch_dtype={"unet": torch.bfloat16, "default": torch.float32} + - if potentially override ComponentSpec if passed a different loading field in kwargs, e.g. `repo`, `variant`, `revision`, etc. """ - - for component in self.expected_components: - if component.name in kwargs: - if hasattr(self, component.name) and getattr(self, component.name) is not None: - current_component = getattr(self, component.name) - new_component = kwargs[component.name] - - if not isinstance(new_component, current_component.__class__): - logger.info( - f"Overwriting existing component '{component.name}' " - f"(type: {current_component.__class__.__name__}) " - f"with type: {new_component.__class__.__name__})" - ) - elif isinstance(current_component, torch.nn.Module): - if id(current_component) != id(new_component): - logger.info( - f"Overwriting existing component '{component.name}' " - f"(type: {type(current_component).__name__}) " - f"with new value (type: {type(new_component).__name__})" - ) - - setattr(self, component.name, kwargs.pop(component.name)) - - configs_to_add = {} - for config in self.expected_configs: - if config.name in kwargs: - configs_to_add[config.name] = kwargs.pop(config.name) - self.register_to_config(**configs_to_add) - - @property - def default_call_parameters(self) -> Dict[str, Any]: - params = {} - for input_param in self.pipeline_block.inputs: - params[input_param.name] = input_param.default - return params - - # def __repr__(self): - # output = "ModularPipeline:\n" - # output += "==============================\n\n" - - # block = self.pipeline_block + if component_names is None: + component_names = list(self._component_specs.keys()) + elif not isinstance(component_names, list): + component_names = [component_names] + + components_to_load = set([name for name in component_names if name in self._component_specs]) + unknown_component_names = set([name for name in component_names if name not in self._component_specs]) + if len(unknown_component_names) > 0: + logger.warning(f"Unknown components will be ignored: {unknown_component_names}") - # # List the pipeline block structure first - # output += "Pipeline Block:\n" - # output += "--------------\n" - # if hasattr(block, "blocks"): - # output += f"{block.__class__.__name__}\n" - # base_class = block.__class__.__bases__[0].__name__ - # output += f" (Class: {base_class})\n" if base_class != "object" else "\n" - # for sub_block_name, sub_block in block.blocks.items(): - # if hasattr(block, "block_trigger_inputs"): - # trigger_input = block.block_to_trigger_map[sub_block_name] - # trigger_info = f" [trigger: {trigger_input}]" if trigger_input is not None else " [default]" - # output += f" • {sub_block_name} ({sub_block.__class__.__name__}){trigger_info}\n" - # else: - # output += f" • {sub_block_name} ({sub_block.__class__.__name__})\n" - # else: - # output += f"{block.__class__.__name__}\n" - # output += "\n" - - # # List the components registered in the pipeline - # output += "Registered Components:\n" - # output += "----------------------\n" - # for name, component in self.components.items(): - # output += f"{name}: {type(component).__name__}" - # if hasattr(component, "dtype") and hasattr(component, "device"): - # output += f" (dtype={component.dtype}, device={component.device})" - # output += "\n" - # output += "\n" - - # # List the configs registered in the pipeline - # output += "Registered Configs:\n" - # output += "------------------\n" - # for name, config in self.config.items(): - # output += f"{name}: {config!r}\n" - # output += "\n" - - # # Add auto blocks section - # if hasattr(block, "trigger_inputs") and block.trigger_inputs: - # output += "------------------\n" - # output += "This pipeline contains blocks that are selected at runtime based on inputs.\n\n" - # output += f"Trigger Inputs: {block.trigger_inputs}\n" - # # Get first trigger input as example - # example_input = next(t for t in block.trigger_inputs if t is not None) - # output += f" Use `get_execution_blocks()` with input names to see selected blocks (e.g. `get_execution_blocks('{example_input}')`).\n" - # output += "Check `.doc` of returned object for more information.\n\n" - - # # List the call parameters - # full_doc = self.pipeline_block.doc - # if "------------------------" in full_doc: - # full_doc = full_doc.split("------------------------")[0].rstrip() - # output += full_doc - - # return output - - # YiYi TODO: try to unify the to method with the one in DiffusionPipeline - # Modified from diffusers.pipelines.pipeline_utils.DiffusionPipeline.to - def to(self, *args, **kwargs): - r""" - Performs Pipeline dtype and/or device conversion. A torch.dtype and torch.device are inferred from the - arguments of `self.to(*args, **kwargs).` - - - - If the pipeline already has the correct torch.dtype and torch.device, then it is returned as is. Otherwise, - the returned pipeline is a copy of self with the desired torch.dtype and torch.device. - - + components_to_register = {} + for name in components_to_load: + spec = self._component_specs[name] + component_load_kwargs = {} + for key, value in kwargs.items(): + if not isinstance(value, dict): + # if the value is a single value, apply it to all components + component_load_kwargs[key] = value + else: + if name in value: + # if it is a dict, check if the component name is in the dict + component_load_kwargs[key] = value[name] + elif "default" in value: + # check if the default is specified + component_load_kwargs[key] = value["default"] + try: + components_to_register[name] = spec.create(**component_load_kwargs) + except Exception as e: + logger.warning(f"Failed to create component '{name}': {e}") + + # Register all components at once + self.register_components(**components_to_register) + # YiYi TODO: should support to method + def to(self, *args, **kwargs): + pass + + # YiYi TODO: + # 1. should support save some components too! currently only modular_model_index.json is saved + # 2. maybe order the json file to make it more readable: configs first, then components + def save_pretrained(self, save_directory: Union[str, os.PathLike], push_to_hub: bool = False, spec_only: bool = True, **kwargs): + + component_names = list(self._component_specs.keys()) + config_names = list(self._config_specs.keys()) + self.register_to_config(_components_names=component_names, _configs_names=config_names) + self.save_config(save_directory=save_directory, push_to_hub=push_to_hub, **kwargs) + config = dict(self.config) + config.pop("_components_names", None) + config.pop("_configs_names", None) + self._internal_dict = FrozenDict(config) - Here are the ways to call `to`: + + @classmethod + @validate_hf_hub_args + def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], spec_only: bool = True, **kwargs): + + config_dict = cls.load_config(pretrained_model_name_or_path, **kwargs) + expected_component = set(config_dict.pop("_components_names")) + expected_config = set(config_dict.pop("_configs_names")) + + component_specs = [] + config_specs = [] + for name, value in config_dict.items(): + if name in expected_component and isinstance(value, (tuple, list)) and len(value) == 3: + library, class_name, component_spec_dict = value + component_spec = cls._dict_to_component_spec(name, component_spec_dict) + component_specs.append(component_spec) + + elif name in expected_config: + config_specs.append(ConfigSpec(name=name, default=value)) + + for name in expected_component: + for spec in component_specs: + if spec.name == name: + break + else: + # append a empty component spec for these not in modular_model_index + component_specs.append(ComponentSpec(name=name, default_creation_method="from_config")) + return cls(component_specs + config_specs) - - `to(dtype, silence_dtype_warnings=False) → DiffusionPipeline` to return a pipeline with the specified - [`dtype`](https://pytorch.org/docs/stable/tensor_attributes.html#torch.dtype) - - `to(device, silence_dtype_warnings=False) → DiffusionPipeline` to return a pipeline with the specified - [`device`](https://pytorch.org/docs/stable/tensor_attributes.html#torch.device) - - `to(device=None, dtype=None, silence_dtype_warnings=False) → DiffusionPipeline` to return a pipeline with the - specified [`device`](https://pytorch.org/docs/stable/tensor_attributes.html#torch.device) and - [`dtype`](https://pytorch.org/docs/stable/tensor_attributes.html#torch.dtype) + + @staticmethod + def _component_spec_to_dict(component_spec: ComponentSpec) -> Any: + """ + Convert a ComponentSpec into a JSON‐serializable dict for saving in + `modular_model_index.json`. + + This dict contains: + - "type_hint": Tuple[str, str] + The fully‐qualified module path and class name of the component. + - All loading fields defined by `component_spec.loading_fields()`, typically: + - "repo": Optional[str] + The model repository (e.g., "stabilityai/stable-diffusion-xl"). + - "subfolder": Optional[str] + A subfolder within the repo where this component lives. + - "variant": Optional[str] + An optional variant identifier for the model. + - "revision": Optional[str] + A specific git revision (commit hash, tag, or branch). + - ... any other loading fields defined on the spec. - Arguments: - dtype (`torch.dtype`, *optional*): - Returns a pipeline with the specified - [`dtype`](https://pytorch.org/docs/stable/tensor_attributes.html#torch.dtype) - device (`torch.Device`, *optional*): - Returns a pipeline with the specified - [`device`](https://pytorch.org/docs/stable/tensor_attributes.html#torch.device) - silence_dtype_warnings (`str`, *optional*, defaults to `False`): - Whether to omit warnings if the target `dtype` is not compatible with the target `device`. + Args: + component_spec (ComponentSpec): + The spec object describing one pipeline component. Returns: - [`DiffusionPipeline`]: The pipeline converted to specified `dtype` and/or `dtype`. + Dict[str, Any]: A mapping suitable for JSON serialization. + + Example: + >>> from diffusers.pipelines.modular_pipeline_utils import ComponentSpec + >>> from diffusers.models.unet import UNet2DConditionModel + >>> spec = ComponentSpec( + ... name="unet", + ... type_hint=UNet2DConditionModel, + ... config=None, + ... repo="path/to/repo", + ... subfolder="subfolder", + ... variant=None, + ... revision=None, + ... default_creation_method="from_pretrained", + ... ) + >>> ModularLoader._component_spec_to_dict(spec) + { + "type_hint": ("diffusers.models.unet", "UNet2DConditionModel"), + "repo": "path/to/repo", + "subfolder": "subfolder", + "variant": None, + "revision": None, + } """ - dtype = kwargs.pop("dtype", None) - device = kwargs.pop("device", None) - silence_dtype_warnings = kwargs.pop("silence_dtype_warnings", False) - - dtype_arg = None - device_arg = None - if len(args) == 1: - if isinstance(args[0], torch.dtype): - dtype_arg = args[0] - else: - device_arg = torch.device(args[0]) if args[0] is not None else None - elif len(args) == 2: - if isinstance(args[0], torch.dtype): - raise ValueError( - "When passing two arguments, make sure the first corresponds to `device` and the second to `dtype`." - ) - device_arg = torch.device(args[0]) if args[0] is not None else None - dtype_arg = args[1] - elif len(args) > 2: - raise ValueError("Please make sure to pass at most two arguments (`device` and `dtype`) `.to(...)`") - - if dtype is not None and dtype_arg is not None: - raise ValueError( - "You have passed `dtype` both as an argument and as a keyword argument. Please only pass one of the two." - ) - - dtype = dtype or dtype_arg - - if device is not None and device_arg is not None: - raise ValueError( - "You have passed `device` both as an argument and as a keyword argument. Please only pass one of the two." - ) - - device = device or device_arg - - # throw warning if pipeline is in "offloaded"-mode but user tries to manually set to GPU. - def module_is_sequentially_offloaded(module): - if not is_accelerate_available() or is_accelerate_version("<", "0.14.0"): - return False - - return hasattr(module, "_hf_hook") and ( - isinstance(module._hf_hook, accelerate.hooks.AlignDevicesHook) - or hasattr(module._hf_hook, "hooks") - and isinstance(module._hf_hook.hooks[0], accelerate.hooks.AlignDevicesHook) - ) - - def module_is_offloaded(module): - if not is_accelerate_available() or is_accelerate_version("<", "0.17.0.dev0"): - return False - - return hasattr(module, "_hf_hook") and isinstance(module._hf_hook, accelerate.hooks.CpuOffload) - - # .to("cuda") would raise an error if the pipeline is sequentially offloaded, so we raise our own to make it clearer - pipeline_is_sequentially_offloaded = any( - module_is_sequentially_offloaded(module) for _, module in self.components.items() - ) - if pipeline_is_sequentially_offloaded and device and torch.device(device).type == "cuda": - raise ValueError( - "It seems like you have activated sequential model offloading by calling `enable_sequential_cpu_offload`, but are now attempting to move the pipeline to GPU. This is not compatible with offloading. Please, move your pipeline `.to('cpu')` or consider removing the move altogether if you use sequential offloading." - ) - - is_pipeline_device_mapped = hasattr(self, "hf_device_map") and self.hf_device_map is not None and len(self.hf_device_map) > 1 - if is_pipeline_device_mapped: - raise ValueError( - "It seems like you have activated a device mapping strategy on the pipeline which doesn't allow explicit device placement using `to()`. You can call `reset_device_map()` first and then call `to()`." - ) - - # Display a warning in this case (the operation succeeds but the benefits are lost) - pipeline_is_offloaded = any(module_is_offloaded(module) for _, module in self.components.items()) - if pipeline_is_offloaded and device and torch.device(device).type == "cuda": - logger.warning( - f"It seems like you have activated model offloading by calling `enable_model_cpu_offload`, but are now manually moving the pipeline to GPU. It is strongly recommended against doing so as memory gains from offloading are likely to be lost. Offloading automatically takes care of moving the individual components {', '.join(self.components.keys())} to GPU when needed. To make sure offloading works as expected, you should consider moving the pipeline back to CPU: `pipeline.to('cpu')` or removing the move altogether if you use offloading." - ) - - modules = [m for m in self.components.values() if isinstance(m, torch.nn.Module)] - - is_offloaded = pipeline_is_offloaded or pipeline_is_sequentially_offloaded - for module in modules: - is_loaded_in_8bit = hasattr(module, "is_loaded_in_8bit") and module.is_loaded_in_8bit - - if is_loaded_in_8bit and dtype is not None: - logger.warning( - f"The module '{module.__class__.__name__}' has been loaded in 8bit and conversion to {dtype} is not yet supported. Module is still in 8bit precision." - ) - - if is_loaded_in_8bit and device is not None: - logger.warning( - f"The module '{module.__class__.__name__}' has been loaded in 8bit and moving it to {dtype} via `.to()` is not yet supported. Module is still on {module.device}." - ) - else: - module.to(device, dtype) - - if ( - module.dtype == torch.float16 - and str(device) in ["cpu"] - and not silence_dtype_warnings - and not is_offloaded - ): - logger.warning( - "Pipelines loaded with `dtype=torch.float16` cannot run with `cpu` device. It" - " is not recommended to move them to `cpu` as running them will fail. Please make" - " sure to use an accelerator to run the pipeline in inference, due to the lack of" - " support for`float16` operations on this device in PyTorch. Please, remove the" - " `torch_dtype=torch.float16` argument, or use another device for inference." - ) - return self + if component_spec.type_hint is not None: + lib_name, cls_name = _fetch_class_library_tuple(component_spec.type_hint) + else: + lib_name = None + cls_name = None + load_spec_dict = {k: getattr(component_spec, k) for k in component_spec.loading_fields()} + return { + "type_hint": (lib_name, cls_name), + **load_spec_dict, + } + + @staticmethod + def _dict_to_component_spec( + name: str, + spec_dict: Dict[str, Any], + ) -> ComponentSpec: + """ + Reconstruct a ComponentSpec from a dict. + """ + # make a shallow copy so we can pop() safely + spec_dict = spec_dict.copy() + # pull out and resolve the stored type_hint + lib_name, cls_name = spec_dict.pop("type_hint") + if lib_name is not None and cls_name is not None: + type_hint = simple_get_class_obj(lib_name, cls_name) + else: + type_hint = None + + # re‐assemble the ComponentSpec + return ComponentSpec( + name=name, + type_hint=type_hint, + **spec_dict, + ) \ No newline at end of file diff --git a/src/diffusers/pipelines/modular_pipeline_utils.py b/src/diffusers/pipelines/modular_pipeline_utils.py new file mode 100644 index 000000000000..c8064a5215aa --- /dev/null +++ b/src/diffusers/pipelines/modular_pipeline_utils.py @@ -0,0 +1,592 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import re +import inspect +from dataclasses import dataclass, asdict, field, fields +from typing import Any, Dict, List, Optional, Tuple, Type, Union, Literal + +from ..utils.import_utils import is_torch_available +from ..configuration_utils import FrozenDict, ConfigMixin + +if is_torch_available(): + import torch + + +# YiYi TODO: +# 1. validate the dataclass fields +# 2. add a validator for create_* methods, make sure they are valid inputs to pass to from_pretrained() +@dataclass +class ComponentSpec: + """Specification for a pipeline component. + + A component can be created in two ways: + 1. From scratch using __init__ with a config dict + 2. using `from_pretrained` + + Attributes: + name: Name of the component + type_hint: Type of the component (e.g. UNet2DConditionModel) + description: Optional description of the component + config: Optional config dict for __init__ creation + repo: Optional repo path for from_pretrained creation + subfolder: Optional subfolder in repo + variant: Optional variant in repo + revision: Optional revision in repo + default_creation_method: Preferred creation method - "from_config" or "from_pretrained" + """ + name: Optional[str] = None + type_hint: Optional[Type] = None + description: Optional[str] = None + config: Optional[FrozenDict[str, Any]] = None + # YiYi Notes: should we change it to pretrained_model_name_or_path for consistency? a bit long for a field name + repo: Optional[Union[str, List[str]]] = field(default=None, metadata={"loading": True}) + subfolder: Optional[str] = field(default=None, metadata={"loading": True}) + variant: Optional[str] = field(default=None, metadata={"loading": True}) + revision: Optional[str] = field(default=None, metadata={"loading": True}) + default_creation_method: Literal["from_config", "from_pretrained"] = "from_pretrained" + + + def __hash__(self): + """Make ComponentSpec hashable, using load_id as the hash value.""" + return hash((self.name, self.load_id, self.default_creation_method)) + + def __eq__(self, other): + """Compare ComponentSpec objects based on name and load_id.""" + if not isinstance(other, ComponentSpec): + return False + return (self.name == other.name and + self.load_id == other.load_id and + self.default_creation_method == other.default_creation_method) + + @classmethod + def from_component(cls, name: str, component: torch.nn.Module) -> Any: + """Create a ComponentSpec from a Component created by `create` method.""" + + if not hasattr(component, "_diffusers_load_id"): + raise ValueError("Component is not created by `create` method") + + type_hint = component.__class__ + + if component._diffusers_load_id == "null" and isinstance(component, ConfigMixin): + config = component.config + else: + config = None + + load_spec = cls.decode_load_id(component._diffusers_load_id) + + return cls(name=name, type_hint=type_hint, config=config, **load_spec) + + @classmethod + def from_load_id(cls, load_id: str, name: Optional[str] = None) -> Any: + """Create a ComponentSpec from a load_id string.""" + if load_id == "null": + raise ValueError("Cannot create ComponentSpec from null load_id") + + # Decode the load_id into a dictionary of loading fields + load_fields = cls.decode_load_id(load_id) + + # Create a new ComponentSpec instance with the decoded fields + return cls(name=name, **load_fields) + + @classmethod + def loading_fields(cls) -> List[str]: + """ + Return the names of all loading‐related fields + (i.e. those whose field.metadata["loading"] is True). + """ + return [f.name for f in fields(cls) if f.metadata.get("loading", False)] + + + @property + def load_id(self) -> str: + """ + Unique identifier for this spec's pretrained load, + composed of repo|subfolder|variant|revision (no empty segments). + """ + parts = [getattr(self, k) for k in self.loading_fields()] + parts = ["null" if p is None else p for p in parts] + return "|".join(p for p in parts if p) + + @classmethod + def decode_load_id(cls, load_id: str) -> Dict[str, Optional[str]]: + """ + Decode a load_id string back into a dictionary of loading fields and values. + + Args: + load_id: The load_id string to decode, format: "repo|subfolder|variant|revision" + where None values are represented as "null" + + Returns: + Dict mapping loading field names to their values. e.g. + { + "repo": "path/to/repo", + "subfolder": "subfolder", + "variant": "variant", + "revision": "revision" + } + If a segment value is "null", it's replaced with None. + Returns None if load_id is "null" (indicating component not loaded from pretrained). + """ + + # Get all loading fields in order + loading_fields = cls.loading_fields() + result = {f: None for f in loading_fields} + + if load_id == "null": + return result + + # Split the load_id + parts = load_id.split("|") + + # Map parts to loading fields by position + for i, part in enumerate(parts): + if i < len(loading_fields): + # Convert "null" string back to None + result[loading_fields[i]] = None if part == "null" else part + + return result + + # YiYi TODO: add validator + def create(self, **kwargs) -> Any: + """Create the component using the preferred creation method.""" + + # from_pretrained creation + if self.default_creation_method == "from_pretrained": + return self.create_from_pretrained(**kwargs) + elif self.default_creation_method == "from_config": + # from_config creation + return self.create_from_config(**kwargs) + else: + raise ValueError(f"Invalid creation method: {self.default_creation_method}") + + def create_from_config(self, config: Optional[Union[FrozenDict, Dict[str, Any]]] = None, **kwargs) -> Any: + """Create component using from_config with config.""" + + if self.type_hint is None or not isinstance(self.type_hint, type): + raise ValueError( + f"`type_hint` is required when using from_config creation method." + ) + + config = config or self.config or {} + + if issubclass(self.type_hint, ConfigMixin): + component = self.type_hint.from_config(config, **kwargs) + else: + signature_params = inspect.signature(self.type_hint.__init__).parameters + init_kwargs = {} + for k, v in config.items(): + if k in signature_params: + init_kwargs[k] = v + for k, v in kwargs.items(): + if k in signature_params: + init_kwargs[k] = v + component = self.type_hint(**init_kwargs) + + component._diffusers_load_id = "null" + if hasattr(component, "config"): + self.config = component.config + + return component + + # YiYi TODO: add guard for type of model, if it is supported by from_pretrained + def create_from_pretrained(self, **kwargs) -> Any: + """Create component using from_pretrained.""" + + passed_loading_kwargs = {key: kwargs.pop(key) for key in self.loading_fields() if key in kwargs} + load_kwargs = {key: passed_loading_kwargs.get(key, getattr(self, key)) for key in self.loading_fields()} + # repo is a required argument for from_pretrained, a.k.a. pretrained_model_name_or_path + repo = load_kwargs.pop("repo", None) + if repo is None: + raise ValueError(f"`repo` info is required when using from_pretrained creation method (you can directly set it in `repo` field of the ComponentSpec or pass it as an argument)") + + if self.type_hint is None: + try: + from diffusers import AutoModel + component = AutoModel.from_pretrained(repo, **load_kwargs, **kwargs) + except Exception as e: + raise ValueError(f"Error creating {self.name} without `type_hint` from pretrained: {e}") + self.type_hint = component.__class__ + else: + try: + component = self.type_hint.from_pretrained(repo, **load_kwargs, **kwargs) + except Exception as e: + raise ValueError(f"Error creating {self.name}[{self.type_hint.__name__}] from pretrained: {e}") + + if repo != self.repo: + self.repo = repo + for k, v in passed_loading_kwargs.items(): + if v is not None: + setattr(self, k, v) + component._diffusers_load_id = self.load_id + + return component + + + +@dataclass +class ConfigSpec: + """Specification for a pipeline configuration parameter.""" + name: str + default: Any + description: Optional[str] = None +@dataclass +class InputParam: + """Specification for an input parameter.""" + name: str + type_hint: Any = None + default: Any = None + required: bool = False + description: str = "" + + def __repr__(self): + return f"<{self.name}: {'required' if self.required else 'optional'}, default={self.default}>" + + +@dataclass +class OutputParam: + """Specification for an output parameter.""" + name: str + type_hint: Any = None + description: str = "" + + def __repr__(self): + return f"<{self.name}: {self.type_hint.__name__ if hasattr(self.type_hint, '__name__') else str(self.type_hint)}>" + + +def format_inputs_short(inputs): + """ + Format input parameters into a string representation, with required params first followed by optional ones. + + Args: + inputs: List of input parameters with 'required' and 'name' attributes, and 'default' for optional params + + Returns: + str: Formatted string of input parameters + + Example: + >>> inputs = [ + ... InputParam(name="prompt", required=True), + ... InputParam(name="image", required=True), + ... InputParam(name="guidance_scale", required=False, default=7.5), + ... InputParam(name="num_inference_steps", required=False, default=50) + ... ] + >>> format_inputs_short(inputs) + 'prompt, image, guidance_scale=7.5, num_inference_steps=50' + """ + required_inputs = [param for param in inputs if param.required] + optional_inputs = [param for param in inputs if not param.required] + + required_str = ", ".join(param.name for param in required_inputs) + optional_str = ", ".join(f"{param.name}={param.default}" for param in optional_inputs) + + inputs_str = required_str + if optional_str: + inputs_str = f"{inputs_str}, {optional_str}" if required_str else optional_str + + return inputs_str + + +def format_intermediates_short(intermediates_inputs, required_intermediates_inputs, intermediates_outputs): + """ + Formats intermediate inputs and outputs of a block into a string representation. + + Args: + intermediates_inputs: List of intermediate input parameters + required_intermediates_inputs: List of required intermediate input names + intermediates_outputs: List of intermediate output parameters + + Returns: + str: Formatted string like: + Intermediates: + - inputs: Required(latents), dtype + - modified: latents # variables that appear in both inputs and outputs + - outputs: images # new outputs only + """ + # Handle inputs + input_parts = [] + for inp in intermediates_inputs: + if inp.name in required_intermediates_inputs: + input_parts.append(f"Required({inp.name})") + else: + input_parts.append(inp.name) + + # Handle modified variables (appear in both inputs and outputs) + inputs_set = {inp.name for inp in intermediates_inputs} + modified_parts = [] + new_output_parts = [] + + for out in intermediates_outputs: + if out.name in inputs_set: + modified_parts.append(out.name) + else: + new_output_parts.append(out.name) + + result = [] + if input_parts: + result.append(f" - inputs: {', '.join(input_parts)}") + if modified_parts: + result.append(f" - modified: {', '.join(modified_parts)}") + if new_output_parts: + result.append(f" - outputs: {', '.join(new_output_parts)}") + + return "\n".join(result) if result else " (none)" + + +def format_params(params, header="Args", indent_level=4, max_line_length=115): + """Format a list of InputParam or OutputParam objects into a readable string representation. + + Args: + params: List of InputParam or OutputParam objects to format + header: Header text to use (e.g. "Args" or "Returns") + indent_level: Number of spaces to indent each parameter line (default: 4) + max_line_length: Maximum length for each line before wrapping (default: 115) + + Returns: + A formatted string representing all parameters + """ + if not params: + return "" + + base_indent = " " * indent_level + param_indent = " " * (indent_level + 4) + desc_indent = " " * (indent_level + 8) + formatted_params = [] + + def get_type_str(type_hint): + if hasattr(type_hint, "__origin__") and type_hint.__origin__ is Union: + types = [t.__name__ if hasattr(t, "__name__") else str(t) for t in type_hint.__args__] + return f"Union[{', '.join(types)}]" + return type_hint.__name__ if hasattr(type_hint, "__name__") else str(type_hint) + + def wrap_text(text, indent, max_length): + """Wrap text while preserving markdown links and maintaining indentation.""" + words = text.split() + lines = [] + current_line = [] + current_length = 0 + + for word in words: + word_length = len(word) + (1 if current_line else 0) + + if current_line and current_length + word_length > max_length: + lines.append(" ".join(current_line)) + current_line = [word] + current_length = len(word) + else: + current_line.append(word) + current_length += word_length + + if current_line: + lines.append(" ".join(current_line)) + + return f"\n{indent}".join(lines) + + # Add the header + formatted_params.append(f"{base_indent}{header}:") + + for param in params: + # Format parameter name and type + type_str = get_type_str(param.type_hint) if param.type_hint != Any else "" + param_str = f"{param_indent}{param.name} (`{type_str}`" + + # Add optional tag and default value if parameter is an InputParam and optional + if hasattr(param, "required"): + if not param.required: + param_str += ", *optional*" + if param.default is not None: + param_str += f", defaults to {param.default}" + param_str += "):" + + # Add description on a new line with additional indentation and wrapping + if param.description: + desc = re.sub( + r'\[(.*?)\]\((https?://[^\s\)]+)\)', + r'[\1](\2)', + param.description + ) + wrapped_desc = wrap_text(desc, desc_indent, max_line_length) + param_str += f"\n{desc_indent}{wrapped_desc}" + + formatted_params.append(param_str) + + return "\n\n".join(formatted_params) + + +def format_input_params(input_params, indent_level=4, max_line_length=115): + """Format a list of InputParam objects into a readable string representation. + + Args: + input_params: List of InputParam objects to format + indent_level: Number of spaces to indent each parameter line (default: 4) + max_line_length: Maximum length for each line before wrapping (default: 115) + + Returns: + A formatted string representing all input parameters + """ + return format_params(input_params, "Inputs", indent_level, max_line_length) + + +def format_output_params(output_params, indent_level=4, max_line_length=115): + """Format a list of OutputParam objects into a readable string representation. + + Args: + output_params: List of OutputParam objects to format + indent_level: Number of spaces to indent each parameter line (default: 4) + max_line_length: Maximum length for each line before wrapping (default: 115) + + Returns: + A formatted string representing all output parameters + """ + return format_params(output_params, "Outputs", indent_level, max_line_length) + + +def format_components(components, indent_level=4, max_line_length=115, add_empty_lines=True): + """Format a list of ComponentSpec objects into a readable string representation. + + Args: + components: List of ComponentSpec objects to format + indent_level: Number of spaces to indent each component line (default: 4) + max_line_length: Maximum length for each line before wrapping (default: 115) + add_empty_lines: Whether to add empty lines between components (default: True) + + Returns: + A formatted string representing all components + """ + if not components: + return "" + + base_indent = " " * indent_level + component_indent = " " * (indent_level + 4) + formatted_components = [] + + # Add the header + formatted_components.append(f"{base_indent}Components:") + if add_empty_lines: + formatted_components.append("") + + # Add each component with optional empty lines between them + for i, component in enumerate(components): + # Get type name, handling special cases + type_name = component.type_hint.__name__ if hasattr(component.type_hint, "__name__") else str(component.type_hint) + + component_desc = f"{component_indent}{component.name} (`{type_name}`)" + if component.description: + component_desc += f": {component.description}" + + # Get the loading fields dynamically + loading_field_values = [] + for field_name in component.loading_fields(): + field_value = getattr(component, field_name) + if field_value is not None: + loading_field_values.append(f"{field_name}={field_value}") + + # Add loading field information if available + if loading_field_values: + component_desc += f" [{', '.join(loading_field_values)}]" + + formatted_components.append(component_desc) + + # Add an empty line after each component except the last one + if add_empty_lines and i < len(components) - 1: + formatted_components.append("") + + return "\n".join(formatted_components) + + +def format_configs(configs, indent_level=4, max_line_length=115, add_empty_lines=True): + """Format a list of ConfigSpec objects into a readable string representation. + + Args: + configs: List of ConfigSpec objects to format + indent_level: Number of spaces to indent each config line (default: 4) + max_line_length: Maximum length for each line before wrapping (default: 115) + add_empty_lines: Whether to add empty lines between configs (default: True) + + Returns: + A formatted string representing all configs + """ + if not configs: + return "" + + base_indent = " " * indent_level + config_indent = " " * (indent_level + 4) + formatted_configs = [] + + # Add the header + formatted_configs.append(f"{base_indent}Configs:") + if add_empty_lines: + formatted_configs.append("") + + # Add each config with optional empty lines between them + for i, config in enumerate(configs): + config_desc = f"{config_indent}{config.name} (default: {config.default})" + if config.description: + config_desc += f": {config.description}" + formatted_configs.append(config_desc) + + # Add an empty line after each config except the last one + if add_empty_lines and i < len(configs) - 1: + formatted_configs.append("") + + return "\n".join(formatted_configs) + + +def make_doc_string(inputs, intermediates_inputs, outputs, description="", class_name=None, expected_components=None, expected_configs=None): + """ + Generates a formatted documentation string describing the pipeline block's parameters and structure. + + Args: + inputs: List of input parameters + intermediates_inputs: List of intermediate input parameters + outputs: List of output parameters + description (str, *optional*): Description of the block + class_name (str, *optional*): Name of the class to include in the documentation + expected_components (List[ComponentSpec], *optional*): List of expected components + expected_configs (List[ConfigSpec], *optional*): List of expected configurations + + Returns: + str: A formatted string containing information about components, configs, call parameters, + intermediate inputs/outputs, and final outputs. + """ + output = "" + + # Add class name if provided + if class_name: + output += f"class {class_name}\n\n" + + # Add description + if description: + desc_lines = description.strip().split('\n') + aligned_desc = '\n'.join(' ' + line for line in desc_lines) + output += aligned_desc + "\n\n" + + # Add components section if provided + if expected_components and len(expected_components) > 0: + components_str = format_components(expected_components, indent_level=2) + output += components_str + "\n\n" + + # Add configs section if provided + if expected_configs and len(expected_configs) > 0: + configs_str = format_configs(expected_configs, indent_level=2) + output += configs_str + "\n\n" + + # Add inputs section + output += format_input_params(inputs + intermediates_inputs, indent_level=2) + + # Add outputs section + output += "\n\n" + output += format_output_params(outputs, indent_level=2) + + return output \ No newline at end of file diff --git a/src/diffusers/pipelines/pipeline_loading_utils.py b/src/diffusers/pipelines/pipeline_loading_utils.py index a9d6c561af34..48d5992f31ee 100644 --- a/src/diffusers/pipelines/pipeline_loading_utils.py +++ b/src/diffusers/pipelines/pipeline_loading_utils.py @@ -333,6 +333,20 @@ def maybe_raise_or_warn( ) +# a simpler version of get_class_obj_and_candidates, it won't work with custom code +def simple_get_class_obj(library_name, class_name): + from diffusers import pipelines + is_pipeline_module = hasattr(pipelines, library_name) + + if is_pipeline_module: + pipeline_module = getattr(pipelines, library_name) + class_obj = getattr(pipeline_module, class_name) + else: + library = importlib.import_module(library_name) + class_obj = getattr(library, class_name) + + return class_obj + def get_class_obj_and_candidates( library_name, class_name, importable_classes, pipelines, is_pipeline_module, component_name=None, cache_dir=None ): @@ -414,7 +428,7 @@ def _get_pipeline_class( revision=revision, ) - if class_obj.__name__ != "DiffusionPipeline" and class_obj.__name__ != "ModularPipeline": + if class_obj.__name__ != "DiffusionPipeline": return class_obj diffusers_module = importlib.import_module(class_obj.__module__.split(".")[0]) @@ -841,7 +855,10 @@ def _fetch_class_library_tuple(module): library = not_compiled_module.__module__ # retrieve class_name - class_name = not_compiled_module.__class__.__name__ + if isinstance(not_compiled_module, type): + class_name = not_compiled_module.__name__ + else: + class_name = not_compiled_module.__class__.__name__ return (library, class_name) diff --git a/src/diffusers/pipelines/pipeline_utils.py b/src/diffusers/pipelines/pipeline_utils.py index c27cd434cd9a..22b0baee2e39 100644 --- a/src/diffusers/pipelines/pipeline_utils.py +++ b/src/diffusers/pipelines/pipeline_utils.py @@ -1917,9 +1917,10 @@ def from_pipe(cls, pipeline, **kwargs): f"{'' if k.startswith('_') else '_'}{k}": v for k, v in original_config.items() if k not in pipeline_kwargs } + optional_components = pipeline._optional_components if hasattr(pipeline, "_optional_components") and pipeline._optional_components else [] missing_modules = ( set(expected_modules) - - set(pipeline._optional_components) + - set(optional_components) - set(pipeline_kwargs.keys()) - set(true_optional_modules) ) diff --git a/src/diffusers/pipelines/stable_diffusion_xl/__init__.py b/src/diffusers/pipelines/stable_diffusion_xl/__init__.py index 584b260eaaa8..006836fe30d4 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/__init__.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/__init__.py @@ -34,7 +34,7 @@ "StableDiffusionXLDecodeLatentsStep", "StableDiffusionXLDenoiseStep", "StableDiffusionXLInputStep", - "StableDiffusionXLModularPipeline", + "StableDiffusionXLModularLoader", "StableDiffusionXLPrepareAdditionalConditioningStep", "StableDiffusionXLPrepareLatentsStep", "StableDiffusionXLSetTimestepsStep", @@ -65,7 +65,7 @@ StableDiffusionXLDecodeLatentsStep, StableDiffusionXLDenoiseStep, StableDiffusionXLInputStep, - StableDiffusionXLModularPipeline, + StableDiffusionXLModularLoader, StableDiffusionXLPrepareAdditionalConditioningStep, StableDiffusionXLPrepareLatentsStep, StableDiffusionXLSetTimestepsStep, diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py index 2493d5635552..5ae9e63851db 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py @@ -34,7 +34,7 @@ from ..controlnet.multicontrolnet import MultiControlNetModel from ..modular_pipeline import ( AutoPipelineBlocks, - ModularPipeline, + ModularLoader, PipelineBlock, PipelineState, InputParam, @@ -56,8 +56,9 @@ CLIPVisionModelWithProjection, ) -from ...schedulers import KarrasDiffusionSchedulers -from ...guiders import GuiderType, ClassifierFreeGuidance +from ...schedulers import EulerDiscreteScheduler +from ...guiders import ClassifierFreeGuidance +from ...configuration_utils import FrozenDict import numpy as np @@ -182,9 +183,13 @@ def description(self) -> str: def expected_components(self) -> List[ComponentSpec]: return [ ComponentSpec("image_encoder", CLIPVisionModelWithProjection), - ComponentSpec("feature_extractor", CLIPImageProcessor), + ComponentSpec("feature_extractor", CLIPImageProcessor, config=FrozenDict({"size": 224, "crop_size": 224}), default_creation_method="from_config"), ComponentSpec("unet", UNet2DConditionModel), - ComponentSpec("guider", GuiderType), + ComponentSpec( + "guider", + ClassifierFreeGuidance, + config=FrozenDict({"guidance_scale": 7.5}), + default_creation_method="from_config"), ] @property @@ -320,7 +325,11 @@ def expected_components(self) -> List[ComponentSpec]: ComponentSpec("text_encoder_2", CLIPTextModelWithProjection), ComponentSpec("tokenizer", CLIPTokenizer), ComponentSpec("tokenizer_2", CLIPTokenizer), - ComponentSpec("guider", GuiderType), + ComponentSpec( + "guider", + ClassifierFreeGuidance, + config=FrozenDict({"guidance_scale": 7.5}), + default_creation_method="from_config"), ] @property @@ -645,7 +654,11 @@ def description(self) -> str: def expected_components(self) -> List[ComponentSpec]: return [ ComponentSpec("vae", AutoencoderKL), - ComponentSpec("image_processor", VaeImageProcessor, obj=VaeImageProcessor()), + ComponentSpec( + "image_processor", + VaeImageProcessor, + config=FrozenDict({"vae_scale_factor": 8}), + default_creation_method="from_config"), ] @property @@ -740,8 +753,16 @@ class StableDiffusionXLInpaintVaeEncoderStep(PipelineBlock): def expected_components(self) -> List[ComponentSpec]: return [ ComponentSpec("vae", AutoencoderKL), - ComponentSpec("image_processor", VaeImageProcessor, obj=VaeImageProcessor()), - ComponentSpec("mask_processor", VaeImageProcessor, obj=VaeImageProcessor(do_normalize=False, do_binarize=True, do_convert_grayscale=True)), + ComponentSpec( + "image_processor", + VaeImageProcessor, + config=FrozenDict({"vae_scale_factor": 8}), + default_creation_method="from_config"), + ComponentSpec( + "mask_processor", + VaeImageProcessor, + config=FrozenDict({"do_normalize": False, "vae_scale_factor": 8, "do_binarize": True, "do_convert_grayscale": True}), + default_creation_method="from_config"), ] @@ -1028,7 +1049,7 @@ class StableDiffusionXLImg2ImgSetTimestepsStep(PipelineBlock): @property def expected_components(self) -> List[ComponentSpec]: return [ - ComponentSpec("scheduler", KarrasDiffusionSchedulers), + ComponentSpec("scheduler", EulerDiscreteScheduler), ] @property @@ -1151,7 +1172,7 @@ class StableDiffusionXLSetTimestepsStep(PipelineBlock): @property def expected_components(self) -> List[ComponentSpec]: return [ - ComponentSpec("scheduler", KarrasDiffusionSchedulers), + ComponentSpec("scheduler", EulerDiscreteScheduler), ] @property @@ -1206,7 +1227,7 @@ class StableDiffusionXLInpaintPrepareLatentsStep(PipelineBlock): @property def expected_components(self) -> List[ComponentSpec]: return [ - ComponentSpec("scheduler", KarrasDiffusionSchedulers), + ComponentSpec("scheduler", EulerDiscreteScheduler), ] @property @@ -1460,7 +1481,7 @@ class StableDiffusionXLImg2ImgPrepareLatentsStep(PipelineBlock): def expected_components(self) -> List[ComponentSpec]: return [ ComponentSpec("vae", AutoencoderKL), - ComponentSpec("scheduler", KarrasDiffusionSchedulers), + ComponentSpec("scheduler", EulerDiscreteScheduler), ] @property @@ -1608,7 +1629,7 @@ class StableDiffusionXLPrepareLatentsStep(PipelineBlock): @property def expected_components(self) -> List[ComponentSpec]: return [ - ComponentSpec("scheduler", KarrasDiffusionSchedulers), + ComponentSpec("scheduler", EulerDiscreteScheduler), ] @property @@ -1727,7 +1748,7 @@ class StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep(PipelineBlock): @property def expected_configs(self) -> List[ConfigSpec]: - return [ConfigSpec("requires_aesthetics_score", default=False),] + return [ConfigSpec("requires_aesthetics_score", False),] @property def description(self) -> str: @@ -2062,8 +2083,12 @@ class StableDiffusionXLDenoiseStep(PipelineBlock): @property def expected_components(self) -> List[ComponentSpec]: return [ - ComponentSpec("guider", GuiderType, obj=ClassifierFreeGuidance()), - ComponentSpec("scheduler", KarrasDiffusionSchedulers), + ComponentSpec( + "guider", + ClassifierFreeGuidance, + config=FrozenDict({"guidance_scale": 7.5}), + default_creation_method="from_config"), + ComponentSpec("scheduler", EulerDiscreteScheduler), ComponentSpec("unet", UNet2DConditionModel), ] @@ -2245,7 +2270,7 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: ip_adapter_embeds=("ip_adapter_embeds", "negative_ip_adapter_embeds"), ) - with pipeline.progress_bar(total=data.num_inference_steps) as progress_bar: + with self.progress_bar(total=data.num_inference_steps) as progress_bar: for i, t in enumerate(data.timesteps): pipeline.guider.set_state(step=i, num_inference_steps=data.num_inference_steps, timestep=t) guider_data = pipeline.guider.prepare_inputs(data) @@ -2316,11 +2341,15 @@ class StableDiffusionXLControlNetDenoiseStep(PipelineBlock): @property def expected_components(self) -> List[ComponentSpec]: return [ - ComponentSpec("guider", GuiderType, obj=ClassifierFreeGuidance()), - ComponentSpec("scheduler", KarrasDiffusionSchedulers), + ComponentSpec( + "guider", + ClassifierFreeGuidance, + config=FrozenDict({"guidance_scale": 7.5}), + default_creation_method="from_config"), + ComponentSpec("scheduler", EulerDiscreteScheduler), ComponentSpec("unet", UNet2DConditionModel), ComponentSpec("controlnet", ControlNetModel), - ComponentSpec("control_image_processor", VaeImageProcessor, obj=VaeImageProcessor(do_convert_rgb=True, do_normalize=False)), + ComponentSpec("control_image_processor", VaeImageProcessor, config=FrozenDict({"do_convert_rgb": True, "do_normalize": False}), default_creation_method="from_config"), ] @property @@ -2626,7 +2655,7 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: ) # (5) Denoise loop - with pipeline.progress_bar(total=data.num_inference_steps) as progress_bar: + with self.progress_bar(total=data.num_inference_steps) as progress_bar: for i, t in enumerate(data.timesteps): pipeline.guider.set_state(step=i, num_inference_steps=data.num_inference_steps, timestep=t) guider_data = pipeline.guider.prepare_inputs(data) @@ -2733,9 +2762,17 @@ def expected_components(self) -> List[ComponentSpec]: return [ ComponentSpec("unet", UNet2DConditionModel), ComponentSpec("controlnet", ControlNetUnionModel), - ComponentSpec("scheduler", KarrasDiffusionSchedulers), - ComponentSpec("guider", GuiderType, obj=ClassifierFreeGuidance()), - ComponentSpec("control_image_processor", VaeImageProcessor, obj=VaeImageProcessor(do_convert_rgb=True, do_normalize=False)), + ComponentSpec("scheduler", EulerDiscreteScheduler), + ComponentSpec( + "guider", + ClassifierFreeGuidance, + config=FrozenDict({"guidance_scale": 7.5}), + default_creation_method="from_config"), + ComponentSpec( + "control_image_processor", + VaeImageProcessor, + config=FrozenDict({"do_convert_rgb": True, "do_normalize": False}), + default_creation_method="from_config"), ] @property @@ -3029,7 +3066,7 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: ip_adapter_embeds=("ip_adapter_embeds", "negative_ip_adapter_embeds"), ) - with pipeline.progress_bar(total=data.num_inference_steps) as progress_bar: + with self.progress_bar(total=data.num_inference_steps) as progress_bar: for i, t in enumerate(data.timesteps): pipeline.guider.set_state(step=i, num_inference_steps=data.num_inference_steps, timestep=t) guider_data = pipeline.guider.prepare_inputs(data) @@ -3136,7 +3173,11 @@ class StableDiffusionXLDecodeLatentsStep(PipelineBlock): def expected_components(self) -> List[ComponentSpec]: return [ ComponentSpec("vae", AutoencoderKL), - ComponentSpec("image_processor", VaeImageProcessor, obj=VaeImageProcessor()) + ComponentSpec( + "image_processor", + VaeImageProcessor, + config=FrozenDict({"vae_scale_factor": 8}), + default_creation_method="from_config"), ] @property @@ -3527,9 +3568,14 @@ def description(self): } -# YiYi TODO: rename to components etc. and not inherit from ModularPipeline -class StableDiffusionXLModularPipeline( - ModularPipeline, +# YiYi Notes: model specific components: +## (1) it should inherit from ModularLoader +## (2) acts like a container that holds components and configs +## (3) define default config (related to components), e.g. default_sample_size, vae_scale_factor, num_channels_unet, num_channels_latents +## (4) inherit from model-specic loader class (e.g. StableDiffusionXLLoraLoaderMixin) +## (5) how to use together with Components_manager? +class StableDiffusionXLModularLoader( + ModularLoader, StableDiffusionMixin, TextualInversionLoaderMixin, StableDiffusionXLLoraLoaderMixin, diff --git a/src/diffusers/utils/dummy_pt_objects.py b/src/diffusers/utils/dummy_pt_objects.py index bea14cfe9c8d..f3837e39f192 100644 --- a/src/diffusers/utils/dummy_pt_objects.py +++ b/src/diffusers/utils/dummy_pt_objects.py @@ -1328,7 +1328,7 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) -class ModularPipeline(metaclass=DummyObject): +class ModularLoader(metaclass=DummyObject): _backends = ["torch"] def __init__(self, *args, **kwargs): diff --git a/src/diffusers/utils/dummy_torch_and_transformers_objects.py b/src/diffusers/utils/dummy_torch_and_transformers_objects.py index 0a2c1eefae12..cbfbb842723a 100644 --- a/src/diffusers/utils/dummy_torch_and_transformers_objects.py +++ b/src/diffusers/utils/dummy_torch_and_transformers_objects.py @@ -2417,7 +2417,7 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) -class StableDiffusionXLModularPipeline(metaclass=DummyObject): +class StableDiffusionXLModularLoader(metaclass=DummyObject): _backends = ["torch", "transformers"] def __init__(self, *args, **kwargs):