Skip to content

Commit 39e1f7e

Browse files
asomozasayakpaul
andauthored
[Kolors] Add PAG (#8934)
* txt2img pag added * autopipe added, fixed case * style * apply suggestions * added fast tests, added todo tests * revert dummy objects for kolors * fix pag dummies * fix test imports * update pag tests * add kolor pag to docs --------- Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
1 parent e1b603d commit 39e1f7e

File tree

12 files changed

+1589
-19
lines changed

12 files changed

+1589
-19
lines changed

docs/source/en/api/pipelines/pag.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,11 @@ Since RegEx is supported as a way for matching layer identifiers, it is crucial
4343
- all
4444
- __call__
4545

46+
## KolorsPAGPipeline
47+
[[autodoc]] KolorsPAGPipeline
48+
- all
49+
- __call__
50+
4651
## StableDiffusionPAGPipeline
4752
[[autodoc]] StableDiffusionPAGPipeline
4853
- all

src/diffusers/__init__.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -280,8 +280,6 @@
280280
"KandinskyV22Pipeline",
281281
"KandinskyV22PriorEmb2EmbPipeline",
282282
"KandinskyV22PriorPipeline",
283-
"KolorsImg2ImgPipeline",
284-
"KolorsPipeline",
285283
"LatentConsistencyModelImg2ImgPipeline",
286284
"LatentConsistencyModelPipeline",
287285
"LattePipeline",
@@ -397,7 +395,7 @@
397395
]
398396

399397
else:
400-
_import_structure["pipelines"].extend(["KolorsImg2ImgPipeline", "KolorsPipeline"])
398+
_import_structure["pipelines"].extend(["KolorsImg2ImgPipeline", "KolorsPAGPipeline", "KolorsPipeline"])
401399

402400
try:
403401
if not (is_torch_available() and is_transformers_available() and is_onnx_available()):
@@ -820,7 +818,7 @@
820818
except OptionalDependencyNotAvailable:
821819
from .utils.dummy_torch_and_transformers_and_sentencepiece_objects import * # noqa F403
822820
else:
823-
from .pipelines import KolorsImg2ImgPipeline, KolorsPipeline
821+
from .pipelines import KolorsImg2ImgPipeline, KolorsPAGPipeline, KolorsPipeline
824822
try:
825823
if not (is_torch_available() and is_transformers_available() and is_onnx_available()):
826824
raise OptionalDependencyNotAvailable()

src/diffusers/pipelines/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,7 @@
146146
_import_structure["pag"].extend(
147147
[
148148
"AnimateDiffPAGPipeline",
149+
"KolorsPAGPipeline",
149150
"HunyuanDiTPAGPipeline",
150151
"StableDiffusion3PAGPipeline",
151152
"StableDiffusionPAGPipeline",
@@ -540,6 +541,7 @@
540541
from .pag import (
541542
AnimateDiffPAGPipeline,
542543
HunyuanDiTPAGPipeline,
544+
KolorsPAGPipeline,
543545
PixArtSigmaPAGPipeline,
544546
StableDiffusion3PAGPipeline,
545547
StableDiffusionControlNetPAGPipeline,

src/diffusers/pipelines/auto_pipeline.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -162,8 +162,10 @@
162162

163163
if is_sentencepiece_available():
164164
from .kolors import KolorsPipeline
165+
from .pag import KolorsPAGPipeline
165166

166167
AUTO_TEXT2IMAGE_PIPELINES_MAPPING["kolors"] = KolorsPipeline
168+
AUTO_TEXT2IMAGE_PIPELINES_MAPPING["kolors-pag"] = KolorsPAGPipeline
167169
AUTO_IMAGE2IMAGE_PIPELINES_MAPPING["kolors"] = KolorsPipeline
168170

169171
SUPPORTED_TASKS_MAPPINGS = [

src/diffusers/pipelines/kolors/tokenizer.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -143,10 +143,18 @@ def get_command(self, token):
143143
def unk_token(self) -> str:
144144
return "<unk>"
145145

146+
@unk_token.setter
147+
def unk_token(self, value: str):
148+
self._unk_token = value
149+
146150
@property
147151
def pad_token(self) -> str:
148152
return "<unk>"
149153

154+
@pad_token.setter
155+
def pad_token(self, value: str):
156+
self._pad_token = value
157+
150158
@property
151159
def pad_token_id(self):
152160
return self.get_command("<pad>")
@@ -155,6 +163,10 @@ def pad_token_id(self):
155163
def eos_token(self) -> str:
156164
return "</s>"
157165

166+
@eos_token.setter
167+
def eos_token(self, value: str):
168+
self._eos_token = value
169+
158170
@property
159171
def eos_token_id(self):
160172
return self.get_command("<eos>")

src/diffusers/pipelines/pag/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
_import_structure["pipeline_pag_controlnet_sd"] = ["StableDiffusionControlNetPAGPipeline"]
2626
_import_structure["pipeline_pag_controlnet_sd_xl"] = ["StableDiffusionXLControlNetPAGPipeline"]
2727
_import_structure["pipeline_pag_hunyuandit"] = ["HunyuanDiTPAGPipeline"]
28+
_import_structure["pipeline_pag_kolors"] = ["KolorsPAGPipeline"]
2829
_import_structure["pipeline_pag_pixart_sigma"] = ["PixArtSigmaPAGPipeline"]
2930
_import_structure["pipeline_pag_sd"] = ["StableDiffusionPAGPipeline"]
3031
_import_structure["pipeline_pag_sd_3"] = ["StableDiffusion3PAGPipeline"]
@@ -44,6 +45,7 @@
4445
from .pipeline_pag_controlnet_sd import StableDiffusionControlNetPAGPipeline
4546
from .pipeline_pag_controlnet_sd_xl import StableDiffusionXLControlNetPAGPipeline
4647
from .pipeline_pag_hunyuandit import HunyuanDiTPAGPipeline
48+
from .pipeline_pag_kolors import KolorsPAGPipeline
4749
from .pipeline_pag_pixart_sigma import PixArtSigmaPAGPipeline
4850
from .pipeline_pag_sd import StableDiffusionPAGPipeline
4951
from .pipeline_pag_sd_3 import StableDiffusion3PAGPipeline

0 commit comments

Comments
 (0)