Skip to content

Commit ee84283

Browse files
committed
add componentspec and configspec
1 parent 96795af commit ee84283

File tree

3 files changed

+1215
-39
lines changed

3 files changed

+1215
-39
lines changed

src/diffusers/guider.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -743,3 +743,6 @@ def apply_guidance(
743743
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
744744
noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale)
745745
return noise_pred
746+
747+
748+
Guiders = Union[CFGGuider, PAGGuider, APGGuider]

src/diffusers/pipelines/modular_pipeline.py

Lines changed: 138 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
import warnings
1717
from collections import OrderedDict
1818
from dataclasses import dataclass, field
19-
from typing import Any, Dict, List, Tuple, Union
19+
from typing import Any, Dict, List, Tuple, Union, Optional, Type
2020

2121

2222
import torch
@@ -338,11 +338,28 @@ def make_doc_string(inputs, intermediates_inputs, outputs, description=""):
338338
return output
339339

340340

341+
@dataclass
342+
class ComponentSpec:
343+
"""Specification for a pipeline component."""
344+
name: str
345+
type_hint: Optional[Type] = None
346+
description: Optional[str] = None
347+
default: Any = None # you can create a default component if it is a stateless class like scheduler, guider or image processor
348+
default_class_name: Union[str, List[str], Tuple[str, str]] # Either "class_name" or ["module", "class_name"]
349+
default_repo: Optional[Union[str, List[str]]] = None # either "repo" or ["repo", "subfolder"]
350+
351+
@dataclass
352+
class ConfigSpec:
353+
"""Specification for a pipeline configuration parameter."""
354+
name: str
355+
default: Any
356+
description: Optional[str] = None
357+
type_hint: Optional[Type] = None
358+
341359
class PipelineBlock:
342-
# YiYi Notes: do we need this?
343-
# pipelie block should set the default value for all expected config/components, so maybe we do not need to explicitly set the list
344-
expected_components = []
345-
expected_configs = []
360+
361+
component_specs: List[ComponentSpec] = []
362+
config_specs: List[ConfigSpec] = []
346363
model_name = None
347364

348365
@property
@@ -409,14 +426,45 @@ def __repr__(self):
409426
desc = '\n'.join(desc) + '\n'
410427

411428
# Components section
412-
expected_components = set(getattr(self, "expected_components", []))
429+
expected_components = getattr(self, "expected_components", [])
430+
expected_component_names = {comp.name for comp in expected_components} if expected_components else set()
413431
loaded_components = set(self.components.keys())
414-
all_components = sorted(expected_components | loaded_components)
432+
all_components = sorted(expected_component_names | loaded_components)
415433

416434
main_components = []
417435
auxiliary_components = []
418436
for k in all_components:
419-
component_str = f" - {k}={type(self.components[k]).__name__}" if k in loaded_components else f" - {k}"
437+
# Get component spec if available
438+
component_spec = next((comp for comp in expected_components if comp.name == k), None)
439+
440+
if k in loaded_components:
441+
component_type = type(self.components[k]).__name__
442+
component_str = f" - {k}={component_type}"
443+
444+
# Add expected type info if available
445+
if component_spec and component_spec.class_name:
446+
expected_type = component_spec.class_name
447+
if isinstance(expected_type, (list, tuple)):
448+
expected_type = expected_type[1] # Get class name from [module, class_name]
449+
if expected_type != component_type:
450+
component_str += f" (expected: {expected_type})"
451+
else:
452+
# Component not loaded but expected
453+
if component_spec:
454+
expected_type = component_spec.class_name
455+
if isinstance(expected_type, (list, tuple)):
456+
expected_type = expected_type[1] # Get class name from [module, class_name]
457+
component_str = f" - {k} (expected: {expected_type})"
458+
459+
# Add repo info if available
460+
if component_spec.default_repo:
461+
repo_info = component_spec.default_repo
462+
if component_spec.subfolder:
463+
repo_info += f", subfolder={component_spec.subfolder}"
464+
component_str += f" [{repo_info}]"
465+
else:
466+
component_str = f" - {k}"
467+
420468
if k in getattr(self, "auxiliary_components", []):
421469
auxiliary_components.append(component_str)
422470
else:
@@ -793,18 +841,52 @@ def __repr__(self):
793841
desc = '\n'.join(desc) + '\n'
794842

795843
# Components section
796-
expected_components = set(getattr(self, "expected_components", []))
844+
expected_components = getattr(self, "expected_components", [])
845+
expected_component_names = {comp.name for comp in expected_components} if expected_components else set()
797846
loaded_components = set(self.components.keys())
798-
all_components = sorted(expected_components | loaded_components)
799-
components_str = " Components:\n" + "\n".join(
800-
f" - {k}={type(self.components[k]).__name__}" if k in loaded_components else f" - {k}"
801-
for k in all_components
802-
)
847+
all_components = sorted(expected_component_names | loaded_components)
803848

804849
# Auxiliaries section
805850
auxiliaries_str = " Auxiliaries:\n" + "\n".join(
806851
f" - {k}={type(v).__name__}" for k, v in self.auxiliaries.items()
807852
)
853+
main_components = []
854+
for k in all_components:
855+
# Get component spec if available
856+
component_spec = next((comp for comp in expected_components if comp.name == k), None)
857+
858+
if k in loaded_components:
859+
component_type = type(self.components[k]).__name__
860+
component_str = f" - {k}={component_type}"
861+
862+
# Add expected type info if available
863+
if component_spec and component_spec.class_name:
864+
expected_type = component_spec.class_name
865+
if isinstance(expected_type, (list, tuple)):
866+
expected_type = expected_type[1] # Get class name from [module, class_name]
867+
if expected_type != component_type:
868+
component_str += f" (expected: {expected_type})"
869+
else:
870+
# Component not loaded but expected
871+
if component_spec:
872+
expected_type = component_spec.class_name
873+
if isinstance(expected_type, (list, tuple)):
874+
expected_type = expected_type[1] # Get class name from [module, class_name]
875+
component_str = f" - {k} (expected: {expected_type})"
876+
877+
# Add repo info if available
878+
if component_spec.default_repo:
879+
repo_info = component_spec.default_repo
880+
if component_spec.subfolder:
881+
repo_info += f", subfolder={component_spec.subfolder}"
882+
component_str += f" [{repo_info}]"
883+
else:
884+
component_str = f" - {k}"
885+
886+
887+
main_components.append(component_str)
888+
889+
components = "Components:\n" + "\n".join(main_components)
808890

809891
# Configs section
810892
expected_configs = set(getattr(self, "expected_configs", []))
@@ -1188,19 +1270,54 @@ def __repr__(self):
11881270
desc = '\n'.join(desc) + '\n'
11891271

11901272
# Components section
1191-
expected_components = set(getattr(self, "expected_components", []))
1273+
expected_components = getattr(self, "expected_components", [])
1274+
expected_component_names = {comp.name for comp in expected_components} if expected_components else set()
11921275
loaded_components = set(self.components.keys())
1193-
all_components = sorted(expected_components | loaded_components)
1194-
components_str = " Components:\n" + "\n".join(
1195-
f" - {k}={type(self.components[k]).__name__}" if k in loaded_components else f" - {k}"
1196-
for k in all_components
1197-
)
1276+
all_components = sorted(expected_component_names | loaded_components)
11981277

11991278
# Auxiliaries section
12001279
auxiliaries_str = " Auxiliaries:\n" + "\n".join(
12011280
f" - {k}={type(v).__name__}" for k, v in self.auxiliaries.items()
12021281
)
12031282

1283+
main_components = []
1284+
for k in all_components:
1285+
# Get component spec if available
1286+
component_spec = next((comp for comp in expected_components if comp.name == k), None)
1287+
1288+
if k in loaded_components:
1289+
component_type = type(self.components[k]).__name__
1290+
component_str = f" - {k}={component_type}"
1291+
1292+
# Add expected type info if available
1293+
if component_spec and component_spec.class_name:
1294+
expected_type = component_spec.class_name
1295+
if isinstance(expected_type, (list, tuple)):
1296+
expected_type = expected_type[1] # Get class name from [module, class_name]
1297+
if expected_type != component_type:
1298+
component_str += f" (expected: {expected_type})"
1299+
else:
1300+
# Component not loaded but expected
1301+
if component_spec:
1302+
expected_type = component_spec.class_name
1303+
if isinstance(expected_type, (list, tuple)):
1304+
expected_type = expected_type[1] # Get class name from [module, class_name]
1305+
component_str = f" - {k} (expected: {expected_type})"
1306+
1307+
# Add repo info if available
1308+
if component_spec.default_repo:
1309+
repo_info = component_spec.default_repo
1310+
if component_spec.subfolder:
1311+
repo_info += f", subfolder={component_spec.subfolder}"
1312+
component_str += f" [{repo_info}]"
1313+
else:
1314+
component_str = f" - {k}"
1315+
1316+
1317+
main_components.append(component_str)
1318+
1319+
components = "Components:\n" + "\n".join(main_components)
1320+
12041321
# Configs section
12051322
expected_configs = set(getattr(self, "expected_configs", []))
12061323
loaded_configs = set(self.configs.keys())
@@ -1558,7 +1675,7 @@ def __repr__(self):
15581675

15591676
return output
15601677

1561-
# YiYi TO-DO: try to unify the to method with the one in DiffusionPipeline
1678+
# YiYi TODO: try to unify the to method with the one in DiffusionPipeline
15621679
# Modified from diffusers.pipelines.pipeline_utils.DiffusionPipeline.to
15631680
def to(self, *args, **kwargs):
15641681
r"""

0 commit comments

Comments
 (0)