Skip to content

Commit 1fddee2

Browse files
authored
[LoRA] Improve copied from comments in the LoRA loader classes (#10995)
* more sanity of mind with copied from ... * better * better
1 parent b38450d commit 1fddee2

File tree

1 file changed

+15
-5
lines changed

1 file changed

+15
-5
lines changed

src/diffusers/loaders/lora_pipeline.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -843,11 +843,11 @@ def save_lora_weights(
843843

844844
if not (unet_lora_layers or text_encoder_lora_layers or text_encoder_2_lora_layers):
845845
raise ValueError(
846-
"You must pass at least one of `unet_lora_layers`, `text_encoder_lora_layers` or `text_encoder_2_lora_layers`."
846+
"You must pass at least one of `unet_lora_layers`, `text_encoder_lora_layers`, `text_encoder_2_lora_layers`."
847847
)
848848

849849
if unet_lora_layers:
850-
state_dict.update(cls.pack_weights(unet_lora_layers, "unet"))
850+
state_dict.update(cls.pack_weights(unet_lora_layers, cls.unet_name))
851851

852852
if text_encoder_lora_layers:
853853
state_dict.update(cls.pack_weights(text_encoder_lora_layers, "text_encoder"))
@@ -1210,10 +1210,11 @@ def load_lora_into_text_encoder(
12101210
)
12111211

12121212
@classmethod
1213+
# Copied from diffusers.loaders.lora_pipeline.StableDiffusionXLLoraLoaderMixin.save_lora_weights with unet->transformer
12131214
def save_lora_weights(
12141215
cls,
12151216
save_directory: Union[str, os.PathLike],
1216-
transformer_lora_layers: Dict[str, torch.nn.Module] = None,
1217+
transformer_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
12171218
text_encoder_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
12181219
text_encoder_2_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
12191220
is_main_process: bool = True,
@@ -1262,7 +1263,6 @@ def save_lora_weights(
12621263
if text_encoder_2_lora_layers:
12631264
state_dict.update(cls.pack_weights(text_encoder_2_lora_layers, "text_encoder_2"))
12641265

1265-
# Save the model
12661266
cls.write_lora_layers(
12671267
state_dict=state_dict,
12681268
save_directory=save_directory,
@@ -1272,6 +1272,7 @@ def save_lora_weights(
12721272
safe_serialization=safe_serialization,
12731273
)
12741274

1275+
# Copied from diffusers.loaders.lora_pipeline.StableDiffusionXLLoraLoaderMixin.fuse_lora with unet->transformer
12751276
def fuse_lora(
12761277
self,
12771278
components: List[str] = ["transformer", "text_encoder", "text_encoder_2"],
@@ -1315,6 +1316,7 @@ def fuse_lora(
13151316
components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names
13161317
)
13171318

1319+
# Copied from diffusers.loaders.lora_pipeline.StableDiffusionXLLoraLoaderMixin.unfuse_lora with unet->transformer
13181320
def unfuse_lora(self, components: List[str] = ["transformer", "text_encoder", "text_encoder_2"], **kwargs):
13191321
r"""
13201322
Reverses the effect of
@@ -1328,7 +1330,7 @@ def unfuse_lora(self, components: List[str] = ["transformer", "text_encoder", "t
13281330
13291331
Args:
13301332
components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
1331-
unfuse_unet (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
1333+
unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
13321334
unfuse_text_encoder (`bool`, defaults to `True`):
13331335
Whether to unfuse the text encoder LoRA parameters. If the text encoder wasn't monkey-patched with the
13341336
LoRA parameters then it won't have any effect.
@@ -2833,6 +2835,7 @@ def save_lora_weights(
28332835
safe_serialization=safe_serialization,
28342836
)
28352837

2838+
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.fuse_lora
28362839
def fuse_lora(
28372840
self,
28382841
components: List[str] = ["transformer"],
@@ -2876,6 +2879,7 @@ def fuse_lora(
28762879
components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names
28772880
)
28782881

2882+
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.unfuse_lora
28792883
def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs):
28802884
r"""
28812885
Reverses the effect of
@@ -3136,6 +3140,7 @@ def save_lora_weights(
31363140
safe_serialization=safe_serialization,
31373141
)
31383142

3143+
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.fuse_lora
31393144
def fuse_lora(
31403145
self,
31413146
components: List[str] = ["transformer"],
@@ -3179,6 +3184,7 @@ def fuse_lora(
31793184
components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names
31803185
)
31813186

3187+
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.unfuse_lora
31823188
def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs):
31833189
r"""
31843190
Reverses the effect of
@@ -3439,6 +3445,7 @@ def save_lora_weights(
34393445
safe_serialization=safe_serialization,
34403446
)
34413447

3448+
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.fuse_lora
34423449
def fuse_lora(
34433450
self,
34443451
components: List[str] = ["transformer"],
@@ -3482,6 +3489,7 @@ def fuse_lora(
34823489
components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names
34833490
)
34843491

3492+
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.unfuse_lora
34853493
def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs):
34863494
r"""
34873495
Reverses the effect of
@@ -3745,6 +3753,7 @@ def save_lora_weights(
37453753
safe_serialization=safe_serialization,
37463754
)
37473755

3756+
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.fuse_lora
37483757
def fuse_lora(
37493758
self,
37503759
components: List[str] = ["transformer"],
@@ -3788,6 +3797,7 @@ def fuse_lora(
37883797
components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names
37893798
)
37903799

3800+
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.unfuse_lora
37913801
def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs):
37923802
r"""
37933803
Reverses the effect of

0 commit comments

Comments
 (0)