diff --git a/src/diffusers/loaders/lora_pipeline.py b/src/diffusers/loaders/lora_pipeline.py index c5cb27a35f3c..e48725b01ca2 100644 --- a/src/diffusers/loaders/lora_pipeline.py +++ b/src/diffusers/loaders/lora_pipeline.py @@ -843,11 +843,11 @@ def save_lora_weights( if not (unet_lora_layers or text_encoder_lora_layers or text_encoder_2_lora_layers): raise ValueError( - "You must pass at least one of `unet_lora_layers`, `text_encoder_lora_layers` or `text_encoder_2_lora_layers`." + "You must pass at least one of `unet_lora_layers`, `text_encoder_lora_layers`, `text_encoder_2_lora_layers`." ) if unet_lora_layers: - state_dict.update(cls.pack_weights(unet_lora_layers, "unet")) + state_dict.update(cls.pack_weights(unet_lora_layers, cls.unet_name)) if text_encoder_lora_layers: state_dict.update(cls.pack_weights(text_encoder_lora_layers, "text_encoder")) @@ -1210,10 +1210,11 @@ def load_lora_into_text_encoder( ) @classmethod + # Copied from diffusers.loaders.lora_pipeline.StableDiffusionXLLoraLoaderMixin.save_lora_weights with unet->transformer def save_lora_weights( cls, save_directory: Union[str, os.PathLike], - transformer_lora_layers: Dict[str, torch.nn.Module] = None, + transformer_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None, text_encoder_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None, text_encoder_2_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None, is_main_process: bool = True, @@ -1262,7 +1263,6 @@ def save_lora_weights( if text_encoder_2_lora_layers: state_dict.update(cls.pack_weights(text_encoder_2_lora_layers, "text_encoder_2")) - # Save the model cls.write_lora_layers( state_dict=state_dict, save_directory=save_directory, @@ -1272,6 +1272,7 @@ def save_lora_weights( safe_serialization=safe_serialization, ) + # Copied from diffusers.loaders.lora_pipeline.StableDiffusionXLLoraLoaderMixin.fuse_lora with unet->transformer def fuse_lora( self, components: List[str] = ["transformer", "text_encoder", "text_encoder_2"], @@ -1315,6 +1316,7 @@ def fuse_lora( components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names ) + # Copied from diffusers.loaders.lora_pipeline.StableDiffusionXLLoraLoaderMixin.unfuse_lora with unet->transformer def unfuse_lora(self, components: List[str] = ["transformer", "text_encoder", "text_encoder_2"], **kwargs): r""" Reverses the effect of @@ -1328,7 +1330,7 @@ def unfuse_lora(self, components: List[str] = ["transformer", "text_encoder", "t Args: components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from. - unfuse_unet (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters. + unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters. unfuse_text_encoder (`bool`, defaults to `True`): Whether to unfuse the text encoder LoRA parameters. If the text encoder wasn't monkey-patched with the LoRA parameters then it won't have any effect. @@ -2833,6 +2835,7 @@ def save_lora_weights( safe_serialization=safe_serialization, ) + # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.fuse_lora def fuse_lora( self, components: List[str] = ["transformer"], @@ -2876,6 +2879,7 @@ def fuse_lora( components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names ) + # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.unfuse_lora def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs): r""" Reverses the effect of @@ -3136,6 +3140,7 @@ def save_lora_weights( safe_serialization=safe_serialization, ) + # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.fuse_lora def fuse_lora( self, components: List[str] = ["transformer"], @@ -3179,6 +3184,7 @@ def fuse_lora( components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names ) + # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.unfuse_lora def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs): r""" Reverses the effect of @@ -3439,6 +3445,7 @@ def save_lora_weights( safe_serialization=safe_serialization, ) + # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.fuse_lora def fuse_lora( self, components: List[str] = ["transformer"], @@ -3482,6 +3489,7 @@ def fuse_lora( components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names ) + # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.unfuse_lora def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs): r""" Reverses the effect of @@ -3745,6 +3753,7 @@ def save_lora_weights( safe_serialization=safe_serialization, ) + # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.fuse_lora def fuse_lora( self, components: List[str] = ["transformer"], @@ -3788,6 +3797,7 @@ def fuse_lora( components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names ) + # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.unfuse_lora def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs): r""" Reverses the effect of