diff --git a/examples/community/ip_adapter_face_id.py b/examples/community/ip_adapter_face_id.py index 203be1d4c874..7a7f1bcbd31a 100644 --- a/examples/community/ip_adapter_face_id.py +++ b/examples/community/ip_adapter_face_id.py @@ -282,10 +282,7 @@ def load_ip_adapter_face_id(self, pretrained_model_name_or_path_or_dict, weight_ revision = kwargs.pop("revision", None) subfolder = kwargs.pop("subfolder", None) - user_agent = { - "file_type": "attn_procs_weights", - "framework": "pytorch", - } + user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"} model_file = _get_model_file( pretrained_model_name_or_path_or_dict, weights_name=weight_name, diff --git a/src/diffusers/loaders/ip_adapter.py b/src/diffusers/loaders/ip_adapter.py index f4c48f254c44..63b4f226434d 100644 --- a/src/diffusers/loaders/ip_adapter.py +++ b/src/diffusers/loaders/ip_adapter.py @@ -159,10 +159,7 @@ def load_ip_adapter( " `low_cpu_mem_usage=False`." ) - user_agent = { - "file_type": "attn_procs_weights", - "framework": "pytorch", - } + user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"} state_dicts = [] for pretrained_model_name_or_path_or_dict, weight_name, subfolder in zip( pretrained_model_name_or_path_or_dict, weight_name, subfolder @@ -465,10 +462,7 @@ def load_ip_adapter( " `low_cpu_mem_usage=False`." ) - user_agent = { - "file_type": "attn_procs_weights", - "framework": "pytorch", - } + user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"} state_dicts = [] for pretrained_model_name_or_path_or_dict, weight_name, subfolder in zip( pretrained_model_name_or_path_or_dict, weight_name, subfolder @@ -750,10 +744,7 @@ def load_ip_adapter( " `low_cpu_mem_usage=False`." ) - user_agent = { - "file_type": "attn_procs_weights", - "framework": "pytorch", - } + user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"} if not isinstance(pretrained_model_name_or_path_or_dict, dict): model_file = _get_model_file( diff --git a/src/diffusers/loaders/lora_base.py b/src/diffusers/loaders/lora_base.py index 32d7c773d2b0..b20b56340ea4 100644 --- a/src/diffusers/loaders/lora_base.py +++ b/src/diffusers/loaders/lora_base.py @@ -14,6 +14,7 @@ import copy import inspect +import json import os from pathlib import Path from typing import Callable, Dict, List, Optional, Union @@ -45,6 +46,7 @@ set_adapter_layers, set_weights_and_activate_adapters, ) +from ..utils.state_dict_utils import _load_sft_state_dict_metadata if is_transformers_available(): @@ -62,6 +64,7 @@ LORA_WEIGHT_NAME = "pytorch_lora_weights.bin" LORA_WEIGHT_NAME_SAFE = "pytorch_lora_weights.safetensors" +LORA_ADAPTER_METADATA_KEY = "lora_adapter_metadata" def fuse_text_encoder_lora(text_encoder, lora_scale=1.0, safe_fusing=False, adapter_names=None): @@ -206,6 +209,7 @@ def _fetch_state_dict( subfolder, user_agent, allow_pickle, + metadata=None, ): model_file = None if not isinstance(pretrained_model_name_or_path_or_dict, dict): @@ -236,11 +240,14 @@ def _fetch_state_dict( user_agent=user_agent, ) state_dict = safetensors.torch.load_file(model_file, device="cpu") + metadata = _load_sft_state_dict_metadata(model_file) + except (IOError, safetensors.SafetensorError) as e: if not allow_pickle: raise e # try loading non-safetensors weights model_file = None + metadata = None pass if model_file is None: @@ -261,10 +268,11 @@ def _fetch_state_dict( user_agent=user_agent, ) state_dict = load_state_dict(model_file) + metadata = None else: state_dict = pretrained_model_name_or_path_or_dict - return state_dict + return state_dict, metadata def _best_guess_weight_name( @@ -306,6 +314,11 @@ def _best_guess_weight_name( return weight_name +def _pack_dict_with_prefix(state_dict, prefix): + sd_with_prefix = {f"{prefix}.{key}": value for key, value in state_dict.items()} + return sd_with_prefix + + def _load_lora_into_text_encoder( state_dict, network_alphas, @@ -317,10 +330,14 @@ def _load_lora_into_text_encoder( _pipeline=None, low_cpu_mem_usage=False, hotswap: bool = False, + metadata=None, ): if not USE_PEFT_BACKEND: raise ValueError("PEFT backend is required for this method.") + if network_alphas and metadata: + raise ValueError("`network_alphas` and `metadata` cannot be specified both at the same time.") + peft_kwargs = {} if low_cpu_mem_usage: if not is_peft_version(">=", "0.13.1"): @@ -349,6 +366,8 @@ def _load_lora_into_text_encoder( # Load the layers corresponding to text encoder and make necessary adjustments. if prefix is not None: state_dict = {k.removeprefix(f"{prefix}."): v for k, v in state_dict.items() if k.startswith(f"{prefix}.")} + if metadata is not None: + metadata = {k.removeprefix(f"{prefix}."): v for k, v in metadata.items() if k.startswith(f"{prefix}.")} if len(state_dict) > 0: logger.info(f"Loading {prefix}.") @@ -376,7 +395,10 @@ def _load_lora_into_text_encoder( alpha_keys = [k for k in network_alphas.keys() if k.startswith(prefix) and k.split(".")[0] == prefix] network_alphas = {k.removeprefix(f"{prefix}."): v for k, v in network_alphas.items() if k in alpha_keys} - lora_config_kwargs = get_peft_kwargs(rank, network_alphas, state_dict, is_unet=False) + if metadata is not None: + lora_config_kwargs = metadata + else: + lora_config_kwargs = get_peft_kwargs(rank, network_alphas, state_dict, is_unet=False) if "use_dora" in lora_config_kwargs: if lora_config_kwargs["use_dora"]: @@ -398,7 +420,10 @@ def _load_lora_into_text_encoder( if is_peft_version("<=", "0.13.2"): lora_config_kwargs.pop("lora_bias") - lora_config = LoraConfig(**lora_config_kwargs) + try: + lora_config = LoraConfig(**lora_config_kwargs) + except TypeError as e: + raise TypeError("`LoraConfig` class could not be instantiated.") from e # adapter_name if adapter_name is None: @@ -889,8 +914,7 @@ def set_lora_device(self, adapter_names: List[str], device: Union[torch.device, @staticmethod def pack_weights(layers, prefix): layers_weights = layers.state_dict() if isinstance(layers, torch.nn.Module) else layers - layers_state_dict = {f"{prefix}.{module_name}": param for module_name, param in layers_weights.items()} - return layers_state_dict + return _pack_dict_with_prefix(layers_weights, prefix) @staticmethod def write_lora_layers( @@ -900,16 +924,32 @@ def write_lora_layers( weight_name: str, save_function: Callable, safe_serialization: bool, + lora_adapter_metadata: Optional[dict] = None, ): if os.path.isfile(save_directory): logger.error(f"Provided path ({save_directory}) should be a directory, not a file") return + if lora_adapter_metadata and not safe_serialization: + raise ValueError("`lora_adapter_metadata` cannot be specified when not using `safe_serialization`.") + if lora_adapter_metadata and not isinstance(lora_adapter_metadata, dict): + raise TypeError("`lora_adapter_metadata` must be of type `dict`.") + if save_function is None: if safe_serialization: def save_function(weights, filename): - return safetensors.torch.save_file(weights, filename, metadata={"format": "pt"}) + # Inject framework format. + metadata = {"format": "pt"} + if lora_adapter_metadata: + for key, value in lora_adapter_metadata.items(): + if isinstance(value, set): + lora_adapter_metadata[key] = list(value) + metadata[LORA_ADAPTER_METADATA_KEY] = json.dumps( + lora_adapter_metadata, indent=2, sort_keys=True + ) + + return safetensors.torch.save_file(weights, filename, metadata=metadata) else: save_function = torch.save diff --git a/src/diffusers/loaders/lora_pipeline.py b/src/diffusers/loaders/lora_pipeline.py index 189a9ceba541..27053623eeec 100644 --- a/src/diffusers/loaders/lora_pipeline.py +++ b/src/diffusers/loaders/lora_pipeline.py @@ -37,6 +37,7 @@ LoraBaseMixin, _fetch_state_dict, _load_lora_into_text_encoder, + _pack_dict_with_prefix, ) from .lora_conversion_utils import ( _convert_bfl_flux_control_lora_to_diffusers, @@ -202,7 +203,8 @@ def load_lora_weights( pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy() # First, ensure that the checkpoint is a compatible one and can be successfully loaded. - state_dict, network_alphas = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs) + kwargs["return_lora_metadata"] = True + state_dict, network_alphas, metadata = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs) is_correct_format = all("lora" in key for key in state_dict.keys()) if not is_correct_format: @@ -213,6 +215,7 @@ def load_lora_weights( network_alphas=network_alphas, unet=getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet, adapter_name=adapter_name, + metadata=metadata, _pipeline=self, low_cpu_mem_usage=low_cpu_mem_usage, hotswap=hotswap, @@ -226,6 +229,7 @@ def load_lora_weights( lora_scale=self.lora_scale, adapter_name=adapter_name, _pipeline=self, + metadata=metadata, low_cpu_mem_usage=low_cpu_mem_usage, hotswap=hotswap, ) @@ -282,6 +286,8 @@ def lora_state_dict( The subfolder location of a model file within a larger model repository on the Hub or locally. weight_name (`str`, *optional*, defaults to None): Name of the serialized state dict file. + return_lora_metadata (`bool`, *optional*, defaults to False): + When enabled, additionally return the LoRA adapter metadata, typically found in the state dict. """ # Load the main state dict first which has the LoRA layers for either of # UNet and text encoder or both. @@ -295,18 +301,16 @@ def lora_state_dict( weight_name = kwargs.pop("weight_name", None) unet_config = kwargs.pop("unet_config", None) use_safetensors = kwargs.pop("use_safetensors", None) + return_lora_metadata = kwargs.pop("return_lora_metadata", False) allow_pickle = False if use_safetensors is None: use_safetensors = True allow_pickle = True - user_agent = { - "file_type": "attn_procs_weights", - "framework": "pytorch", - } + user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"} - state_dict = _fetch_state_dict( + state_dict, metadata = _fetch_state_dict( pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict, weight_name=weight_name, use_safetensors=use_safetensors, @@ -343,7 +347,8 @@ def lora_state_dict( state_dict = _maybe_map_sgm_blocks_to_diffusers(state_dict, unet_config) state_dict, network_alphas = _convert_non_diffusers_lora_to_diffusers(state_dict) - return state_dict, network_alphas + out = (state_dict, network_alphas, metadata) if return_lora_metadata else (state_dict, network_alphas) + return out @classmethod def load_lora_into_unet( @@ -355,6 +360,7 @@ def load_lora_into_unet( _pipeline=None, low_cpu_mem_usage=False, hotswap: bool = False, + metadata=None, ): """ This will load the LoRA layers specified in `state_dict` into `unet`. @@ -378,6 +384,9 @@ def load_lora_into_unet( weights. hotswap (`bool`, *optional*): See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`]. + metadata (`dict`): + Optional LoRA adapter metadata. When supplied, the `LoraConfig` arguments of `peft` won't be derived + from the state dict. """ if not USE_PEFT_BACKEND: raise ValueError("PEFT backend is required for this method.") @@ -396,6 +405,7 @@ def load_lora_into_unet( prefix=cls.unet_name, network_alphas=network_alphas, adapter_name=adapter_name, + metadata=metadata, _pipeline=_pipeline, low_cpu_mem_usage=low_cpu_mem_usage, hotswap=hotswap, @@ -413,6 +423,7 @@ def load_lora_into_text_encoder( _pipeline=None, low_cpu_mem_usage=False, hotswap: bool = False, + metadata=None, ): """ This will load the LoRA layers specified in `state_dict` into `text_encoder` @@ -440,6 +451,9 @@ def load_lora_into_text_encoder( weights. hotswap (`bool`, *optional*): See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`]. + metadata (`dict`): + Optional LoRA adapter metadata. When supplied, the `LoraConfig` arguments of `peft` won't be derived + from the state dict. """ _load_lora_into_text_encoder( state_dict=state_dict, @@ -449,6 +463,7 @@ def load_lora_into_text_encoder( prefix=prefix, text_encoder_name=cls.text_encoder_name, adapter_name=adapter_name, + metadata=metadata, _pipeline=_pipeline, low_cpu_mem_usage=low_cpu_mem_usage, hotswap=hotswap, @@ -464,6 +479,8 @@ def save_lora_weights( weight_name: str = None, save_function: Callable = None, safe_serialization: bool = True, + unet_lora_adapter_metadata=None, + text_encoder_lora_adapter_metadata=None, ): r""" Save the LoRA parameters corresponding to the UNet and text encoder. @@ -486,8 +503,13 @@ def save_lora_weights( `DIFFUSERS_SAVE_MODE`. safe_serialization (`bool`, *optional*, defaults to `True`): Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`. + unet_lora_adapter_metadata: + LoRA adapter metadata associated with the unet to be serialized with the state dict. + text_encoder_lora_adapter_metadata: + LoRA adapter metadata associated with the text encoder to be serialized with the state dict. """ state_dict = {} + lora_adapter_metadata = {} if not (unet_lora_layers or text_encoder_lora_layers): raise ValueError("You must pass at least one of `unet_lora_layers` and `text_encoder_lora_layers`.") @@ -498,6 +520,14 @@ def save_lora_weights( if text_encoder_lora_layers: state_dict.update(cls.pack_weights(text_encoder_lora_layers, cls.text_encoder_name)) + if unet_lora_adapter_metadata: + lora_adapter_metadata.update(_pack_dict_with_prefix(unet_lora_adapter_metadata, cls.unet_name)) + + if text_encoder_lora_adapter_metadata: + lora_adapter_metadata.update( + _pack_dict_with_prefix(text_encoder_lora_adapter_metadata, cls.text_encoder_name) + ) + # Save the model cls.write_lora_layers( state_dict=state_dict, @@ -506,6 +536,7 @@ def save_lora_weights( weight_name=weight_name, save_function=save_function, safe_serialization=safe_serialization, + lora_adapter_metadata=lora_adapter_metadata, ) def fuse_lora( @@ -641,7 +672,8 @@ def load_lora_weights( pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy() # First, ensure that the checkpoint is a compatible one and can be successfully loaded. - state_dict, network_alphas = self.lora_state_dict( + kwargs["return_lora_metadata"] = True + state_dict, network_alphas, metadata = self.lora_state_dict( pretrained_model_name_or_path_or_dict, unet_config=self.unet.config, **kwargs, @@ -656,6 +688,7 @@ def load_lora_weights( network_alphas=network_alphas, unet=self.unet, adapter_name=adapter_name, + metadata=metadata, _pipeline=self, low_cpu_mem_usage=low_cpu_mem_usage, hotswap=hotswap, @@ -667,6 +700,7 @@ def load_lora_weights( prefix=self.text_encoder_name, lora_scale=self.lora_scale, adapter_name=adapter_name, + metadata=metadata, _pipeline=self, low_cpu_mem_usage=low_cpu_mem_usage, hotswap=hotswap, @@ -678,6 +712,7 @@ def load_lora_weights( prefix=f"{self.text_encoder_name}_2", lora_scale=self.lora_scale, adapter_name=adapter_name, + metadata=metadata, _pipeline=self, low_cpu_mem_usage=low_cpu_mem_usage, hotswap=hotswap, @@ -736,6 +771,8 @@ def lora_state_dict( The subfolder location of a model file within a larger model repository on the Hub or locally. weight_name (`str`, *optional*, defaults to None): Name of the serialized state dict file. + return_lora_metadata (`bool`, *optional*, defaults to False): + When enabled, additionally return the LoRA adapter metadata, typically found in the state dict. """ # Load the main state dict first which has the LoRA layers for either of # UNet and text encoder or both. @@ -749,18 +786,16 @@ def lora_state_dict( weight_name = kwargs.pop("weight_name", None) unet_config = kwargs.pop("unet_config", None) use_safetensors = kwargs.pop("use_safetensors", None) + return_lora_metadata = kwargs.pop("return_lora_metadata", False) allow_pickle = False if use_safetensors is None: use_safetensors = True allow_pickle = True - user_agent = { - "file_type": "attn_procs_weights", - "framework": "pytorch", - } + user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"} - state_dict = _fetch_state_dict( + state_dict, metadata = _fetch_state_dict( pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict, weight_name=weight_name, use_safetensors=use_safetensors, @@ -797,7 +832,8 @@ def lora_state_dict( state_dict = _maybe_map_sgm_blocks_to_diffusers(state_dict, unet_config) state_dict, network_alphas = _convert_non_diffusers_lora_to_diffusers(state_dict) - return state_dict, network_alphas + out = (state_dict, network_alphas, metadata) if return_lora_metadata else (state_dict, network_alphas) + return out @classmethod # Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.load_lora_into_unet @@ -810,6 +846,7 @@ def load_lora_into_unet( _pipeline=None, low_cpu_mem_usage=False, hotswap: bool = False, + metadata=None, ): """ This will load the LoRA layers specified in `state_dict` into `unet`. @@ -833,6 +870,9 @@ def load_lora_into_unet( weights. hotswap (`bool`, *optional*): See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`]. + metadata (`dict`): + Optional LoRA adapter metadata. When supplied, the `LoraConfig` arguments of `peft` won't be derived + from the state dict. """ if not USE_PEFT_BACKEND: raise ValueError("PEFT backend is required for this method.") @@ -851,6 +891,7 @@ def load_lora_into_unet( prefix=cls.unet_name, network_alphas=network_alphas, adapter_name=adapter_name, + metadata=metadata, _pipeline=_pipeline, low_cpu_mem_usage=low_cpu_mem_usage, hotswap=hotswap, @@ -869,6 +910,7 @@ def load_lora_into_text_encoder( _pipeline=None, low_cpu_mem_usage=False, hotswap: bool = False, + metadata=None, ): """ This will load the LoRA layers specified in `state_dict` into `text_encoder` @@ -896,6 +938,9 @@ def load_lora_into_text_encoder( weights. hotswap (`bool`, *optional*): See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`]. + metadata (`dict`): + Optional LoRA adapter metadata. When supplied, the `LoraConfig` arguments of `peft` won't be derived + from the state dict. """ _load_lora_into_text_encoder( state_dict=state_dict, @@ -905,6 +950,7 @@ def load_lora_into_text_encoder( prefix=prefix, text_encoder_name=cls.text_encoder_name, adapter_name=adapter_name, + metadata=metadata, _pipeline=_pipeline, low_cpu_mem_usage=low_cpu_mem_usage, hotswap=hotswap, @@ -921,6 +967,9 @@ def save_lora_weights( weight_name: str = None, save_function: Callable = None, safe_serialization: bool = True, + unet_lora_adapter_metadata=None, + text_encoder_lora_adapter_metadata=None, + text_encoder_2_lora_adapter_metadata=None, ): r""" Save the LoRA parameters corresponding to the UNet and text encoder. @@ -946,8 +995,15 @@ def save_lora_weights( `DIFFUSERS_SAVE_MODE`. safe_serialization (`bool`, *optional*, defaults to `True`): Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`. + unet_lora_adapter_metadata: + LoRA adapter metadata associated with the unet to be serialized with the state dict. + text_encoder_lora_adapter_metadata: + LoRA adapter metadata associated with the text encoder to be serialized with the state dict. + text_encoder_2_lora_adapter_metadata: + LoRA adapter metadata associated with the second text encoder to be serialized with the state dict. """ state_dict = {} + lora_adapter_metadata = {} if not (unet_lora_layers or text_encoder_lora_layers or text_encoder_2_lora_layers): raise ValueError( @@ -963,6 +1019,19 @@ def save_lora_weights( if text_encoder_2_lora_layers: state_dict.update(cls.pack_weights(text_encoder_2_lora_layers, "text_encoder_2")) + if unet_lora_adapter_metadata is not None: + lora_adapter_metadata.update(_pack_dict_with_prefix(unet_lora_adapter_metadata, cls.unet_name)) + + if text_encoder_lora_adapter_metadata: + lora_adapter_metadata.update( + _pack_dict_with_prefix(text_encoder_lora_adapter_metadata, cls.text_encoder_name) + ) + + if text_encoder_2_lora_adapter_metadata: + lora_adapter_metadata.update( + _pack_dict_with_prefix(text_encoder_2_lora_adapter_metadata, "text_encoder_2") + ) + cls.write_lora_layers( state_dict=state_dict, save_directory=save_directory, @@ -970,6 +1039,7 @@ def save_lora_weights( weight_name=weight_name, save_function=save_function, safe_serialization=safe_serialization, + lora_adapter_metadata=lora_adapter_metadata, ) def fuse_lora( @@ -1103,6 +1173,8 @@ def lora_state_dict( allowed by Git. subfolder (`str`, *optional*, defaults to `""`): The subfolder location of a model file within a larger model repository on the Hub or locally. + return_lora_metadata (`bool`, *optional*, defaults to False): + When enabled, additionally return the LoRA adapter metadata, typically found in the state dict. """ # Load the main state dict first which has the LoRA layers for either of @@ -1116,18 +1188,16 @@ def lora_state_dict( subfolder = kwargs.pop("subfolder", None) weight_name = kwargs.pop("weight_name", None) use_safetensors = kwargs.pop("use_safetensors", None) + return_lora_metadata = kwargs.pop("return_lora_metadata", False) allow_pickle = False if use_safetensors is None: use_safetensors = True allow_pickle = True - user_agent = { - "file_type": "attn_procs_weights", - "framework": "pytorch", - } + user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"} - state_dict = _fetch_state_dict( + state_dict, metadata = _fetch_state_dict( pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict, weight_name=weight_name, use_safetensors=use_safetensors, @@ -1148,7 +1218,8 @@ def lora_state_dict( logger.warning(warn_msg) state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k} - return state_dict + out = (state_dict, metadata) if return_lora_metadata else state_dict + return out def load_lora_weights( self, @@ -1197,7 +1268,8 @@ def load_lora_weights( pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy() # First, ensure that the checkpoint is a compatible one and can be successfully loaded. - state_dict = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs) + kwargs["return_lora_metadata"] = True + state_dict, metadata = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs) is_correct_format = all("lora" in key for key in state_dict.keys()) if not is_correct_format: @@ -1207,6 +1279,7 @@ def load_lora_weights( state_dict, transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer, adapter_name=adapter_name, + metadata=metadata, _pipeline=self, low_cpu_mem_usage=low_cpu_mem_usage, hotswap=hotswap, @@ -1218,6 +1291,7 @@ def load_lora_weights( prefix=self.text_encoder_name, lora_scale=self.lora_scale, adapter_name=adapter_name, + metadata=metadata, _pipeline=self, low_cpu_mem_usage=low_cpu_mem_usage, hotswap=hotswap, @@ -1229,6 +1303,7 @@ def load_lora_weights( prefix=f"{self.text_encoder_name}_2", lora_scale=self.lora_scale, adapter_name=adapter_name, + metadata=metadata, _pipeline=self, low_cpu_mem_usage=low_cpu_mem_usage, hotswap=hotswap, @@ -1236,7 +1311,14 @@ def load_lora_weights( @classmethod def load_lora_into_transformer( - cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False, hotswap: bool = False + cls, + state_dict, + transformer, + adapter_name=None, + _pipeline=None, + low_cpu_mem_usage=False, + hotswap: bool = False, + metadata=None, ): """ This will load the LoRA layers specified in `state_dict` into `transformer`. @@ -1256,6 +1338,9 @@ def load_lora_into_transformer( weights. hotswap (`bool`, *optional*): See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`]. + metadata (`dict`): + Optional LoRA adapter metadata. When supplied, the `LoraConfig` arguments of `peft` won't be derived + from the state dict. """ if low_cpu_mem_usage and is_peft_version("<", "0.13.0"): raise ValueError( @@ -1268,6 +1353,7 @@ def load_lora_into_transformer( state_dict, network_alphas=None, adapter_name=adapter_name, + metadata=metadata, _pipeline=_pipeline, low_cpu_mem_usage=low_cpu_mem_usage, hotswap=hotswap, @@ -1286,6 +1372,7 @@ def load_lora_into_text_encoder( _pipeline=None, low_cpu_mem_usage=False, hotswap: bool = False, + metadata=None, ): """ This will load the LoRA layers specified in `state_dict` into `text_encoder` @@ -1313,6 +1400,9 @@ def load_lora_into_text_encoder( weights. hotswap (`bool`, *optional*): See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`]. + metadata (`dict`): + Optional LoRA adapter metadata. When supplied, the `LoraConfig` arguments of `peft` won't be derived + from the state dict. """ _load_lora_into_text_encoder( state_dict=state_dict, @@ -1322,6 +1412,7 @@ def load_lora_into_text_encoder( prefix=prefix, text_encoder_name=cls.text_encoder_name, adapter_name=adapter_name, + metadata=metadata, _pipeline=_pipeline, low_cpu_mem_usage=low_cpu_mem_usage, hotswap=hotswap, @@ -1339,6 +1430,9 @@ def save_lora_weights( weight_name: str = None, save_function: Callable = None, safe_serialization: bool = True, + transformer_lora_adapter_metadata=None, + text_encoder_lora_adapter_metadata=None, + text_encoder_2_lora_adapter_metadata=None, ): r""" Save the LoRA parameters corresponding to the UNet and text encoder. @@ -1364,8 +1458,15 @@ def save_lora_weights( `DIFFUSERS_SAVE_MODE`. safe_serialization (`bool`, *optional*, defaults to `True`): Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`. + transformer_lora_adapter_metadata: + LoRA adapter metadata associated with the transformer to be serialized with the state dict. + text_encoder_lora_adapter_metadata: + LoRA adapter metadata associated with the text encoder to be serialized with the state dict. + text_encoder_2_lora_adapter_metadata: + LoRA adapter metadata associated with the second text encoder to be serialized with the state dict. """ state_dict = {} + lora_adapter_metadata = {} if not (transformer_lora_layers or text_encoder_lora_layers or text_encoder_2_lora_layers): raise ValueError( @@ -1381,6 +1482,21 @@ def save_lora_weights( if text_encoder_2_lora_layers: state_dict.update(cls.pack_weights(text_encoder_2_lora_layers, "text_encoder_2")) + if transformer_lora_adapter_metadata is not None: + lora_adapter_metadata.update( + _pack_dict_with_prefix(transformer_lora_adapter_metadata, cls.transformer_name) + ) + + if text_encoder_lora_adapter_metadata: + lora_adapter_metadata.update( + _pack_dict_with_prefix(text_encoder_lora_adapter_metadata, cls.text_encoder_name) + ) + + if text_encoder_2_lora_adapter_metadata: + lora_adapter_metadata.update( + _pack_dict_with_prefix(text_encoder_2_lora_adapter_metadata, "text_encoder_2") + ) + cls.write_lora_layers( state_dict=state_dict, save_directory=save_directory, @@ -1388,6 +1504,7 @@ def save_lora_weights( weight_name=weight_name, save_function=save_function, safe_serialization=safe_serialization, + lora_adapter_metadata=lora_adapter_metadata, ) # Copied from diffusers.loaders.lora_pipeline.StableDiffusionXLLoraLoaderMixin.fuse_lora with unet->transformer @@ -1519,6 +1636,8 @@ def lora_state_dict( allowed by Git. subfolder (`str`, *optional*, defaults to `""`): The subfolder location of a model file within a larger model repository on the Hub or locally. + return_lora_metadata (`bool`, *optional*, defaults to False): + When enabled, additionally return the LoRA adapter metadata, typically found in the state dict. """ # Load the main state dict first which has the LoRA layers for either of @@ -1532,18 +1651,16 @@ def lora_state_dict( subfolder = kwargs.pop("subfolder", None) weight_name = kwargs.pop("weight_name", None) use_safetensors = kwargs.pop("use_safetensors", None) + return_lora_metadata = kwargs.pop("return_lora_metadata", False) allow_pickle = False if use_safetensors is None: use_safetensors = True allow_pickle = True - user_agent = { - "file_type": "attn_procs_weights", - "framework": "pytorch", - } + user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"} - state_dict = _fetch_state_dict( + state_dict, metadata = _fetch_state_dict( pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict, weight_name=weight_name, use_safetensors=use_safetensors, @@ -1564,7 +1681,8 @@ def lora_state_dict( logger.warning(warn_msg) state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k} - return state_dict + out = (state_dict, metadata) if return_lora_metadata else state_dict + return out # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights def load_lora_weights( @@ -1609,7 +1727,8 @@ def load_lora_weights( pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy() # First, ensure that the checkpoint is a compatible one and can be successfully loaded. - state_dict = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs) + kwargs["return_lora_metadata"] = True + state_dict, metadata = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs) is_correct_format = all("lora" in key for key in state_dict.keys()) if not is_correct_format: @@ -1619,6 +1738,7 @@ def load_lora_weights( state_dict, transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer, adapter_name=adapter_name, + metadata=metadata, _pipeline=self, low_cpu_mem_usage=low_cpu_mem_usage, hotswap=hotswap, @@ -1627,7 +1747,14 @@ def load_lora_weights( @classmethod # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->AuraFlowTransformer2DModel def load_lora_into_transformer( - cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False, hotswap: bool = False + cls, + state_dict, + transformer, + adapter_name=None, + _pipeline=None, + low_cpu_mem_usage=False, + hotswap: bool = False, + metadata=None, ): """ This will load the LoRA layers specified in `state_dict` into `transformer`. @@ -1647,6 +1774,9 @@ def load_lora_into_transformer( weights. hotswap (`bool`, *optional*): See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`]. + metadata (`dict`): + Optional LoRA adapter metadata. When supplied, the `LoraConfig` arguments of `peft` won't be derived + from the state dict. """ if low_cpu_mem_usage and is_peft_version("<", "0.13.0"): raise ValueError( @@ -1659,6 +1789,7 @@ def load_lora_into_transformer( state_dict, network_alphas=None, adapter_name=adapter_name, + metadata=metadata, _pipeline=_pipeline, low_cpu_mem_usage=low_cpu_mem_usage, hotswap=hotswap, @@ -1674,9 +1805,10 @@ def save_lora_weights( weight_name: str = None, save_function: Callable = None, safe_serialization: bool = True, + transformer_lora_adapter_metadata: Optional[dict] = None, ): r""" - Save the LoRA parameters corresponding to the UNet and text encoder. + Save the LoRA parameters corresponding to the transformer. Arguments: save_directory (`str` or `os.PathLike`): @@ -1693,14 +1825,21 @@ def save_lora_weights( `DIFFUSERS_SAVE_MODE`. safe_serialization (`bool`, *optional*, defaults to `True`): Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`. + transformer_lora_adapter_metadata: + LoRA adapter metadata associated with the transformer to be serialized with the state dict. """ state_dict = {} + lora_adapter_metadata = {} if not transformer_lora_layers: raise ValueError("You must pass `transformer_lora_layers`.") - if transformer_lora_layers: - state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name)) + state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name)) + + if transformer_lora_adapter_metadata is not None: + lora_adapter_metadata.update( + _pack_dict_with_prefix(transformer_lora_adapter_metadata, cls.transformer_name) + ) # Save the model cls.write_lora_layers( @@ -1710,6 +1849,7 @@ def save_lora_weights( weight_name=weight_name, save_function=save_function, safe_serialization=safe_serialization, + lora_adapter_metadata=lora_adapter_metadata, ) # Copied from diffusers.loaders.lora_pipeline.SanaLoraLoaderMixin.fuse_lora @@ -1843,7 +1983,8 @@ def lora_state_dict( allowed by Git. subfolder (`str`, *optional*, defaults to `""`): The subfolder location of a model file within a larger model repository on the Hub or locally. - + return_lora_metadata (`bool`, *optional*, defaults to False): + When enabled, additionally return the LoRA adapter metadata, typically found in the state dict. """ # Load the main state dict first which has the LoRA layers for either of # transformer and text encoder or both. @@ -1856,18 +1997,16 @@ def lora_state_dict( subfolder = kwargs.pop("subfolder", None) weight_name = kwargs.pop("weight_name", None) use_safetensors = kwargs.pop("use_safetensors", None) + return_lora_metadata = kwargs.pop("return_lora_metadata", False) allow_pickle = False if use_safetensors is None: use_safetensors = True allow_pickle = True - user_agent = { - "file_type": "attn_procs_weights", - "framework": "pytorch", - } + user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"} - state_dict = _fetch_state_dict( + state_dict, metadata = _fetch_state_dict( pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict, weight_name=weight_name, use_safetensors=use_safetensors, @@ -1921,8 +2060,13 @@ def lora_state_dict( f"The alpha key ({k}) seems to be incorrect. If you think this error is unexpected, please open as issue." ) - if return_alphas: - return state_dict, network_alphas + if return_alphas or return_lora_metadata: + outputs = [state_dict] + if return_alphas: + outputs.append(network_alphas) + if return_lora_metadata: + outputs.append(metadata) + return tuple(outputs) else: return state_dict @@ -1973,7 +2117,8 @@ def load_lora_weights( pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy() # First, ensure that the checkpoint is a compatible one and can be successfully loaded. - state_dict, network_alphas = self.lora_state_dict( + kwargs["return_lora_metadata"] = True + state_dict, network_alphas, metadata = self.lora_state_dict( pretrained_model_name_or_path_or_dict, return_alphas=True, **kwargs ) @@ -2024,6 +2169,7 @@ def load_lora_weights( network_alphas=network_alphas, transformer=transformer, adapter_name=adapter_name, + metadata=metadata, _pipeline=self, low_cpu_mem_usage=low_cpu_mem_usage, hotswap=hotswap, @@ -2043,6 +2189,7 @@ def load_lora_weights( prefix=self.text_encoder_name, lora_scale=self.lora_scale, adapter_name=adapter_name, + metadata=metadata, _pipeline=self, low_cpu_mem_usage=low_cpu_mem_usage, hotswap=hotswap, @@ -2055,6 +2202,7 @@ def load_lora_into_transformer( network_alphas, transformer, adapter_name=None, + metadata=None, _pipeline=None, low_cpu_mem_usage=False, hotswap: bool = False, @@ -2081,6 +2229,9 @@ def load_lora_into_transformer( weights. hotswap (`bool`, *optional*): See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`]. + metadata (`dict`): + Optional LoRA adapter metadata. When supplied, the `LoraConfig` arguments of `peft` won't be derived + from the state dict. """ if low_cpu_mem_usage and not is_peft_version(">=", "0.13.1"): raise ValueError( @@ -2093,6 +2244,7 @@ def load_lora_into_transformer( state_dict, network_alphas=network_alphas, adapter_name=adapter_name, + metadata=metadata, _pipeline=_pipeline, low_cpu_mem_usage=low_cpu_mem_usage, hotswap=hotswap, @@ -2165,6 +2317,7 @@ def load_lora_into_text_encoder( _pipeline=None, low_cpu_mem_usage=False, hotswap: bool = False, + metadata=None, ): """ This will load the LoRA layers specified in `state_dict` into `text_encoder` @@ -2192,6 +2345,9 @@ def load_lora_into_text_encoder( weights. hotswap (`bool`, *optional*): See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`]. + metadata (`dict`): + Optional LoRA adapter metadata. When supplied, the `LoraConfig` arguments of `peft` won't be derived + from the state dict. """ _load_lora_into_text_encoder( state_dict=state_dict, @@ -2201,6 +2357,7 @@ def load_lora_into_text_encoder( prefix=prefix, text_encoder_name=cls.text_encoder_name, adapter_name=adapter_name, + metadata=metadata, _pipeline=_pipeline, low_cpu_mem_usage=low_cpu_mem_usage, hotswap=hotswap, @@ -2217,6 +2374,8 @@ def save_lora_weights( weight_name: str = None, save_function: Callable = None, safe_serialization: bool = True, + transformer_lora_adapter_metadata=None, + text_encoder_lora_adapter_metadata=None, ): r""" Save the LoRA parameters corresponding to the UNet and text encoder. @@ -2239,8 +2398,13 @@ def save_lora_weights( `DIFFUSERS_SAVE_MODE`. safe_serialization (`bool`, *optional*, defaults to `True`): Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`. + transformer_lora_adapter_metadata: + LoRA adapter metadata associated with the transformer to be serialized with the state dict. + text_encoder_lora_adapter_metadata: + LoRA adapter metadata associated with the text encoder to be serialized with the state dict. """ state_dict = {} + lora_adapter_metadata = {} if not (transformer_lora_layers or text_encoder_lora_layers): raise ValueError("You must pass at least one of `transformer_lora_layers` and `text_encoder_lora_layers`.") @@ -2251,6 +2415,16 @@ def save_lora_weights( if text_encoder_lora_layers: state_dict.update(cls.pack_weights(text_encoder_lora_layers, cls.text_encoder_name)) + if transformer_lora_adapter_metadata: + lora_adapter_metadata.update( + _pack_dict_with_prefix(transformer_lora_adapter_metadata, cls.transformer_name) + ) + + if text_encoder_lora_adapter_metadata: + lora_adapter_metadata.update( + _pack_dict_with_prefix(text_encoder_lora_adapter_metadata, cls.text_encoder_name) + ) + # Save the model cls.write_lora_layers( state_dict=state_dict, @@ -2259,6 +2433,7 @@ def save_lora_weights( weight_name=weight_name, save_function=save_function, safe_serialization=safe_serialization, + lora_adapter_metadata=lora_adapter_metadata, ) def fuse_lora( @@ -2626,6 +2801,7 @@ def load_lora_into_transformer( network_alphas, transformer, adapter_name=None, + metadata=None, _pipeline=None, low_cpu_mem_usage=False, hotswap: bool = False, @@ -2652,6 +2828,9 @@ def load_lora_into_transformer( weights. hotswap (`bool`, *optional*): See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`]. + metadata (`dict`): + Optional LoRA adapter metadata. When supplied, the `LoraConfig` arguments of `peft` won't be derived + from the state dict. """ if low_cpu_mem_usage and not is_peft_version(">=", "0.13.1"): raise ValueError( @@ -2664,6 +2843,7 @@ def load_lora_into_transformer( state_dict, network_alphas=network_alphas, adapter_name=adapter_name, + metadata=metadata, _pipeline=_pipeline, low_cpu_mem_usage=low_cpu_mem_usage, hotswap=hotswap, @@ -2682,6 +2862,7 @@ def load_lora_into_text_encoder( _pipeline=None, low_cpu_mem_usage=False, hotswap: bool = False, + metadata=None, ): """ This will load the LoRA layers specified in `state_dict` into `text_encoder` @@ -2709,6 +2890,9 @@ def load_lora_into_text_encoder( weights. hotswap (`bool`, *optional*): See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`]. + metadata (`dict`): + Optional LoRA adapter metadata. When supplied, the `LoraConfig` arguments of `peft` won't be derived + from the state dict. """ _load_lora_into_text_encoder( state_dict=state_dict, @@ -2718,6 +2902,7 @@ def load_lora_into_text_encoder( prefix=prefix, text_encoder_name=cls.text_encoder_name, adapter_name=adapter_name, + metadata=metadata, _pipeline=_pipeline, low_cpu_mem_usage=low_cpu_mem_usage, hotswap=hotswap, @@ -2837,6 +3022,8 @@ def lora_state_dict( allowed by Git. subfolder (`str`, *optional*, defaults to `""`): The subfolder location of a model file within a larger model repository on the Hub or locally. + return_lora_metadata (`bool`, *optional*, defaults to False): + When enabled, additionally return the LoRA adapter metadata, typically found in the state dict. """ # Load the main state dict first which has the LoRA layers for either of @@ -2850,18 +3037,16 @@ def lora_state_dict( subfolder = kwargs.pop("subfolder", None) weight_name = kwargs.pop("weight_name", None) use_safetensors = kwargs.pop("use_safetensors", None) + return_lora_metadata = kwargs.pop("return_lora_metadata", False) allow_pickle = False if use_safetensors is None: use_safetensors = True allow_pickle = True - user_agent = { - "file_type": "attn_procs_weights", - "framework": "pytorch", - } + user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"} - state_dict = _fetch_state_dict( + state_dict, metadata = _fetch_state_dict( pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict, weight_name=weight_name, use_safetensors=use_safetensors, @@ -2882,7 +3067,8 @@ def lora_state_dict( logger.warning(warn_msg) state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k} - return state_dict + out = (state_dict, metadata) if return_lora_metadata else state_dict + return out def load_lora_weights( self, @@ -2926,7 +3112,8 @@ def load_lora_weights( pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy() # First, ensure that the checkpoint is a compatible one and can be successfully loaded. - state_dict = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs) + kwargs["return_lora_metadata"] = True + state_dict, metadata = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs) is_correct_format = all("lora" in key for key in state_dict.keys()) if not is_correct_format: @@ -2936,6 +3123,7 @@ def load_lora_weights( state_dict, transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer, adapter_name=adapter_name, + metadata=metadata, _pipeline=self, low_cpu_mem_usage=low_cpu_mem_usage, hotswap=hotswap, @@ -2944,7 +3132,14 @@ def load_lora_weights( @classmethod # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->CogVideoXTransformer3DModel def load_lora_into_transformer( - cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False, hotswap: bool = False + cls, + state_dict, + transformer, + adapter_name=None, + _pipeline=None, + low_cpu_mem_usage=False, + hotswap: bool = False, + metadata=None, ): """ This will load the LoRA layers specified in `state_dict` into `transformer`. @@ -2964,6 +3159,9 @@ def load_lora_into_transformer( weights. hotswap (`bool`, *optional*): See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`]. + metadata (`dict`): + Optional LoRA adapter metadata. When supplied, the `LoraConfig` arguments of `peft` won't be derived + from the state dict. """ if low_cpu_mem_usage and is_peft_version("<", "0.13.0"): raise ValueError( @@ -2976,6 +3174,7 @@ def load_lora_into_transformer( state_dict, network_alphas=None, adapter_name=adapter_name, + metadata=metadata, _pipeline=_pipeline, low_cpu_mem_usage=low_cpu_mem_usage, hotswap=hotswap, @@ -2991,9 +3190,10 @@ def save_lora_weights( weight_name: str = None, save_function: Callable = None, safe_serialization: bool = True, + transformer_lora_adapter_metadata: Optional[dict] = None, ): r""" - Save the LoRA parameters corresponding to the UNet and text encoder. + Save the LoRA parameters corresponding to the transformer. Arguments: save_directory (`str` or `os.PathLike`): @@ -3010,14 +3210,21 @@ def save_lora_weights( `DIFFUSERS_SAVE_MODE`. safe_serialization (`bool`, *optional*, defaults to `True`): Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`. + transformer_lora_adapter_metadata: + LoRA adapter metadata associated with the transformer to be serialized with the state dict. """ state_dict = {} + lora_adapter_metadata = {} if not transformer_lora_layers: raise ValueError("You must pass `transformer_lora_layers`.") - if transformer_lora_layers: - state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name)) + state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name)) + + if transformer_lora_adapter_metadata is not None: + lora_adapter_metadata.update( + _pack_dict_with_prefix(transformer_lora_adapter_metadata, cls.transformer_name) + ) # Save the model cls.write_lora_layers( @@ -3027,6 +3234,7 @@ def save_lora_weights( weight_name=weight_name, save_function=save_function, safe_serialization=safe_serialization, + lora_adapter_metadata=lora_adapter_metadata, ) def fuse_lora( @@ -3153,6 +3361,8 @@ def lora_state_dict( allowed by Git. subfolder (`str`, *optional*, defaults to `""`): The subfolder location of a model file within a larger model repository on the Hub or locally. + return_lora_metadata (`bool`, *optional*, defaults to False): + When enabled, additionally return the LoRA adapter metadata, typically found in the state dict. """ # Load the main state dict first which has the LoRA layers for either of @@ -3166,18 +3376,16 @@ def lora_state_dict( subfolder = kwargs.pop("subfolder", None) weight_name = kwargs.pop("weight_name", None) use_safetensors = kwargs.pop("use_safetensors", None) + return_lora_metadata = kwargs.pop("return_lora_metadata", False) allow_pickle = False if use_safetensors is None: use_safetensors = True allow_pickle = True - user_agent = { - "file_type": "attn_procs_weights", - "framework": "pytorch", - } + user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"} - state_dict = _fetch_state_dict( + state_dict, metadata = _fetch_state_dict( pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict, weight_name=weight_name, use_safetensors=use_safetensors, @@ -3198,7 +3406,8 @@ def lora_state_dict( logger.warning(warn_msg) state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k} - return state_dict + out = (state_dict, metadata) if return_lora_metadata else state_dict + return out # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights def load_lora_weights( @@ -3243,7 +3452,8 @@ def load_lora_weights( pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy() # First, ensure that the checkpoint is a compatible one and can be successfully loaded. - state_dict = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs) + kwargs["return_lora_metadata"] = True + state_dict, metadata = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs) is_correct_format = all("lora" in key for key in state_dict.keys()) if not is_correct_format: @@ -3253,6 +3463,7 @@ def load_lora_weights( state_dict, transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer, adapter_name=adapter_name, + metadata=metadata, _pipeline=self, low_cpu_mem_usage=low_cpu_mem_usage, hotswap=hotswap, @@ -3261,7 +3472,14 @@ def load_lora_weights( @classmethod # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->MochiTransformer3DModel def load_lora_into_transformer( - cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False, hotswap: bool = False + cls, + state_dict, + transformer, + adapter_name=None, + _pipeline=None, + low_cpu_mem_usage=False, + hotswap: bool = False, + metadata=None, ): """ This will load the LoRA layers specified in `state_dict` into `transformer`. @@ -3281,6 +3499,9 @@ def load_lora_into_transformer( weights. hotswap (`bool`, *optional*): See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`]. + metadata (`dict`): + Optional LoRA adapter metadata. When supplied, the `LoraConfig` arguments of `peft` won't be derived + from the state dict. """ if low_cpu_mem_usage and is_peft_version("<", "0.13.0"): raise ValueError( @@ -3293,6 +3514,7 @@ def load_lora_into_transformer( state_dict, network_alphas=None, adapter_name=adapter_name, + metadata=metadata, _pipeline=_pipeline, low_cpu_mem_usage=low_cpu_mem_usage, hotswap=hotswap, @@ -3308,9 +3530,10 @@ def save_lora_weights( weight_name: str = None, save_function: Callable = None, safe_serialization: bool = True, + transformer_lora_adapter_metadata: Optional[dict] = None, ): r""" - Save the LoRA parameters corresponding to the UNet and text encoder. + Save the LoRA parameters corresponding to the transformer. Arguments: save_directory (`str` or `os.PathLike`): @@ -3327,14 +3550,21 @@ def save_lora_weights( `DIFFUSERS_SAVE_MODE`. safe_serialization (`bool`, *optional*, defaults to `True`): Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`. + transformer_lora_adapter_metadata: + LoRA adapter metadata associated with the transformer to be serialized with the state dict. """ state_dict = {} + lora_adapter_metadata = {} if not transformer_lora_layers: raise ValueError("You must pass `transformer_lora_layers`.") - if transformer_lora_layers: - state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name)) + state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name)) + + if transformer_lora_adapter_metadata is not None: + lora_adapter_metadata.update( + _pack_dict_with_prefix(transformer_lora_adapter_metadata, cls.transformer_name) + ) # Save the model cls.write_lora_layers( @@ -3344,6 +3574,7 @@ def save_lora_weights( weight_name=weight_name, save_function=save_function, safe_serialization=safe_serialization, + lora_adapter_metadata=lora_adapter_metadata, ) # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.fuse_lora @@ -3471,7 +3702,8 @@ def lora_state_dict( allowed by Git. subfolder (`str`, *optional*, defaults to `""`): The subfolder location of a model file within a larger model repository on the Hub or locally. - + return_lora_metadata (`bool`, *optional*, defaults to False): + When enabled, additionally return the LoRA adapter metadata, typically found in the state dict. """ # Load the main state dict first which has the LoRA layers for either of # transformer and text encoder or both. @@ -3484,18 +3716,16 @@ def lora_state_dict( subfolder = kwargs.pop("subfolder", None) weight_name = kwargs.pop("weight_name", None) use_safetensors = kwargs.pop("use_safetensors", None) + return_lora_metadata = kwargs.pop("return_lora_metadata", False) allow_pickle = False if use_safetensors is None: use_safetensors = True allow_pickle = True - user_agent = { - "file_type": "attn_procs_weights", - "framework": "pytorch", - } + user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"} - state_dict = _fetch_state_dict( + state_dict, metadata = _fetch_state_dict( pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict, weight_name=weight_name, use_safetensors=use_safetensors, @@ -3520,7 +3750,8 @@ def lora_state_dict( if is_non_diffusers_format: state_dict = _convert_non_diffusers_ltxv_lora_to_diffusers(state_dict) - return state_dict + out = (state_dict, metadata) if return_lora_metadata else state_dict + return out # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights def load_lora_weights( @@ -3565,7 +3796,8 @@ def load_lora_weights( pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy() # First, ensure that the checkpoint is a compatible one and can be successfully loaded. - state_dict = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs) + kwargs["return_lora_metadata"] = True + state_dict, metadata = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs) is_correct_format = all("lora" in key for key in state_dict.keys()) if not is_correct_format: @@ -3575,6 +3807,7 @@ def load_lora_weights( state_dict, transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer, adapter_name=adapter_name, + metadata=metadata, _pipeline=self, low_cpu_mem_usage=low_cpu_mem_usage, hotswap=hotswap, @@ -3583,7 +3816,14 @@ def load_lora_weights( @classmethod # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->LTXVideoTransformer3DModel def load_lora_into_transformer( - cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False, hotswap: bool = False + cls, + state_dict, + transformer, + adapter_name=None, + _pipeline=None, + low_cpu_mem_usage=False, + hotswap: bool = False, + metadata=None, ): """ This will load the LoRA layers specified in `state_dict` into `transformer`. @@ -3603,6 +3843,9 @@ def load_lora_into_transformer( weights. hotswap (`bool`, *optional*): See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`]. + metadata (`dict`): + Optional LoRA adapter metadata. When supplied, the `LoraConfig` arguments of `peft` won't be derived + from the state dict. """ if low_cpu_mem_usage and is_peft_version("<", "0.13.0"): raise ValueError( @@ -3615,6 +3858,7 @@ def load_lora_into_transformer( state_dict, network_alphas=None, adapter_name=adapter_name, + metadata=metadata, _pipeline=_pipeline, low_cpu_mem_usage=low_cpu_mem_usage, hotswap=hotswap, @@ -3630,9 +3874,10 @@ def save_lora_weights( weight_name: str = None, save_function: Callable = None, safe_serialization: bool = True, + transformer_lora_adapter_metadata: Optional[dict] = None, ): r""" - Save the LoRA parameters corresponding to the UNet and text encoder. + Save the LoRA parameters corresponding to the transformer. Arguments: save_directory (`str` or `os.PathLike`): @@ -3649,14 +3894,21 @@ def save_lora_weights( `DIFFUSERS_SAVE_MODE`. safe_serialization (`bool`, *optional*, defaults to `True`): Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`. + transformer_lora_adapter_metadata: + LoRA adapter metadata associated with the transformer to be serialized with the state dict. """ state_dict = {} + lora_adapter_metadata = {} if not transformer_lora_layers: raise ValueError("You must pass `transformer_lora_layers`.") - if transformer_lora_layers: - state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name)) + state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name)) + + if transformer_lora_adapter_metadata is not None: + lora_adapter_metadata.update( + _pack_dict_with_prefix(transformer_lora_adapter_metadata, cls.transformer_name) + ) # Save the model cls.write_lora_layers( @@ -3666,6 +3918,7 @@ def save_lora_weights( weight_name=weight_name, save_function=save_function, safe_serialization=safe_serialization, + lora_adapter_metadata=lora_adapter_metadata, ) # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.fuse_lora @@ -3794,6 +4047,8 @@ def lora_state_dict( allowed by Git. subfolder (`str`, *optional*, defaults to `""`): The subfolder location of a model file within a larger model repository on the Hub or locally. + return_lora_metadata (`bool`, *optional*, defaults to False): + When enabled, additionally return the LoRA adapter metadata, typically found in the state dict. """ # Load the main state dict first which has the LoRA layers for either of @@ -3807,18 +4062,16 @@ def lora_state_dict( subfolder = kwargs.pop("subfolder", None) weight_name = kwargs.pop("weight_name", None) use_safetensors = kwargs.pop("use_safetensors", None) + return_lora_metadata = kwargs.pop("return_lora_metadata", False) allow_pickle = False if use_safetensors is None: use_safetensors = True allow_pickle = True - user_agent = { - "file_type": "attn_procs_weights", - "framework": "pytorch", - } + user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"} - state_dict = _fetch_state_dict( + state_dict, metadata = _fetch_state_dict( pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict, weight_name=weight_name, use_safetensors=use_safetensors, @@ -3839,7 +4092,8 @@ def lora_state_dict( logger.warning(warn_msg) state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k} - return state_dict + out = (state_dict, metadata) if return_lora_metadata else state_dict + return out # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights def load_lora_weights( @@ -3884,7 +4138,8 @@ def load_lora_weights( pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy() # First, ensure that the checkpoint is a compatible one and can be successfully loaded. - state_dict = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs) + kwargs["return_lora_metadata"] = True + state_dict, metadata = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs) is_correct_format = all("lora" in key for key in state_dict.keys()) if not is_correct_format: @@ -3894,6 +4149,7 @@ def load_lora_weights( state_dict, transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer, adapter_name=adapter_name, + metadata=metadata, _pipeline=self, low_cpu_mem_usage=low_cpu_mem_usage, hotswap=hotswap, @@ -3902,7 +4158,14 @@ def load_lora_weights( @classmethod # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->SanaTransformer2DModel def load_lora_into_transformer( - cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False, hotswap: bool = False + cls, + state_dict, + transformer, + adapter_name=None, + _pipeline=None, + low_cpu_mem_usage=False, + hotswap: bool = False, + metadata=None, ): """ This will load the LoRA layers specified in `state_dict` into `transformer`. @@ -3922,6 +4185,9 @@ def load_lora_into_transformer( weights. hotswap (`bool`, *optional*): See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`]. + metadata (`dict`): + Optional LoRA adapter metadata. When supplied, the `LoraConfig` arguments of `peft` won't be derived + from the state dict. """ if low_cpu_mem_usage and is_peft_version("<", "0.13.0"): raise ValueError( @@ -3934,6 +4200,7 @@ def load_lora_into_transformer( state_dict, network_alphas=None, adapter_name=adapter_name, + metadata=metadata, _pipeline=_pipeline, low_cpu_mem_usage=low_cpu_mem_usage, hotswap=hotswap, @@ -3949,9 +4216,10 @@ def save_lora_weights( weight_name: str = None, save_function: Callable = None, safe_serialization: bool = True, + transformer_lora_adapter_metadata: Optional[dict] = None, ): r""" - Save the LoRA parameters corresponding to the UNet and text encoder. + Save the LoRA parameters corresponding to the transformer. Arguments: save_directory (`str` or `os.PathLike`): @@ -3968,14 +4236,21 @@ def save_lora_weights( `DIFFUSERS_SAVE_MODE`. safe_serialization (`bool`, *optional*, defaults to `True`): Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`. + transformer_lora_adapter_metadata: + LoRA adapter metadata associated with the transformer to be serialized with the state dict. """ state_dict = {} + lora_adapter_metadata = {} if not transformer_lora_layers: raise ValueError("You must pass `transformer_lora_layers`.") - if transformer_lora_layers: - state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name)) + state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name)) + + if transformer_lora_adapter_metadata is not None: + lora_adapter_metadata.update( + _pack_dict_with_prefix(transformer_lora_adapter_metadata, cls.transformer_name) + ) # Save the model cls.write_lora_layers( @@ -3985,6 +4260,7 @@ def save_lora_weights( weight_name=weight_name, save_function=save_function, safe_serialization=safe_serialization, + lora_adapter_metadata=lora_adapter_metadata, ) # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.fuse_lora @@ -4112,7 +4388,8 @@ def lora_state_dict( allowed by Git. subfolder (`str`, *optional*, defaults to `""`): The subfolder location of a model file within a larger model repository on the Hub or locally. - + return_lora_metadata (`bool`, *optional*, defaults to False): + When enabled, additionally return the LoRA adapter metadata, typically found in the state dict. """ # Load the main state dict first which has the LoRA layers for either of # transformer and text encoder or both. @@ -4125,18 +4402,16 @@ def lora_state_dict( subfolder = kwargs.pop("subfolder", None) weight_name = kwargs.pop("weight_name", None) use_safetensors = kwargs.pop("use_safetensors", None) + return_lora_metadata = kwargs.pop("return_lora_metadata", False) allow_pickle = False if use_safetensors is None: use_safetensors = True allow_pickle = True - user_agent = { - "file_type": "attn_procs_weights", - "framework": "pytorch", - } + user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"} - state_dict = _fetch_state_dict( + state_dict, metadata = _fetch_state_dict( pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict, weight_name=weight_name, use_safetensors=use_safetensors, @@ -4161,7 +4436,8 @@ def lora_state_dict( if is_original_hunyuan_video: state_dict = _convert_hunyuan_video_lora_to_diffusers(state_dict) - return state_dict + out = (state_dict, metadata) if return_lora_metadata else state_dict + return out # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights def load_lora_weights( @@ -4206,7 +4482,8 @@ def load_lora_weights( pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy() # First, ensure that the checkpoint is a compatible one and can be successfully loaded. - state_dict = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs) + kwargs["return_lora_metadata"] = True + state_dict, metadata = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs) is_correct_format = all("lora" in key for key in state_dict.keys()) if not is_correct_format: @@ -4216,6 +4493,7 @@ def load_lora_weights( state_dict, transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer, adapter_name=adapter_name, + metadata=metadata, _pipeline=self, low_cpu_mem_usage=low_cpu_mem_usage, hotswap=hotswap, @@ -4224,7 +4502,14 @@ def load_lora_weights( @classmethod # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->HunyuanVideoTransformer3DModel def load_lora_into_transformer( - cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False, hotswap: bool = False + cls, + state_dict, + transformer, + adapter_name=None, + _pipeline=None, + low_cpu_mem_usage=False, + hotswap: bool = False, + metadata=None, ): """ This will load the LoRA layers specified in `state_dict` into `transformer`. @@ -4244,6 +4529,9 @@ def load_lora_into_transformer( weights. hotswap (`bool`, *optional*): See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`]. + metadata (`dict`): + Optional LoRA adapter metadata. When supplied, the `LoraConfig` arguments of `peft` won't be derived + from the state dict. """ if low_cpu_mem_usage and is_peft_version("<", "0.13.0"): raise ValueError( @@ -4256,6 +4544,7 @@ def load_lora_into_transformer( state_dict, network_alphas=None, adapter_name=adapter_name, + metadata=metadata, _pipeline=_pipeline, low_cpu_mem_usage=low_cpu_mem_usage, hotswap=hotswap, @@ -4271,9 +4560,10 @@ def save_lora_weights( weight_name: str = None, save_function: Callable = None, safe_serialization: bool = True, + transformer_lora_adapter_metadata: Optional[dict] = None, ): r""" - Save the LoRA parameters corresponding to the UNet and text encoder. + Save the LoRA parameters corresponding to the transformer. Arguments: save_directory (`str` or `os.PathLike`): @@ -4290,14 +4580,21 @@ def save_lora_weights( `DIFFUSERS_SAVE_MODE`. safe_serialization (`bool`, *optional*, defaults to `True`): Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`. + transformer_lora_adapter_metadata: + LoRA adapter metadata associated with the transformer to be serialized with the state dict. """ state_dict = {} + lora_adapter_metadata = {} if not transformer_lora_layers: raise ValueError("You must pass `transformer_lora_layers`.") - if transformer_lora_layers: - state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name)) + state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name)) + + if transformer_lora_adapter_metadata is not None: + lora_adapter_metadata.update( + _pack_dict_with_prefix(transformer_lora_adapter_metadata, cls.transformer_name) + ) # Save the model cls.write_lora_layers( @@ -4307,6 +4604,7 @@ def save_lora_weights( weight_name=weight_name, save_function=save_function, safe_serialization=safe_serialization, + lora_adapter_metadata=lora_adapter_metadata, ) # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.fuse_lora @@ -4434,7 +4732,8 @@ def lora_state_dict( allowed by Git. subfolder (`str`, *optional*, defaults to `""`): The subfolder location of a model file within a larger model repository on the Hub or locally. - + return_lora_metadata (`bool`, *optional*, defaults to False): + When enabled, additionally return the LoRA adapter metadata, typically found in the state dict. """ # Load the main state dict first which has the LoRA layers for either of # transformer and text encoder or both. @@ -4447,18 +4746,16 @@ def lora_state_dict( subfolder = kwargs.pop("subfolder", None) weight_name = kwargs.pop("weight_name", None) use_safetensors = kwargs.pop("use_safetensors", None) + return_lora_metadata = kwargs.pop("return_lora_metadata", False) allow_pickle = False if use_safetensors is None: use_safetensors = True allow_pickle = True - user_agent = { - "file_type": "attn_procs_weights", - "framework": "pytorch", - } + user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"} - state_dict = _fetch_state_dict( + state_dict, metadata = _fetch_state_dict( pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict, weight_name=weight_name, use_safetensors=use_safetensors, @@ -4484,7 +4781,8 @@ def lora_state_dict( if non_diffusers: state_dict = _convert_non_diffusers_lumina2_lora_to_diffusers(state_dict) - return state_dict + out = (state_dict, metadata) if return_lora_metadata else state_dict + return out # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights def load_lora_weights( @@ -4529,7 +4827,8 @@ def load_lora_weights( pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy() # First, ensure that the checkpoint is a compatible one and can be successfully loaded. - state_dict = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs) + kwargs["return_lora_metadata"] = True + state_dict, metadata = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs) is_correct_format = all("lora" in key for key in state_dict.keys()) if not is_correct_format: @@ -4539,6 +4838,7 @@ def load_lora_weights( state_dict, transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer, adapter_name=adapter_name, + metadata=metadata, _pipeline=self, low_cpu_mem_usage=low_cpu_mem_usage, hotswap=hotswap, @@ -4547,7 +4847,14 @@ def load_lora_weights( @classmethod # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->Lumina2Transformer2DModel def load_lora_into_transformer( - cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False, hotswap: bool = False + cls, + state_dict, + transformer, + adapter_name=None, + _pipeline=None, + low_cpu_mem_usage=False, + hotswap: bool = False, + metadata=None, ): """ This will load the LoRA layers specified in `state_dict` into `transformer`. @@ -4567,6 +4874,9 @@ def load_lora_into_transformer( weights. hotswap (`bool`, *optional*): See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`]. + metadata (`dict`): + Optional LoRA adapter metadata. When supplied, the `LoraConfig` arguments of `peft` won't be derived + from the state dict. """ if low_cpu_mem_usage and is_peft_version("<", "0.13.0"): raise ValueError( @@ -4579,6 +4889,7 @@ def load_lora_into_transformer( state_dict, network_alphas=None, adapter_name=adapter_name, + metadata=metadata, _pipeline=_pipeline, low_cpu_mem_usage=low_cpu_mem_usage, hotswap=hotswap, @@ -4594,9 +4905,10 @@ def save_lora_weights( weight_name: str = None, save_function: Callable = None, safe_serialization: bool = True, + transformer_lora_adapter_metadata: Optional[dict] = None, ): r""" - Save the LoRA parameters corresponding to the UNet and text encoder. + Save the LoRA parameters corresponding to the transformer. Arguments: save_directory (`str` or `os.PathLike`): @@ -4613,14 +4925,21 @@ def save_lora_weights( `DIFFUSERS_SAVE_MODE`. safe_serialization (`bool`, *optional*, defaults to `True`): Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`. + transformer_lora_adapter_metadata: + LoRA adapter metadata associated with the transformer to be serialized with the state dict. """ state_dict = {} + lora_adapter_metadata = {} if not transformer_lora_layers: raise ValueError("You must pass `transformer_lora_layers`.") - if transformer_lora_layers: - state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name)) + state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name)) + + if transformer_lora_adapter_metadata is not None: + lora_adapter_metadata.update( + _pack_dict_with_prefix(transformer_lora_adapter_metadata, cls.transformer_name) + ) # Save the model cls.write_lora_layers( @@ -4630,6 +4949,7 @@ def save_lora_weights( weight_name=weight_name, save_function=save_function, safe_serialization=safe_serialization, + lora_adapter_metadata=lora_adapter_metadata, ) # Copied from diffusers.loaders.lora_pipeline.SanaLoraLoaderMixin.fuse_lora @@ -4757,7 +5077,8 @@ def lora_state_dict( allowed by Git. subfolder (`str`, *optional*, defaults to `""`): The subfolder location of a model file within a larger model repository on the Hub or locally. - + return_lora_metadata (`bool`, *optional*, defaults to False): + When enabled, additionally return the LoRA adapter metadata, typically found in the state dict. """ # Load the main state dict first which has the LoRA layers for either of # transformer and text encoder or both. @@ -4770,18 +5091,16 @@ def lora_state_dict( subfolder = kwargs.pop("subfolder", None) weight_name = kwargs.pop("weight_name", None) use_safetensors = kwargs.pop("use_safetensors", None) + return_lora_metadata = kwargs.pop("return_lora_metadata", False) allow_pickle = False if use_safetensors is None: use_safetensors = True allow_pickle = True - user_agent = { - "file_type": "attn_procs_weights", - "framework": "pytorch", - } + user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"} - state_dict = _fetch_state_dict( + state_dict, metadata = _fetch_state_dict( pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict, weight_name=weight_name, use_safetensors=use_safetensors, @@ -4806,7 +5125,8 @@ def lora_state_dict( logger.warning(warn_msg) state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k} - return state_dict + out = (state_dict, metadata) if return_lora_metadata else state_dict + return out @classmethod def _maybe_expand_t2v_lora_for_i2v( @@ -4898,7 +5218,8 @@ def load_lora_weights( pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy() # First, ensure that the checkpoint is a compatible one and can be successfully loaded. - state_dict = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs) + kwargs["return_lora_metadata"] = True + state_dict, metadata = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs) # convert T2V LoRA to I2V LoRA (when loaded to Wan I2V) by adding zeros for the additional (missing) _img layers state_dict = self._maybe_expand_t2v_lora_for_i2v( transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer, @@ -4912,6 +5233,7 @@ def load_lora_weights( state_dict, transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer, adapter_name=adapter_name, + metadata=metadata, _pipeline=self, low_cpu_mem_usage=low_cpu_mem_usage, hotswap=hotswap, @@ -4920,7 +5242,14 @@ def load_lora_weights( @classmethod # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->WanTransformer3DModel def load_lora_into_transformer( - cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False, hotswap: bool = False + cls, + state_dict, + transformer, + adapter_name=None, + _pipeline=None, + low_cpu_mem_usage=False, + hotswap: bool = False, + metadata=None, ): """ This will load the LoRA layers specified in `state_dict` into `transformer`. @@ -4940,6 +5269,9 @@ def load_lora_into_transformer( weights. hotswap (`bool`, *optional*): See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`]. + metadata (`dict`): + Optional LoRA adapter metadata. When supplied, the `LoraConfig` arguments of `peft` won't be derived + from the state dict. """ if low_cpu_mem_usage and is_peft_version("<", "0.13.0"): raise ValueError( @@ -4952,6 +5284,7 @@ def load_lora_into_transformer( state_dict, network_alphas=None, adapter_name=adapter_name, + metadata=metadata, _pipeline=_pipeline, low_cpu_mem_usage=low_cpu_mem_usage, hotswap=hotswap, @@ -4967,9 +5300,10 @@ def save_lora_weights( weight_name: str = None, save_function: Callable = None, safe_serialization: bool = True, + transformer_lora_adapter_metadata: Optional[dict] = None, ): r""" - Save the LoRA parameters corresponding to the UNet and text encoder. + Save the LoRA parameters corresponding to the transformer. Arguments: save_directory (`str` or `os.PathLike`): @@ -4986,14 +5320,21 @@ def save_lora_weights( `DIFFUSERS_SAVE_MODE`. safe_serialization (`bool`, *optional*, defaults to `True`): Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`. + transformer_lora_adapter_metadata: + LoRA adapter metadata associated with the transformer to be serialized with the state dict. """ state_dict = {} + lora_adapter_metadata = {} if not transformer_lora_layers: raise ValueError("You must pass `transformer_lora_layers`.") - if transformer_lora_layers: - state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name)) + state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name)) + + if transformer_lora_adapter_metadata is not None: + lora_adapter_metadata.update( + _pack_dict_with_prefix(transformer_lora_adapter_metadata, cls.transformer_name) + ) # Save the model cls.write_lora_layers( @@ -5003,6 +5344,7 @@ def save_lora_weights( weight_name=weight_name, save_function=save_function, safe_serialization=safe_serialization, + lora_adapter_metadata=lora_adapter_metadata, ) # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.fuse_lora @@ -5131,6 +5473,8 @@ def lora_state_dict( allowed by Git. subfolder (`str`, *optional*, defaults to `""`): The subfolder location of a model file within a larger model repository on the Hub or locally. + return_lora_metadata (`bool`, *optional*, defaults to False): + When enabled, additionally return the LoRA adapter metadata, typically found in the state dict. """ # Load the main state dict first which has the LoRA layers for either of @@ -5144,18 +5488,16 @@ def lora_state_dict( subfolder = kwargs.pop("subfolder", None) weight_name = kwargs.pop("weight_name", None) use_safetensors = kwargs.pop("use_safetensors", None) + return_lora_metadata = kwargs.pop("return_lora_metadata", False) allow_pickle = False if use_safetensors is None: use_safetensors = True allow_pickle = True - user_agent = { - "file_type": "attn_procs_weights", - "framework": "pytorch", - } + user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"} - state_dict = _fetch_state_dict( + state_dict, metadata = _fetch_state_dict( pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict, weight_name=weight_name, use_safetensors=use_safetensors, @@ -5176,7 +5518,8 @@ def lora_state_dict( logger.warning(warn_msg) state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k} - return state_dict + out = (state_dict, metadata) if return_lora_metadata else state_dict + return out # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights def load_lora_weights( @@ -5221,7 +5564,8 @@ def load_lora_weights( pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy() # First, ensure that the checkpoint is a compatible one and can be successfully loaded. - state_dict = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs) + kwargs["return_lora_metadata"] = True + state_dict, metadata = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs) is_correct_format = all("lora" in key for key in state_dict.keys()) if not is_correct_format: @@ -5231,6 +5575,7 @@ def load_lora_weights( state_dict, transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer, adapter_name=adapter_name, + metadata=metadata, _pipeline=self, low_cpu_mem_usage=low_cpu_mem_usage, hotswap=hotswap, @@ -5239,7 +5584,14 @@ def load_lora_weights( @classmethod # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->CogView4Transformer2DModel def load_lora_into_transformer( - cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False, hotswap: bool = False + cls, + state_dict, + transformer, + adapter_name=None, + _pipeline=None, + low_cpu_mem_usage=False, + hotswap: bool = False, + metadata=None, ): """ This will load the LoRA layers specified in `state_dict` into `transformer`. @@ -5259,6 +5611,9 @@ def load_lora_into_transformer( weights. hotswap (`bool`, *optional*): See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`]. + metadata (`dict`): + Optional LoRA adapter metadata. When supplied, the `LoraConfig` arguments of `peft` won't be derived + from the state dict. """ if low_cpu_mem_usage and is_peft_version("<", "0.13.0"): raise ValueError( @@ -5271,6 +5626,7 @@ def load_lora_into_transformer( state_dict, network_alphas=None, adapter_name=adapter_name, + metadata=metadata, _pipeline=_pipeline, low_cpu_mem_usage=low_cpu_mem_usage, hotswap=hotswap, @@ -5286,9 +5642,10 @@ def save_lora_weights( weight_name: str = None, save_function: Callable = None, safe_serialization: bool = True, + transformer_lora_adapter_metadata: Optional[dict] = None, ): r""" - Save the LoRA parameters corresponding to the UNet and text encoder. + Save the LoRA parameters corresponding to the transformer. Arguments: save_directory (`str` or `os.PathLike`): @@ -5305,14 +5662,21 @@ def save_lora_weights( `DIFFUSERS_SAVE_MODE`. safe_serialization (`bool`, *optional*, defaults to `True`): Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`. + transformer_lora_adapter_metadata: + LoRA adapter metadata associated with the transformer to be serialized with the state dict. """ state_dict = {} + lora_adapter_metadata = {} if not transformer_lora_layers: raise ValueError("You must pass `transformer_lora_layers`.") - if transformer_lora_layers: - state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name)) + state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name)) + + if transformer_lora_adapter_metadata is not None: + lora_adapter_metadata.update( + _pack_dict_with_prefix(transformer_lora_adapter_metadata, cls.transformer_name) + ) # Save the model cls.write_lora_layers( @@ -5322,6 +5686,7 @@ def save_lora_weights( weight_name=weight_name, save_function=save_function, safe_serialization=safe_serialization, + lora_adapter_metadata=lora_adapter_metadata, ) # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.fuse_lora @@ -5449,7 +5814,8 @@ def lora_state_dict( allowed by Git. subfolder (`str`, *optional*, defaults to `""`): The subfolder location of a model file within a larger model repository on the Hub or locally. - + return_lora_metadata (`bool`, *optional*, defaults to False): + When enabled, additionally return the LoRA adapter metadata, typically found in the state dict. """ # Load the main state dict first which has the LoRA layers for either of # transformer and text encoder or both. @@ -5462,18 +5828,16 @@ def lora_state_dict( subfolder = kwargs.pop("subfolder", None) weight_name = kwargs.pop("weight_name", None) use_safetensors = kwargs.pop("use_safetensors", None) + return_lora_metadata = kwargs.pop("return_lora_metadata", False) allow_pickle = False if use_safetensors is None: use_safetensors = True allow_pickle = True - user_agent = { - "file_type": "attn_procs_weights", - "framework": "pytorch", - } + user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"} - state_dict = _fetch_state_dict( + state_dict, metadata = _fetch_state_dict( pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict, weight_name=weight_name, use_safetensors=use_safetensors, @@ -5498,7 +5862,8 @@ def lora_state_dict( if is_non_diffusers_format: state_dict = _convert_non_diffusers_hidream_lora_to_diffusers(state_dict) - return state_dict + out = (state_dict, metadata) if return_lora_metadata else state_dict + return out # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights def load_lora_weights( @@ -5543,7 +5908,8 @@ def load_lora_weights( pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy() # First, ensure that the checkpoint is a compatible one and can be successfully loaded. - state_dict = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs) + kwargs["return_lora_metadata"] = True + state_dict, metadata = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs) is_correct_format = all("lora" in key for key in state_dict.keys()) if not is_correct_format: @@ -5553,6 +5919,7 @@ def load_lora_weights( state_dict, transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer, adapter_name=adapter_name, + metadata=metadata, _pipeline=self, low_cpu_mem_usage=low_cpu_mem_usage, hotswap=hotswap, @@ -5561,7 +5928,14 @@ def load_lora_weights( @classmethod # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->HiDreamImageTransformer2DModel def load_lora_into_transformer( - cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False, hotswap: bool = False + cls, + state_dict, + transformer, + adapter_name=None, + _pipeline=None, + low_cpu_mem_usage=False, + hotswap: bool = False, + metadata=None, ): """ This will load the LoRA layers specified in `state_dict` into `transformer`. @@ -5581,6 +5955,9 @@ def load_lora_into_transformer( weights. hotswap (`bool`, *optional*): See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`]. + metadata (`dict`): + Optional LoRA adapter metadata. When supplied, the `LoraConfig` arguments of `peft` won't be derived + from the state dict. """ if low_cpu_mem_usage and is_peft_version("<", "0.13.0"): raise ValueError( @@ -5593,6 +5970,7 @@ def load_lora_into_transformer( state_dict, network_alphas=None, adapter_name=adapter_name, + metadata=metadata, _pipeline=_pipeline, low_cpu_mem_usage=low_cpu_mem_usage, hotswap=hotswap, @@ -5608,9 +5986,10 @@ def save_lora_weights( weight_name: str = None, save_function: Callable = None, safe_serialization: bool = True, + transformer_lora_adapter_metadata: Optional[dict] = None, ): r""" - Save the LoRA parameters corresponding to the UNet and text encoder. + Save the LoRA parameters corresponding to the transformer. Arguments: save_directory (`str` or `os.PathLike`): @@ -5627,14 +6006,21 @@ def save_lora_weights( `DIFFUSERS_SAVE_MODE`. safe_serialization (`bool`, *optional*, defaults to `True`): Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`. + transformer_lora_adapter_metadata: + LoRA adapter metadata associated with the transformer to be serialized with the state dict. """ state_dict = {} + lora_adapter_metadata = {} if not transformer_lora_layers: raise ValueError("You must pass `transformer_lora_layers`.") - if transformer_lora_layers: - state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name)) + state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name)) + + if transformer_lora_adapter_metadata is not None: + lora_adapter_metadata.update( + _pack_dict_with_prefix(transformer_lora_adapter_metadata, cls.transformer_name) + ) # Save the model cls.write_lora_layers( @@ -5644,6 +6030,7 @@ def save_lora_weights( weight_name=weight_name, save_function=save_function, safe_serialization=safe_serialization, + lora_adapter_metadata=lora_adapter_metadata, ) # Copied from diffusers.loaders.lora_pipeline.SanaLoraLoaderMixin.fuse_lora diff --git a/src/diffusers/loaders/peft.py b/src/diffusers/loaders/peft.py index 2c28cef11210..0480e93f356f 100644 --- a/src/diffusers/loaders/peft.py +++ b/src/diffusers/loaders/peft.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import inspect +import json import os from functools import partial from pathlib import Path @@ -185,6 +186,7 @@ def load_lora_adapter( Note that hotswapping adapters of the text encoder is not yet supported. There are some further limitations to this technique, which are documented here: https://huggingface.co/docs/peft/main/en/package_reference/hotswap + metadata: TODO """ from peft import LoraConfig, inject_adapter_in_model, set_peft_model_state_dict from peft.tuners.tuners_utils import BaseTunerLayer @@ -202,6 +204,7 @@ def load_lora_adapter( network_alphas = kwargs.pop("network_alphas", None) _pipeline = kwargs.pop("_pipeline", None) low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", False) + metadata = kwargs.pop("metadata", None) allow_pickle = False if low_cpu_mem_usage and is_peft_version("<=", "0.13.0"): @@ -209,12 +212,9 @@ def load_lora_adapter( "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`." ) - user_agent = { - "file_type": "attn_procs_weights", - "framework": "pytorch", - } + user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"} - state_dict = _fetch_state_dict( + state_dict, metadata = _fetch_state_dict( pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict, weight_name=weight_name, use_safetensors=use_safetensors, @@ -227,12 +227,17 @@ def load_lora_adapter( subfolder=subfolder, user_agent=user_agent, allow_pickle=allow_pickle, + metadata=metadata, ) if network_alphas is not None and prefix is None: raise ValueError("`network_alphas` cannot be None when `prefix` is None.") + if network_alphas and metadata: + raise ValueError("Both `network_alphas` and `metadata` cannot be specified.") if prefix is not None: state_dict = {k.removeprefix(f"{prefix}."): v for k, v in state_dict.items() if k.startswith(f"{prefix}.")} + if metadata is not None: + metadata = {k.removeprefix(f"{prefix}."): v for k, v in metadata.items() if k.startswith(f"{prefix}.")} if len(state_dict) > 0: if adapter_name in getattr(self, "peft_config", {}) and not hotswap: @@ -267,7 +272,12 @@ def load_lora_adapter( k.removeprefix(f"{prefix}."): v for k, v in network_alphas.items() if k in alpha_keys } - lora_config_kwargs = get_peft_kwargs(rank, network_alpha_dict=network_alphas, peft_state_dict=state_dict) + if metadata is not None: + lora_config_kwargs = metadata + else: + lora_config_kwargs = get_peft_kwargs( + rank, network_alpha_dict=network_alphas, peft_state_dict=state_dict + ) _maybe_raise_error_for_ambiguity(lora_config_kwargs) if "use_dora" in lora_config_kwargs: @@ -290,7 +300,11 @@ def load_lora_adapter( if is_peft_version("<=", "0.13.2"): lora_config_kwargs.pop("lora_bias") - lora_config = LoraConfig(**lora_config_kwargs) + try: + lora_config = LoraConfig(**lora_config_kwargs) + except TypeError as e: + raise TypeError("`LoraConfig` class could not be instantiated.") from e + # adapter_name if adapter_name is None: adapter_name = get_adapter_name(self) @@ -445,17 +459,13 @@ def save_lora_adapter( underlying model has multiple adapters loaded. upcast_before_saving (`bool`, defaults to `False`): Whether to cast the underlying model to `torch.float32` before serialization. - save_function (`Callable`): - The function to use to save the state dictionary. Useful during distributed training when you need to - replace `torch.save` with another method. Can be configured with the environment variable - `DIFFUSERS_SAVE_MODE`. safe_serialization (`bool`, *optional*, defaults to `True`): Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`. weight_name: (`str`, *optional*, defaults to `None`): Name of the file to serialize the state dict with. """ from peft.utils import get_peft_model_state_dict - from .lora_base import LORA_WEIGHT_NAME, LORA_WEIGHT_NAME_SAFE + from .lora_base import LORA_ADAPTER_METADATA_KEY, LORA_WEIGHT_NAME, LORA_WEIGHT_NAME_SAFE if adapter_name is None: adapter_name = get_adapter_name(self) @@ -463,6 +473,8 @@ def save_lora_adapter( if adapter_name not in getattr(self, "peft_config", {}): raise ValueError(f"Adapter name {adapter_name} not found in the model.") + lora_adapter_metadata = self.peft_config[adapter_name].to_dict() + lora_layers_to_save = get_peft_model_state_dict( self.to(dtype=torch.float32 if upcast_before_saving else None), adapter_name=adapter_name ) @@ -472,7 +484,15 @@ def save_lora_adapter( if safe_serialization: def save_function(weights, filename): - return safetensors.torch.save_file(weights, filename, metadata={"format": "pt"}) + # Inject framework format. + metadata = {"format": "pt"} + if lora_adapter_metadata is not None: + for key, value in lora_adapter_metadata.items(): + if isinstance(value, set): + lora_adapter_metadata[key] = list(value) + metadata[LORA_ADAPTER_METADATA_KEY] = json.dumps(lora_adapter_metadata, indent=2, sort_keys=True) + + return safetensors.torch.save_file(weights, filename, metadata=metadata) else: save_function = torch.save @@ -485,7 +505,6 @@ def save_function(weights, filename): else: weight_name = LORA_WEIGHT_NAME - # TODO: we could consider saving the `peft_config` as well. save_path = Path(save_directory, weight_name).as_posix() save_function(lora_layers_to_save, save_path) logger.info(f"Model weights saved in {save_path}") diff --git a/src/diffusers/loaders/unet.py b/src/diffusers/loaders/unet.py index 1d8aba900c85..d9308c57bfe5 100644 --- a/src/diffusers/loaders/unet.py +++ b/src/diffusers/loaders/unet.py @@ -155,10 +155,7 @@ def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict use_safetensors = True allow_pickle = True - user_agent = { - "file_type": "attn_procs_weights", - "framework": "pytorch", - } + user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"} model_file = None if not isinstance(pretrained_model_name_or_path_or_dict, dict): diff --git a/src/diffusers/utils/state_dict_utils.py b/src/diffusers/utils/state_dict_utils.py index 15a91040c48c..498f7e566c67 100644 --- a/src/diffusers/utils/state_dict_utils.py +++ b/src/diffusers/utils/state_dict_utils.py @@ -16,6 +16,7 @@ """ import enum +import json from .import_utils import is_torch_available from .logging import get_logger @@ -347,3 +348,16 @@ def state_dict_all_zero(state_dict, filter_str=None): state_dict = {k: v for k, v in state_dict.items() if any(f in k for f in filter_str)} return all(torch.all(param == 0).item() for param in state_dict.values()) + + +def _load_sft_state_dict_metadata(model_file: str): + import safetensors.torch + + from ..loaders.lora_base import LORA_ADAPTER_METADATA_KEY + + with safetensors.torch.safe_open(model_file, framework="pt", device="cpu") as f: + metadata = f.metadata() or {} + + metadata.pop("format", None) + raw = metadata.get(LORA_ADAPTER_METADATA_KEY) + return json.loads(raw) if raw else None diff --git a/src/diffusers/utils/testing_utils.py b/src/diffusers/utils/testing_utils.py index 5cbe5ff27780..e5da39c1d865 100644 --- a/src/diffusers/utils/testing_utils.py +++ b/src/diffusers/utils/testing_utils.py @@ -133,6 +133,29 @@ def numpy_cosine_similarity_distance(a, b): return distance +def check_if_dicts_are_equal(dict1, dict2): + dict1, dict2 = dict1.copy(), dict2.copy() + + for key, value in dict1.items(): + if isinstance(value, set): + dict1[key] = sorted(value) + for key, value in dict2.items(): + if isinstance(value, set): + dict2[key] = sorted(value) + + for key in dict1: + if key not in dict2: + return False + if dict1[key] != dict2[key]: + return False + + for key in dict2: + if key not in dict1: + return False + + return True + + def print_tensor_test( tensor, limit_to_slices=None, diff --git a/tests/lora/test_lora_layers_wan.py b/tests/lora/test_lora_layers_wan.py index c2498fa68c3d..a1420012d601 100644 --- a/tests/lora/test_lora_layers_wan.py +++ b/tests/lora/test_lora_layers_wan.py @@ -24,11 +24,7 @@ WanPipeline, WanTransformer3DModel, ) -from diffusers.utils.testing_utils import ( - floats_tensor, - require_peft_backend, - skip_mps, -) +from diffusers.utils.testing_utils import floats_tensor, require_peft_backend, skip_mps sys.path.append(".") diff --git a/tests/lora/utils.py b/tests/lora/utils.py index 7c89f5a47df9..e419c61f6602 100644 --- a/tests/lora/utils.py +++ b/tests/lora/utils.py @@ -22,6 +22,7 @@ import numpy as np import pytest import torch +from parameterized import parameterized from diffusers import ( AutoencoderKL, @@ -33,6 +34,7 @@ from diffusers.utils.import_utils import is_peft_available from diffusers.utils.testing_utils import ( CaptureLogger, + check_if_dicts_are_equal, floats_tensor, is_torch_version, require_peft_backend, @@ -71,6 +73,13 @@ def check_if_lora_correctly_set(model) -> bool: return False +def check_module_lora_metadata(parsed_metadata: dict, lora_metadatas: dict, module_key: str): + extracted = { + k.removeprefix(f"{module_key}."): v for k, v in parsed_metadata.items() if k.startswith(f"{module_key}.") + } + check_if_dicts_are_equal(extracted, lora_metadatas[f"{module_key}_lora_adapter_metadata"]) + + def initialize_dummy_state_dict(state_dict): if not all(v.device.type == "meta" for _, v in state_dict.items()): raise ValueError("`state_dict` has non-meta values.") @@ -118,7 +127,7 @@ class PeftLoraLoaderMixinTests: text_encoder_target_modules = ["q_proj", "k_proj", "v_proj", "out_proj"] denoiser_target_modules = ["to_q", "to_k", "to_v", "to_out.0"] - def get_dummy_components(self, scheduler_cls=None, use_dora=False): + def get_dummy_components(self, scheduler_cls=None, use_dora=False, lora_alpha=None): if self.unet_kwargs and self.transformer_kwargs: raise ValueError("Both `unet_kwargs` and `transformer_kwargs` cannot be specified.") if self.has_two_text_encoders and self.has_three_text_encoders: @@ -126,6 +135,7 @@ def get_dummy_components(self, scheduler_cls=None, use_dora=False): scheduler_cls = self.scheduler_cls if scheduler_cls is None else scheduler_cls rank = 4 + lora_alpha = rank if lora_alpha is None else lora_alpha torch.manual_seed(0) if self.unet_kwargs is not None: @@ -161,7 +171,7 @@ def get_dummy_components(self, scheduler_cls=None, use_dora=False): text_lora_config = LoraConfig( r=rank, - lora_alpha=rank, + lora_alpha=lora_alpha, target_modules=self.text_encoder_target_modules, init_lora_weights=False, use_dora=use_dora, @@ -169,7 +179,7 @@ def get_dummy_components(self, scheduler_cls=None, use_dora=False): denoiser_lora_config = LoraConfig( r=rank, - lora_alpha=rank, + lora_alpha=lora_alpha, target_modules=self.denoiser_target_modules, init_lora_weights=False, use_dora=use_dora, @@ -246,6 +256,13 @@ def _get_lora_state_dicts(self, modules_to_save): state_dicts[f"{module_name}_lora_layers"] = get_peft_model_state_dict(module) return state_dicts + def _get_lora_adapter_metadata(self, modules_to_save): + metadatas = {} + for module_name, module in modules_to_save.items(): + if module is not None: + metadatas[f"{module_name}_lora_adapter_metadata"] = module.peft_config["default"].to_dict() + return metadatas + def _get_modules_to_save(self, pipe, has_denoiser=False): modules_to_save = {} lora_loadable_modules = self.pipeline_class._lora_loadable_modules @@ -2214,6 +2231,86 @@ def check_module(denoiser): _, _, inputs = self.get_dummy_inputs(with_generator=False) pipe(**inputs, generator=torch.manual_seed(0))[0] + @parameterized.expand([4, 8, 16]) + def test_lora_adapter_metadata_is_loaded_correctly(self, lora_alpha): + scheduler_cls = self.scheduler_classes[0] + components, text_lora_config, denoiser_lora_config = self.get_dummy_components( + scheduler_cls, lora_alpha=lora_alpha + ) + pipe = self.pipeline_class(**components) + + pipe, _ = self.check_if_adapters_added_correctly( + pipe, text_lora_config=text_lora_config, denoiser_lora_config=denoiser_lora_config + ) + + with tempfile.TemporaryDirectory() as tmpdir: + modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True) + lora_state_dicts = self._get_lora_state_dicts(modules_to_save) + lora_metadatas = self._get_lora_adapter_metadata(modules_to_save) + self.pipeline_class.save_lora_weights(save_directory=tmpdir, **lora_state_dicts, **lora_metadatas) + pipe.unload_lora_weights() + + out = pipe.lora_state_dict(tmpdir, return_lora_metadata=True) + if len(out) == 3: + _, _, parsed_metadata = out + elif len(out) == 2: + _, parsed_metadata = out + + denoiser_key = ( + f"{self.pipeline_class.transformer_name}" + if self.transformer_kwargs is not None + else f"{self.pipeline_class.unet_name}" + ) + self.assertTrue(any(k.startswith(f"{denoiser_key}.") for k in parsed_metadata)) + check_module_lora_metadata( + parsed_metadata=parsed_metadata, lora_metadatas=lora_metadatas, module_key=denoiser_key + ) + + if "text_encoder" in self.pipeline_class._lora_loadable_modules: + text_encoder_key = self.pipeline_class.text_encoder_name + self.assertTrue(any(k.startswith(f"{text_encoder_key}.") for k in parsed_metadata)) + check_module_lora_metadata( + parsed_metadata=parsed_metadata, lora_metadatas=lora_metadatas, module_key=text_encoder_key + ) + + if "text_encoder_2" in self.pipeline_class._lora_loadable_modules: + text_encoder_2_key = "text_encoder_2" + self.assertTrue(any(k.startswith(f"{text_encoder_2_key}.") for k in parsed_metadata)) + check_module_lora_metadata( + parsed_metadata=parsed_metadata, lora_metadatas=lora_metadatas, module_key=text_encoder_2_key + ) + + @parameterized.expand([4, 8, 16]) + def test_lora_adapter_metadata_save_load_inference(self, lora_alpha): + scheduler_cls = self.scheduler_classes[0] + components, text_lora_config, denoiser_lora_config = self.get_dummy_components( + scheduler_cls, lora_alpha=lora_alpha + ) + pipe = self.pipeline_class(**components).to(torch_device) + _, _, inputs = self.get_dummy_inputs(with_generator=False) + + output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] + self.assertTrue(output_no_lora.shape == self.output_shape) + + pipe, _ = self.check_if_adapters_added_correctly( + pipe, text_lora_config=text_lora_config, denoiser_lora_config=denoiser_lora_config + ) + output_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] + + with tempfile.TemporaryDirectory() as tmpdir: + modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True) + lora_state_dicts = self._get_lora_state_dicts(modules_to_save) + lora_metadatas = self._get_lora_adapter_metadata(modules_to_save) + self.pipeline_class.save_lora_weights(save_directory=tmpdir, **lora_state_dicts, **lora_metadatas) + pipe.unload_lora_weights() + pipe.load_lora_weights(tmpdir) + + output_lora_pretrained = pipe(**inputs, generator=torch.manual_seed(0))[0] + + self.assertTrue( + np.allclose(output_lora, output_lora_pretrained, atol=1e-3, rtol=1e-3), "Lora outputs should match." + ) + def test_inference_load_delete_load_adapters(self): "Tests if `load_lora_weights()` -> `delete_adapters()` -> `load_lora_weights()` works." for scheduler_cls in self.scheduler_classes: diff --git a/tests/models/test_modeling_common.py b/tests/models/test_modeling_common.py index 5087bd0094a5..511fa4bfa9ea 100644 --- a/tests/models/test_modeling_common.py +++ b/tests/models/test_modeling_common.py @@ -30,6 +30,7 @@ import numpy as np import requests_mock +import safetensors.torch import torch import torch.nn as nn from accelerate.utils.modeling import _get_proper_dtype, compute_module_sizes, dtype_byte_size @@ -62,6 +63,7 @@ backend_max_memory_allocated, backend_reset_peak_memory_stats, backend_synchronize, + check_if_dicts_are_equal, get_python_version, is_torch_compile, numpy_cosine_similarity_distance, @@ -1057,11 +1059,10 @@ def test_deprecated_kwargs(self): " from `_deprecated_kwargs = []`" ) - @parameterized.expand([True, False]) + @parameterized.expand([(4, 4, True), (4, 8, False), (8, 4, False)]) @torch.no_grad() @unittest.skipIf(not is_peft_available(), "Only with PEFT") - def test_lora_save_load_adapter(self, use_dora=False): - import safetensors + def test_save_load_lora_adapter(self, rank, lora_alpha, use_dora=False): from peft import LoraConfig from peft.utils import get_peft_model_state_dict @@ -1077,8 +1078,8 @@ def test_lora_save_load_adapter(self, use_dora=False): output_no_lora = model(**inputs_dict, return_dict=False)[0] denoiser_lora_config = LoraConfig( - r=4, - lora_alpha=4, + r=rank, + lora_alpha=lora_alpha, target_modules=["to_q", "to_k", "to_v", "to_out.0"], init_lora_weights=False, use_dora=use_dora, @@ -1145,6 +1146,90 @@ def test_lora_wrong_adapter_name_raises_error(self): self.assertTrue(f"Adapter name {wrong_name} not found in the model." in str(err_context.exception)) + @parameterized.expand([(4, 4, True), (4, 8, False), (8, 4, False)]) + @torch.no_grad() + @unittest.skipIf(not is_peft_available(), "Only with PEFT") + def test_lora_adapter_metadata_is_loaded_correctly(self, rank, lora_alpha, use_dora): + from peft import LoraConfig + + from diffusers.loaders.peft import PeftAdapterMixin + + init_dict, _ = self.prepare_init_args_and_inputs_for_common() + model = self.model_class(**init_dict).to(torch_device) + + if not issubclass(model.__class__, PeftAdapterMixin): + return + + denoiser_lora_config = LoraConfig( + r=rank, + lora_alpha=lora_alpha, + target_modules=["to_q", "to_k", "to_v", "to_out.0"], + init_lora_weights=False, + use_dora=use_dora, + ) + model.add_adapter(denoiser_lora_config) + metadata = model.peft_config["default"].to_dict() + self.assertTrue(check_if_lora_correctly_set(model), "LoRA layers not set correctly") + + with tempfile.TemporaryDirectory() as tmpdir: + model.save_lora_adapter(tmpdir) + model_file = os.path.join(tmpdir, "pytorch_lora_weights.safetensors") + self.assertTrue(os.path.isfile(model_file)) + + model.unload_lora() + self.assertFalse(check_if_lora_correctly_set(model), "LoRA layers not set correctly") + + model.load_lora_adapter(tmpdir, prefix=None, use_safetensors=True) + parsed_metadata = model.peft_config["default_0"].to_dict() + check_if_dicts_are_equal(metadata, parsed_metadata) + + @torch.no_grad() + @unittest.skipIf(not is_peft_available(), "Only with PEFT") + def test_lora_adapter_wrong_metadata_raises_error(self): + from peft import LoraConfig + + from diffusers.loaders.lora_base import LORA_ADAPTER_METADATA_KEY + from diffusers.loaders.peft import PeftAdapterMixin + + init_dict, _ = self.prepare_init_args_and_inputs_for_common() + model = self.model_class(**init_dict).to(torch_device) + + if not issubclass(model.__class__, PeftAdapterMixin): + return + + denoiser_lora_config = LoraConfig( + r=4, + lora_alpha=4, + target_modules=["to_q", "to_k", "to_v", "to_out.0"], + init_lora_weights=False, + use_dora=False, + ) + model.add_adapter(denoiser_lora_config) + self.assertTrue(check_if_lora_correctly_set(model), "LoRA layers not set correctly") + + with tempfile.TemporaryDirectory() as tmpdir: + model.save_lora_adapter(tmpdir) + model_file = os.path.join(tmpdir, "pytorch_lora_weights.safetensors") + self.assertTrue(os.path.isfile(model_file)) + + # Perturb the metadata in the state dict. + loaded_state_dict = safetensors.torch.load_file(model_file) + metadata = {"format": "pt"} + lora_adapter_metadata = denoiser_lora_config.to_dict() + lora_adapter_metadata.update({"foo": 1, "bar": 2}) + for key, value in lora_adapter_metadata.items(): + if isinstance(value, set): + lora_adapter_metadata[key] = list(value) + metadata[LORA_ADAPTER_METADATA_KEY] = json.dumps(lora_adapter_metadata, indent=2, sort_keys=True) + safetensors.torch.save_file(loaded_state_dict, model_file, metadata=metadata) + + model.unload_lora() + self.assertFalse(check_if_lora_correctly_set(model), "LoRA layers not set correctly") + + with self.assertRaises(TypeError) as err_context: + model.load_lora_adapter(tmpdir, prefix=None, use_safetensors=True) + self.assertTrue("`LoraConfig` class could not be instantiated" in str(err_context.exception)) + @require_torch_accelerator def test_cpu_offload(self): config, inputs_dict = self.prepare_init_args_and_inputs_for_common()