@@ -843,11 +843,11 @@ def save_lora_weights(
843
843
844
844
if not (unet_lora_layers or text_encoder_lora_layers or text_encoder_2_lora_layers ):
845
845
raise ValueError (
846
- "You must pass at least one of `unet_lora_layers`, `text_encoder_lora_layers` or `text_encoder_2_lora_layers`."
846
+ "You must pass at least one of `unet_lora_layers`, `text_encoder_lora_layers`, `text_encoder_2_lora_layers`."
847
847
)
848
848
849
849
if unet_lora_layers :
850
- state_dict .update (cls .pack_weights (unet_lora_layers , "unet" ))
850
+ state_dict .update (cls .pack_weights (unet_lora_layers , cls . unet_name ))
851
851
852
852
if text_encoder_lora_layers :
853
853
state_dict .update (cls .pack_weights (text_encoder_lora_layers , "text_encoder" ))
@@ -1210,10 +1210,11 @@ def load_lora_into_text_encoder(
1210
1210
)
1211
1211
1212
1212
@classmethod
1213
+ # Copied from diffusers.loaders.lora_pipeline.StableDiffusionXLLoraLoaderMixin.save_lora_weights with unet->transformer
1213
1214
def save_lora_weights (
1214
1215
cls ,
1215
1216
save_directory : Union [str , os .PathLike ],
1216
- transformer_lora_layers : Dict [str , torch .nn .Module ] = None ,
1217
+ transformer_lora_layers : Dict [str , Union [ torch .nn .Module , torch . Tensor ] ] = None ,
1217
1218
text_encoder_lora_layers : Dict [str , Union [torch .nn .Module , torch .Tensor ]] = None ,
1218
1219
text_encoder_2_lora_layers : Dict [str , Union [torch .nn .Module , torch .Tensor ]] = None ,
1219
1220
is_main_process : bool = True ,
@@ -1262,7 +1263,6 @@ def save_lora_weights(
1262
1263
if text_encoder_2_lora_layers :
1263
1264
state_dict .update (cls .pack_weights (text_encoder_2_lora_layers , "text_encoder_2" ))
1264
1265
1265
- # Save the model
1266
1266
cls .write_lora_layers (
1267
1267
state_dict = state_dict ,
1268
1268
save_directory = save_directory ,
@@ -1272,6 +1272,7 @@ def save_lora_weights(
1272
1272
safe_serialization = safe_serialization ,
1273
1273
)
1274
1274
1275
+ # Copied from diffusers.loaders.lora_pipeline.StableDiffusionXLLoraLoaderMixin.fuse_lora with unet->transformer
1275
1276
def fuse_lora (
1276
1277
self ,
1277
1278
components : List [str ] = ["transformer" , "text_encoder" , "text_encoder_2" ],
@@ -1315,6 +1316,7 @@ def fuse_lora(
1315
1316
components = components , lora_scale = lora_scale , safe_fusing = safe_fusing , adapter_names = adapter_names
1316
1317
)
1317
1318
1319
+ # Copied from diffusers.loaders.lora_pipeline.StableDiffusionXLLoraLoaderMixin.unfuse_lora with unet->transformer
1318
1320
def unfuse_lora (self , components : List [str ] = ["transformer" , "text_encoder" , "text_encoder_2" ], ** kwargs ):
1319
1321
r"""
1320
1322
Reverses the effect of
@@ -1328,7 +1330,7 @@ def unfuse_lora(self, components: List[str] = ["transformer", "text_encoder", "t
1328
1330
1329
1331
Args:
1330
1332
components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
1331
- unfuse_unet (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
1333
+ unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
1332
1334
unfuse_text_encoder (`bool`, defaults to `True`):
1333
1335
Whether to unfuse the text encoder LoRA parameters. If the text encoder wasn't monkey-patched with the
1334
1336
LoRA parameters then it won't have any effect.
@@ -2833,6 +2835,7 @@ def save_lora_weights(
2833
2835
safe_serialization = safe_serialization ,
2834
2836
)
2835
2837
2838
+ # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.fuse_lora
2836
2839
def fuse_lora (
2837
2840
self ,
2838
2841
components : List [str ] = ["transformer" ],
@@ -2876,6 +2879,7 @@ def fuse_lora(
2876
2879
components = components , lora_scale = lora_scale , safe_fusing = safe_fusing , adapter_names = adapter_names
2877
2880
)
2878
2881
2882
+ # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.unfuse_lora
2879
2883
def unfuse_lora (self , components : List [str ] = ["transformer" ], ** kwargs ):
2880
2884
r"""
2881
2885
Reverses the effect of
@@ -3136,6 +3140,7 @@ def save_lora_weights(
3136
3140
safe_serialization = safe_serialization ,
3137
3141
)
3138
3142
3143
+ # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.fuse_lora
3139
3144
def fuse_lora (
3140
3145
self ,
3141
3146
components : List [str ] = ["transformer" ],
@@ -3179,6 +3184,7 @@ def fuse_lora(
3179
3184
components = components , lora_scale = lora_scale , safe_fusing = safe_fusing , adapter_names = adapter_names
3180
3185
)
3181
3186
3187
+ # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.unfuse_lora
3182
3188
def unfuse_lora (self , components : List [str ] = ["transformer" ], ** kwargs ):
3183
3189
r"""
3184
3190
Reverses the effect of
@@ -3439,6 +3445,7 @@ def save_lora_weights(
3439
3445
safe_serialization = safe_serialization ,
3440
3446
)
3441
3447
3448
+ # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.fuse_lora
3442
3449
def fuse_lora (
3443
3450
self ,
3444
3451
components : List [str ] = ["transformer" ],
@@ -3482,6 +3489,7 @@ def fuse_lora(
3482
3489
components = components , lora_scale = lora_scale , safe_fusing = safe_fusing , adapter_names = adapter_names
3483
3490
)
3484
3491
3492
+ # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.unfuse_lora
3485
3493
def unfuse_lora (self , components : List [str ] = ["transformer" ], ** kwargs ):
3486
3494
r"""
3487
3495
Reverses the effect of
@@ -3745,6 +3753,7 @@ def save_lora_weights(
3745
3753
safe_serialization = safe_serialization ,
3746
3754
)
3747
3755
3756
+ # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.fuse_lora
3748
3757
def fuse_lora (
3749
3758
self ,
3750
3759
components : List [str ] = ["transformer" ],
@@ -3788,6 +3797,7 @@ def fuse_lora(
3788
3797
components = components , lora_scale = lora_scale , safe_fusing = safe_fusing , adapter_names = adapter_names
3789
3798
)
3790
3799
3800
+ # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.unfuse_lora
3791
3801
def unfuse_lora (self , components : List [str ] = ["transformer" ], ** kwargs ):
3792
3802
r"""
3793
3803
Reverses the effect of
0 commit comments