Skip to content

Commit bf99ab2

Browse files
committed
up
1 parent ee84283 commit bf99ab2

File tree

2 files changed

+593
-1915
lines changed

2 files changed

+593
-1915
lines changed

src/diffusers/pipelines/modular_pipeline.py

Lines changed: 34 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -138,22 +138,40 @@ def format_value(v):
138138
return f"BlockState(\n{attributes}\n)"
139139

140140

141+
@dataclass
142+
class ComponentSpec:
143+
"""Specification for a pipeline component."""
144+
name: str
145+
type_hint: Type
146+
description: Optional[str] = None
147+
default: Any = None # you can create a default component if it is a stateless class like scheduler, guider or image processor
148+
default_class_name: Union[str, List[str], Tuple[str, str]] = None # Either "class_name" or ["module", "class_name"]
149+
default_repo: Optional[Union[str, List[str]]] = None # either "repo" or ["repo", "subfolder"]
150+
151+
@dataclass
152+
class ConfigSpec:
153+
"""Specification for a pipeline configuration parameter."""
154+
name: str
155+
default: Any
156+
description: Optional[str] = None
157+
158+
141159
@dataclass
142160
class InputParam:
143161
name: str
162+
type_hint: Any = None
144163
default: Any = None
145164
required: bool = False
146165
description: str = ""
147-
type_hint: Any = Any
148166

149167
def __repr__(self):
150168
return f"<{self.name}: {'required' if self.required else 'optional'}, default={self.default}>"
151169

152170
@dataclass
153171
class OutputParam:
154172
name: str
173+
type_hint: Any
155174
description: str = ""
156-
type_hint: Any = Any
157175

158176
def __repr__(self):
159177
return f"<{self.name}: {self.type_hint.__name__ if hasattr(self.type_hint, '__name__') else str(self.type_hint)}>"
@@ -338,49 +356,40 @@ def make_doc_string(inputs, intermediates_inputs, outputs, description=""):
338356
return output
339357

340358

341-
@dataclass
342-
class ComponentSpec:
343-
"""Specification for a pipeline component."""
344-
name: str
345-
type_hint: Optional[Type] = None
346-
description: Optional[str] = None
347-
default: Any = None # you can create a default component if it is a stateless class like scheduler, guider or image processor
348-
default_class_name: Union[str, List[str], Tuple[str, str]] # Either "class_name" or ["module", "class_name"]
349-
default_repo: Optional[Union[str, List[str]]] = None # either "repo" or ["repo", "subfolder"]
350-
351-
@dataclass
352-
class ConfigSpec:
353-
"""Specification for a pipeline configuration parameter."""
354-
name: str
355-
default: Any
356-
description: Optional[str] = None
357-
type_hint: Optional[Type] = None
358359

359360
class PipelineBlock:
360-
361-
component_specs: List[ComponentSpec] = []
362-
config_specs: List[ConfigSpec] = []
361+
363362
model_name = None
364363

365364
@property
366365
def description(self) -> str:
367366
"""Description of the block. Must be implemented by subclasses."""
368367
raise NotImplementedError("description method must be implemented in subclasses")
368+
369+
@property
370+
def components(self) -> List[ComponentSpec]:
371+
return []
369372

373+
@property
374+
def configs(self) -> List[ConfigSpec]:
375+
return []
376+
377+
378+
# YiYi TODO: can we combine inputs and intermediates_inputs? the difference is inputs are immutable
370379
@property
371380
def inputs(self) -> List[InputParam]:
372381
"""List of input parameters. Must be implemented by subclasses."""
373-
raise NotImplementedError("inputs method must be implemented in subclasses")
382+
return []
374383

375384
@property
376385
def intermediates_inputs(self) -> List[InputParam]:
377386
"""List of intermediate input parameters. Must be implemented by subclasses."""
378-
raise NotImplementedError("intermediates_inputs method must be implemented in subclasses")
387+
return []
379388

380389
@property
381390
def intermediates_outputs(self) -> List[OutputParam]:
382391
"""List of intermediate output parameters. Must be implemented by subclasses."""
383-
raise NotImplementedError("intermediates_outputs method must be implemented in subclasses")
392+
return []
384393

385394
# Adding outputs attributes here for consistency between PipelineBlock/AutoPipelineBlocks/SequentialPipelineBlocks
386395
@property
@@ -403,10 +412,6 @@ def required_intermediates_inputs(self) -> List[str]:
403412
input_names.append(input_param.name)
404413
return input_names
405414

406-
def __init__(self):
407-
self.components: Dict[str, Any] = {}
408-
self.auxiliaries: Dict[str, Any] = {}
409-
self.configs: Dict[str, Any] = {}
410415

411416
def __call__(self, pipeline, state: PipelineState) -> PipelineState:
412417
raise NotImplementedError("__call__ method must be implemented in subclasses")

0 commit comments

Comments
 (0)