|
16 | 16 | import warnings
|
17 | 17 | from collections import OrderedDict
|
18 | 18 | from dataclasses import dataclass, field
|
19 |
| -from typing import Any, Dict, List, Tuple, Union |
| 19 | +from typing import Any, Dict, List, Tuple, Union, Optional, Type |
20 | 20 |
|
21 | 21 |
|
22 | 22 | import torch
|
@@ -338,11 +338,28 @@ def make_doc_string(inputs, intermediates_inputs, outputs, description=""):
|
338 | 338 | return output
|
339 | 339 |
|
340 | 340 |
|
| 341 | +@dataclass |
| 342 | +class ComponentSpec: |
| 343 | + """Specification for a pipeline component.""" |
| 344 | + name: str |
| 345 | + type_hint: Optional[Type] = None |
| 346 | + description: Optional[str] = None |
| 347 | + default: Any = None # you can create a default component if it is a stateless class like scheduler, guider or image processor |
| 348 | + default_class_name: Union[str, List[str], Tuple[str, str]] # Either "class_name" or ["module", "class_name"] |
| 349 | + default_repo: Optional[Union[str, List[str]]] = None # either "repo" or ["repo", "subfolder"] |
| 350 | + |
| 351 | +@dataclass |
| 352 | +class ConfigSpec: |
| 353 | + """Specification for a pipeline configuration parameter.""" |
| 354 | + name: str |
| 355 | + default: Any |
| 356 | + description: Optional[str] = None |
| 357 | + type_hint: Optional[Type] = None |
| 358 | + |
341 | 359 | class PipelineBlock:
|
342 |
| - # YiYi Notes: do we need this? |
343 |
| - # pipelie block should set the default value for all expected config/components, so maybe we do not need to explicitly set the list |
344 |
| - expected_components = [] |
345 |
| - expected_configs = [] |
| 360 | + |
| 361 | + component_specs: List[ComponentSpec] = [] |
| 362 | + config_specs: List[ConfigSpec] = [] |
346 | 363 | model_name = None
|
347 | 364 |
|
348 | 365 | @property
|
@@ -409,14 +426,45 @@ def __repr__(self):
|
409 | 426 | desc = '\n'.join(desc) + '\n'
|
410 | 427 |
|
411 | 428 | # Components section
|
412 |
| - expected_components = set(getattr(self, "expected_components", [])) |
| 429 | + expected_components = getattr(self, "expected_components", []) |
| 430 | + expected_component_names = {comp.name for comp in expected_components} if expected_components else set() |
413 | 431 | loaded_components = set(self.components.keys())
|
414 |
| - all_components = sorted(expected_components | loaded_components) |
| 432 | + all_components = sorted(expected_component_names | loaded_components) |
415 | 433 |
|
416 | 434 | main_components = []
|
417 | 435 | auxiliary_components = []
|
418 | 436 | for k in all_components:
|
419 |
| - component_str = f" - {k}={type(self.components[k]).__name__}" if k in loaded_components else f" - {k}" |
| 437 | + # Get component spec if available |
| 438 | + component_spec = next((comp for comp in expected_components if comp.name == k), None) |
| 439 | + |
| 440 | + if k in loaded_components: |
| 441 | + component_type = type(self.components[k]).__name__ |
| 442 | + component_str = f" - {k}={component_type}" |
| 443 | + |
| 444 | + # Add expected type info if available |
| 445 | + if component_spec and component_spec.class_name: |
| 446 | + expected_type = component_spec.class_name |
| 447 | + if isinstance(expected_type, (list, tuple)): |
| 448 | + expected_type = expected_type[1] # Get class name from [module, class_name] |
| 449 | + if expected_type != component_type: |
| 450 | + component_str += f" (expected: {expected_type})" |
| 451 | + else: |
| 452 | + # Component not loaded but expected |
| 453 | + if component_spec: |
| 454 | + expected_type = component_spec.class_name |
| 455 | + if isinstance(expected_type, (list, tuple)): |
| 456 | + expected_type = expected_type[1] # Get class name from [module, class_name] |
| 457 | + component_str = f" - {k} (expected: {expected_type})" |
| 458 | + |
| 459 | + # Add repo info if available |
| 460 | + if component_spec.default_repo: |
| 461 | + repo_info = component_spec.default_repo |
| 462 | + if component_spec.subfolder: |
| 463 | + repo_info += f", subfolder={component_spec.subfolder}" |
| 464 | + component_str += f" [{repo_info}]" |
| 465 | + else: |
| 466 | + component_str = f" - {k}" |
| 467 | + |
420 | 468 | if k in getattr(self, "auxiliary_components", []):
|
421 | 469 | auxiliary_components.append(component_str)
|
422 | 470 | else:
|
@@ -793,18 +841,52 @@ def __repr__(self):
|
793 | 841 | desc = '\n'.join(desc) + '\n'
|
794 | 842 |
|
795 | 843 | # Components section
|
796 |
| - expected_components = set(getattr(self, "expected_components", [])) |
| 844 | + expected_components = getattr(self, "expected_components", []) |
| 845 | + expected_component_names = {comp.name for comp in expected_components} if expected_components else set() |
797 | 846 | loaded_components = set(self.components.keys())
|
798 |
| - all_components = sorted(expected_components | loaded_components) |
799 |
| - components_str = " Components:\n" + "\n".join( |
800 |
| - f" - {k}={type(self.components[k]).__name__}" if k in loaded_components else f" - {k}" |
801 |
| - for k in all_components |
802 |
| - ) |
| 847 | + all_components = sorted(expected_component_names | loaded_components) |
803 | 848 |
|
804 | 849 | # Auxiliaries section
|
805 | 850 | auxiliaries_str = " Auxiliaries:\n" + "\n".join(
|
806 | 851 | f" - {k}={type(v).__name__}" for k, v in self.auxiliaries.items()
|
807 | 852 | )
|
| 853 | + main_components = [] |
| 854 | + for k in all_components: |
| 855 | + # Get component spec if available |
| 856 | + component_spec = next((comp for comp in expected_components if comp.name == k), None) |
| 857 | + |
| 858 | + if k in loaded_components: |
| 859 | + component_type = type(self.components[k]).__name__ |
| 860 | + component_str = f" - {k}={component_type}" |
| 861 | + |
| 862 | + # Add expected type info if available |
| 863 | + if component_spec and component_spec.class_name: |
| 864 | + expected_type = component_spec.class_name |
| 865 | + if isinstance(expected_type, (list, tuple)): |
| 866 | + expected_type = expected_type[1] # Get class name from [module, class_name] |
| 867 | + if expected_type != component_type: |
| 868 | + component_str += f" (expected: {expected_type})" |
| 869 | + else: |
| 870 | + # Component not loaded but expected |
| 871 | + if component_spec: |
| 872 | + expected_type = component_spec.class_name |
| 873 | + if isinstance(expected_type, (list, tuple)): |
| 874 | + expected_type = expected_type[1] # Get class name from [module, class_name] |
| 875 | + component_str = f" - {k} (expected: {expected_type})" |
| 876 | + |
| 877 | + # Add repo info if available |
| 878 | + if component_spec.default_repo: |
| 879 | + repo_info = component_spec.default_repo |
| 880 | + if component_spec.subfolder: |
| 881 | + repo_info += f", subfolder={component_spec.subfolder}" |
| 882 | + component_str += f" [{repo_info}]" |
| 883 | + else: |
| 884 | + component_str = f" - {k}" |
| 885 | + |
| 886 | + |
| 887 | + main_components.append(component_str) |
| 888 | + |
| 889 | + components = "Components:\n" + "\n".join(main_components) |
808 | 890 |
|
809 | 891 | # Configs section
|
810 | 892 | expected_configs = set(getattr(self, "expected_configs", []))
|
@@ -1188,19 +1270,54 @@ def __repr__(self):
|
1188 | 1270 | desc = '\n'.join(desc) + '\n'
|
1189 | 1271 |
|
1190 | 1272 | # Components section
|
1191 |
| - expected_components = set(getattr(self, "expected_components", [])) |
| 1273 | + expected_components = getattr(self, "expected_components", []) |
| 1274 | + expected_component_names = {comp.name for comp in expected_components} if expected_components else set() |
1192 | 1275 | loaded_components = set(self.components.keys())
|
1193 |
| - all_components = sorted(expected_components | loaded_components) |
1194 |
| - components_str = " Components:\n" + "\n".join( |
1195 |
| - f" - {k}={type(self.components[k]).__name__}" if k in loaded_components else f" - {k}" |
1196 |
| - for k in all_components |
1197 |
| - ) |
| 1276 | + all_components = sorted(expected_component_names | loaded_components) |
1198 | 1277 |
|
1199 | 1278 | # Auxiliaries section
|
1200 | 1279 | auxiliaries_str = " Auxiliaries:\n" + "\n".join(
|
1201 | 1280 | f" - {k}={type(v).__name__}" for k, v in self.auxiliaries.items()
|
1202 | 1281 | )
|
1203 | 1282 |
|
| 1283 | + main_components = [] |
| 1284 | + for k in all_components: |
| 1285 | + # Get component spec if available |
| 1286 | + component_spec = next((comp for comp in expected_components if comp.name == k), None) |
| 1287 | + |
| 1288 | + if k in loaded_components: |
| 1289 | + component_type = type(self.components[k]).__name__ |
| 1290 | + component_str = f" - {k}={component_type}" |
| 1291 | + |
| 1292 | + # Add expected type info if available |
| 1293 | + if component_spec and component_spec.class_name: |
| 1294 | + expected_type = component_spec.class_name |
| 1295 | + if isinstance(expected_type, (list, tuple)): |
| 1296 | + expected_type = expected_type[1] # Get class name from [module, class_name] |
| 1297 | + if expected_type != component_type: |
| 1298 | + component_str += f" (expected: {expected_type})" |
| 1299 | + else: |
| 1300 | + # Component not loaded but expected |
| 1301 | + if component_spec: |
| 1302 | + expected_type = component_spec.class_name |
| 1303 | + if isinstance(expected_type, (list, tuple)): |
| 1304 | + expected_type = expected_type[1] # Get class name from [module, class_name] |
| 1305 | + component_str = f" - {k} (expected: {expected_type})" |
| 1306 | + |
| 1307 | + # Add repo info if available |
| 1308 | + if component_spec.default_repo: |
| 1309 | + repo_info = component_spec.default_repo |
| 1310 | + if component_spec.subfolder: |
| 1311 | + repo_info += f", subfolder={component_spec.subfolder}" |
| 1312 | + component_str += f" [{repo_info}]" |
| 1313 | + else: |
| 1314 | + component_str = f" - {k}" |
| 1315 | + |
| 1316 | + |
| 1317 | + main_components.append(component_str) |
| 1318 | + |
| 1319 | + components = "Components:\n" + "\n".join(main_components) |
| 1320 | + |
1204 | 1321 | # Configs section
|
1205 | 1322 | expected_configs = set(getattr(self, "expected_configs", []))
|
1206 | 1323 | loaded_configs = set(self.configs.keys())
|
@@ -1558,7 +1675,7 @@ def __repr__(self):
|
1558 | 1675 |
|
1559 | 1676 | return output
|
1560 | 1677 |
|
1561 |
| - # YiYi TO-DO: try to unify the to method with the one in DiffusionPipeline |
| 1678 | + # YiYi TODO: try to unify the to method with the one in DiffusionPipeline |
1562 | 1679 | # Modified from diffusers.pipelines.pipeline_utils.DiffusionPipeline.to
|
1563 | 1680 | def to(self, *args, **kwargs):
|
1564 | 1681 | r"""
|
|
0 commit comments