From 5139de1165bc833253dd09099fa9b60c080aba81 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Tue, 15 Apr 2025 11:44:31 +0530 Subject: [PATCH 01/53] feat: parse metadata from lora state dicts. --- src/diffusers/loaders/lora_base.py | 21 +++++++++++++- src/diffusers/loaders/lora_pipeline.py | 23 +++++++++++++++- src/diffusers/loaders/peft.py | 38 +++++++++++++++++++++++--- src/diffusers/utils/peft_utils.py | 7 ++++- tests/lora/test_lora_layers_wan.py | 33 ++++++++++++++++++---- 5 files changed, 110 insertions(+), 12 deletions(-) diff --git a/src/diffusers/loaders/lora_base.py b/src/diffusers/loaders/lora_base.py index 280a9fa6e73f..ae590245f3c8 100644 --- a/src/diffusers/loaders/lora_base.py +++ b/src/diffusers/loaders/lora_base.py @@ -206,6 +206,7 @@ def _fetch_state_dict( subfolder, user_agent, allow_pickle, + load_with_metadata=False, ): model_file = None if not isinstance(pretrained_model_name_or_path_or_dict, dict): @@ -223,6 +224,9 @@ def _fetch_state_dict( file_extension=".safetensors", local_files_only=local_files_only, ) + if load_with_metadata and not weight_name.endswith(".safetensors"): + raise ValueError("`load_with_metadata` cannot be set to True when not using safetensors.") + model_file = _get_model_file( pretrained_model_name_or_path_or_dict, weights_name=weight_name or LORA_WEIGHT_NAME_SAFE, @@ -236,6 +240,12 @@ def _fetch_state_dict( user_agent=user_agent, ) state_dict = safetensors.torch.load_file(model_file, device="cpu") + if load_with_metadata: + with safetensors.torch.safe_open(model_file, framework="pt", device="cpu") as f: + if hasattr(f, "metadata") and f.metadata() is not None: + state_dict["_metadata"] = f.metadata() + else: + raise ValueError("Metadata couldn't be parsed from the safetensors file.") except (IOError, safetensors.SafetensorError) as e: if not allow_pickle: raise e @@ -882,16 +892,25 @@ def write_lora_layers( weight_name: str, save_function: Callable, safe_serialization: bool, + lora_adapter_metadata: dict = None, ): if os.path.isfile(save_directory): logger.error(f"Provided path ({save_directory}) should be a directory, not a file") return + if lora_adapter_metadata is not None and not safe_serialization: + raise ValueError("`lora_adapter_metadata` cannot be specified when not using `safe_serialization`.") + if not isinstance(lora_adapter_metadata, dict): + raise ValueError("`lora_adapter_metadata` must be of type `dict`.") + if save_function is None: if safe_serialization: def save_function(weights, filename): - return safetensors.torch.save_file(weights, filename, metadata={"format": "pt"}) + metadata = {"format": "pt"} + if lora_adapter_metadata is not None: + metadata.update(lora_adapter_metadata) + return safetensors.torch.save_file(weights, filename, metadata=metadata) else: save_function = torch.save diff --git a/src/diffusers/loaders/lora_pipeline.py b/src/diffusers/loaders/lora_pipeline.py index 2e241bc9ffad..00596b1b0139 100644 --- a/src/diffusers/loaders/lora_pipeline.py +++ b/src/diffusers/loaders/lora_pipeline.py @@ -4734,6 +4734,7 @@ def lora_state_dict( - A [torch state dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict). + load_with_metadata: TODO cache_dir (`Union[str, os.PathLike]`, *optional*): Path to a directory where a downloaded pretrained model configuration is cached if the standard cache is not used. @@ -4768,6 +4769,7 @@ def lora_state_dict( subfolder = kwargs.pop("subfolder", None) weight_name = kwargs.pop("weight_name", None) use_safetensors = kwargs.pop("use_safetensors", None) + load_with_metadata = kwargs.pop("load_with_metadata", False) allow_pickle = False if use_safetensors is None: @@ -4792,6 +4794,7 @@ def lora_state_dict( subfolder=subfolder, user_agent=user_agent, allow_pickle=allow_pickle, + load_with_metadata=load_with_metadata, ) if any(k.startswith("diffusion_model.") for k in state_dict): state_dict = _convert_non_diffusers_wan_lora_to_diffusers(state_dict) @@ -4859,6 +4862,7 @@ def load_lora_weights( raise ValueError("PEFT backend is required for this method.") low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT_LORA) + load_with_metdata = kwargs.get("load_with_metdata", False) if low_cpu_mem_usage and is_peft_version("<", "0.13.0"): raise ValueError( "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`." @@ -4885,12 +4889,20 @@ def load_lora_weights( adapter_name=adapter_name, _pipeline=self, low_cpu_mem_usage=low_cpu_mem_usage, + load_with_metdata=load_with_metdata, ) @classmethod # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->WanTransformer3DModel def load_lora_into_transformer( - cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False, hotswap: bool = False + cls, + state_dict, + transformer, + adapter_name=None, + _pipeline=None, + low_cpu_mem_usage=False, + hotswap: bool = False, + load_with_metadata: bool = False, ): """ This will load the LoRA layers specified in `state_dict` into `transformer`. @@ -4931,6 +4943,7 @@ def load_lora_into_transformer( Note that hotswapping adapters of the text encoder is not yet supported. There are some further limitations to this technique, which are documented here: https://huggingface.co/docs/peft/main/en/package_reference/hotswap + load_with_metadata: TODO """ if low_cpu_mem_usage and is_peft_version("<", "0.13.0"): raise ValueError( @@ -4946,6 +4959,7 @@ def load_lora_into_transformer( _pipeline=_pipeline, low_cpu_mem_usage=low_cpu_mem_usage, hotswap=hotswap, + load_with_metadata=load_with_metadata, ) @classmethod @@ -4958,6 +4972,7 @@ def save_lora_weights( weight_name: str = None, save_function: Callable = None, safe_serialization: bool = True, + transformer_lora_adapter_metadata: Optional[dict] = None, ): r""" Save the LoRA parameters corresponding to the UNet and text encoder. @@ -4977,8 +4992,10 @@ def save_lora_weights( `DIFFUSERS_SAVE_MODE`. safe_serialization (`bool`, *optional*, defaults to `True`): Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`. + transformer_lora_adapter_metadata: TODO """ state_dict = {} + lora_adapter_metadata = {} if not transformer_lora_layers: raise ValueError("You must pass `transformer_lora_layers`.") @@ -4986,6 +5003,9 @@ def save_lora_weights( if transformer_lora_layers: state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name)) + if transformer_lora_adapter_metadata: + lora_adapter_metadata.update(cls.pack_weights(transformer_lora_adapter_metadata, cls.transformer_name)) + # Save the model cls.write_lora_layers( state_dict=state_dict, @@ -4994,6 +5014,7 @@ def save_lora_weights( weight_name=weight_name, save_function=save_function, safe_serialization=safe_serialization, + lora_adapter_metadata=lora_adapter_metadata, ) # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.fuse_lora diff --git a/src/diffusers/loaders/peft.py b/src/diffusers/loaders/peft.py index 9165c46f3c78..208ee4b7a5fe 100644 --- a/src/diffusers/loaders/peft.py +++ b/src/diffusers/loaders/peft.py @@ -113,7 +113,12 @@ def _optionally_disable_offloading(cls, _pipeline): return _func_optionally_disable_offloading(_pipeline=_pipeline) def load_lora_adapter( - self, pretrained_model_name_or_path_or_dict, prefix="transformer", hotswap: bool = False, **kwargs + self, + pretrained_model_name_or_path_or_dict, + prefix="transformer", + hotswap: bool = False, + load_with_metadata: bool = False, + **kwargs, ): r""" Loads a LoRA adapter into the underlying model. @@ -181,6 +186,8 @@ def load_lora_adapter( Note that hotswapping adapters of the text encoder is not yet supported. There are some further limitations to this technique, which are documented here: https://huggingface.co/docs/peft/main/en/package_reference/hotswap + + load_with_metadata: TODO """ from peft import LoraConfig, inject_adapter_in_model, set_peft_model_state_dict from peft.tuners.tuners_utils import BaseTunerLayer @@ -223,10 +230,14 @@ def load_lora_adapter( subfolder=subfolder, user_agent=user_agent, allow_pickle=allow_pickle, + load_with_metadata=load_with_metadata, ) if network_alphas is not None and prefix is None: raise ValueError("`network_alphas` cannot be None when `prefix` is None.") + if load_with_metadata is not None and not use_safetensors: + raise ValueError("`load_with_metadata` cannot be specified when not using `use_safetensors`.") + if prefix is not None: state_dict = {k[len(f"{prefix}.") :]: v for k, v in state_dict.items() if k.startswith(f"{prefix}.")} @@ -261,7 +272,12 @@ def load_lora_adapter( alpha_keys = [k for k in network_alphas.keys() if k.startswith(f"{prefix}.")] network_alphas = {k.replace(f"{prefix}.", ""): v for k, v in network_alphas.items() if k in alpha_keys} - lora_config_kwargs = get_peft_kwargs(rank, network_alpha_dict=network_alphas, peft_state_dict=state_dict) + lora_config_kwargs = get_peft_kwargs( + rank, + network_alpha_dict=network_alphas, + peft_state_dict=state_dict, + load_with_metadata=load_with_metadata, + ) _maybe_raise_error_for_ambiguity(lora_config_kwargs) if "use_dora" in lora_config_kwargs: @@ -284,7 +300,11 @@ def load_lora_adapter( if is_peft_version("<=", "0.13.2"): lora_config_kwargs.pop("lora_bias") - lora_config = LoraConfig(**lora_config_kwargs) + try: + lora_config = LoraConfig(**lora_config_kwargs) + except TypeError as e: + logger.error(f"`LoraConfig` class could not be instantiated with the following trace: {e}.") + # adapter_name if adapter_name is None: adapter_name = get_adapter_name(self) @@ -428,6 +448,7 @@ def save_lora_adapter( upcast_before_saving: bool = False, safe_serialization: bool = True, weight_name: Optional[str] = None, + lora_adapter_metadata: Optional[dict] = None, ): """ Save the LoRA parameters corresponding to the underlying model. @@ -446,11 +467,17 @@ def save_lora_adapter( safe_serialization (`bool`, *optional*, defaults to `True`): Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`. weight_name: (`str`, *optional*, defaults to `None`): Name of the file to serialize the state dict with. + lora_adapter_metadata: TODO """ from peft.utils import get_peft_model_state_dict from .lora_base import LORA_WEIGHT_NAME, LORA_WEIGHT_NAME_SAFE + if lora_adapter_metadata is not None and not safe_serialization: + raise ValueError("`lora_adapter_metadata` cannot be specified when not using `safe_serialization`.") + if not isinstance(lora_adapter_metadata, dict): + raise ValueError("`lora_adapter_metadata` must be of type `dict`.") + if adapter_name is None: adapter_name = get_adapter_name(self) @@ -466,7 +493,10 @@ def save_lora_adapter( if safe_serialization: def save_function(weights, filename): - return safetensors.torch.save_file(weights, filename, metadata={"format": "pt"}) + metadata = {"format": "pt"} + if lora_adapter_metadata is not None: + metadata.update(lora_adapter_metadata) + return safetensors.torch.save_file(weights, filename, metadata=metadata) else: save_function = torch.save diff --git a/src/diffusers/utils/peft_utils.py b/src/diffusers/utils/peft_utils.py index d1269fbc5f20..28c3ab29773f 100644 --- a/src/diffusers/utils/peft_utils.py +++ b/src/diffusers/utils/peft_utils.py @@ -147,7 +147,12 @@ def unscale_lora_layers(model, weight: Optional[float] = None): module.set_scale(adapter_name, 1.0) -def get_peft_kwargs(rank_dict, network_alpha_dict, peft_state_dict, is_unet=True): +def get_peft_kwargs(rank_dict, network_alpha_dict, peft_state_dict, is_unet=True, load_with_metadata=False): + if load_with_metadata: + if "_metadata" not in peft_state_dict: + raise ValueError("Couldn't find '_metadata' key in the `peft_state_dict`.") + return peft_state_dict["_metadata"] + rank_pattern = {} alpha_pattern = {} r = lora_alpha = list(rank_dict.values())[0] diff --git a/tests/lora/test_lora_layers_wan.py b/tests/lora/test_lora_layers_wan.py index c2498fa68c3d..8d13339e554d 100644 --- a/tests/lora/test_lora_layers_wan.py +++ b/tests/lora/test_lora_layers_wan.py @@ -13,6 +13,7 @@ # limitations under the License. import sys +import tempfile import unittest import torch @@ -24,11 +25,7 @@ WanPipeline, WanTransformer3DModel, ) -from diffusers.utils.testing_utils import ( - floats_tensor, - require_peft_backend, - skip_mps, -) +from diffusers.utils.testing_utils import floats_tensor, require_peft_backend, skip_mps, torch_device sys.path.append(".") @@ -141,3 +138,29 @@ def test_simple_inference_with_text_lora_fused(self): @unittest.skip("Text encoder LoRA is not supported in Wan.") def test_simple_inference_with_text_lora_save_load(self): pass + + def test_save_load_with_adapter_metadata(self): + # Will write the test in utils.py eventually. + scheduler_cls = self.scheduler_classes[0] + components, _, denoiser_lora_config = self.get_dummy_components(scheduler_cls) + pipe = self.pipeline_class(**components) + pipe = pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + _, _, inputs = self.get_dummy_inputs(with_generator=False) + + output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] + self.assertTrue(output_no_lora.shape == self.output_shape) + + pipe, _ = self.check_if_adapters_added_correctly( + pipe, text_lora_config=None, denoiser_lora_config=denoiser_lora_config + ) + + with tempfile.TemporaryDirectory() as tmpdirname: + modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True) + lora_state_dicts = self._get_lora_state_dicts(modules_to_save) + self.pipeline_class.save_lora_weights( + save_directory=tmpdirname, + safe_serialization=False, + lora_adapter_metadata=denoiser_lora_config.to_dict(), + **lora_state_dicts, + ) From d8a305e0eec38c01158aa9acce784cacc6303bdc Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Tue, 15 Apr 2025 14:43:49 +0530 Subject: [PATCH 02/53] tests --- src/diffusers/loaders/lora_base.py | 19 ++++++--- src/diffusers/loaders/peft.py | 18 ++++++--- src/diffusers/utils/peft_utils.py | 9 ++++- src/diffusers/utils/state_dict_utils.py | 17 ++++++++ tests/lora/test_lora_layers_wan.py | 52 +++++++++++++++++++++---- tests/lora/utils.py | 27 +++++++++++++ 6 files changed, 121 insertions(+), 21 deletions(-) diff --git a/src/diffusers/loaders/lora_base.py b/src/diffusers/loaders/lora_base.py index ae590245f3c8..a845fab8d97e 100644 --- a/src/diffusers/loaders/lora_base.py +++ b/src/diffusers/loaders/lora_base.py @@ -14,6 +14,7 @@ import copy import inspect +import json import os from pathlib import Path from typing import Callable, Dict, List, Optional, Union @@ -45,6 +46,7 @@ set_adapter_layers, set_weights_and_activate_adapters, ) +from ..utils.state_dict_utils import _maybe_populate_state_dict_with_metadata if is_transformers_available(): @@ -241,11 +243,10 @@ def _fetch_state_dict( ) state_dict = safetensors.torch.load_file(model_file, device="cpu") if load_with_metadata: - with safetensors.torch.safe_open(model_file, framework="pt", device="cpu") as f: - if hasattr(f, "metadata") and f.metadata() is not None: - state_dict["_metadata"] = f.metadata() - else: - raise ValueError("Metadata couldn't be parsed from the safetensors file.") + state_dict = _maybe_populate_state_dict_with_metadata( + state_dict, model_file, metadata_key="lora_adapter_config" + ) + except (IOError, safetensors.SafetensorError) as e: if not allow_pickle: raise e @@ -907,9 +908,15 @@ def write_lora_layers( if safe_serialization: def save_function(weights, filename): + # We need to be able to serialize the NoneTypes too, otherwise we run into + # 'NoneType' object cannot be converted to 'PyString' metadata = {"format": "pt"} if lora_adapter_metadata is not None: - metadata.update(lora_adapter_metadata) + for key, value in lora_adapter_metadata.items(): + if isinstance(value, set): + lora_adapter_metadata[key] = list(value) + metadata["lora_adapter_config"] = json.dumps(lora_adapter_metadata, indent=2, sort_keys=True) + return safetensors.torch.save_file(weights, filename, metadata=metadata) else: diff --git a/src/diffusers/loaders/peft.py b/src/diffusers/loaders/peft.py index 208ee4b7a5fe..70425c96b153 100644 --- a/src/diffusers/loaders/peft.py +++ b/src/diffusers/loaders/peft.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import inspect +import json import os from functools import partial from pathlib import Path @@ -239,8 +240,12 @@ def load_lora_adapter( raise ValueError("`load_with_metadata` cannot be specified when not using `use_safetensors`.") if prefix is not None: + metadata = state_dict.pop("_metadata", None) state_dict = {k[len(f"{prefix}.") :]: v for k, v in state_dict.items() if k.startswith(f"{prefix}.")} + if metadata is not None: + state_dict["_metadata"] = metadata + if len(state_dict) > 0: if adapter_name in getattr(self, "peft_config", {}) and not hotswap: raise ValueError( @@ -277,6 +282,7 @@ def load_lora_adapter( network_alpha_dict=network_alphas, peft_state_dict=state_dict, load_with_metadata=load_with_metadata, + prefix=prefix, ) _maybe_raise_error_for_ambiguity(lora_config_kwargs) @@ -460,10 +466,6 @@ def save_lora_adapter( underlying model has multiple adapters loaded. upcast_before_saving (`bool`, defaults to `False`): Whether to cast the underlying model to `torch.float32` before serialization. - save_function (`Callable`): - The function to use to save the state dictionary. Useful during distributed training when you need to - replace `torch.save` with another method. Can be configured with the environment variable - `DIFFUSERS_SAVE_MODE`. safe_serialization (`bool`, *optional*, defaults to `True`): Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`. weight_name: (`str`, *optional*, defaults to `None`): Name of the file to serialize the state dict with. @@ -493,9 +495,15 @@ def save_lora_adapter( if safe_serialization: def save_function(weights, filename): + # We need to be able to serialize the NoneTypes too, otherwise we run into + # 'NoneType' object cannot be converted to 'PyString' metadata = {"format": "pt"} if lora_adapter_metadata is not None: - metadata.update(lora_adapter_metadata) + for key, value in lora_adapter_metadata.items(): + if isinstance(value, set): + lora_adapter_metadata[key] = list(value) + metadata["lora_adapter_config"] = json.dumps(lora_adapter_metadata, indent=2, sort_keys=True) + return safetensors.torch.save_file(weights, filename, metadata=metadata) else: diff --git a/src/diffusers/utils/peft_utils.py b/src/diffusers/utils/peft_utils.py index 28c3ab29773f..cd63616178bc 100644 --- a/src/diffusers/utils/peft_utils.py +++ b/src/diffusers/utils/peft_utils.py @@ -147,11 +147,16 @@ def unscale_lora_layers(model, weight: Optional[float] = None): module.set_scale(adapter_name, 1.0) -def get_peft_kwargs(rank_dict, network_alpha_dict, peft_state_dict, is_unet=True, load_with_metadata=False): +def get_peft_kwargs( + rank_dict, network_alpha_dict, peft_state_dict, is_unet=True, prefix=None, load_with_metadata=False +): if load_with_metadata: if "_metadata" not in peft_state_dict: raise ValueError("Couldn't find '_metadata' key in the `peft_state_dict`.") - return peft_state_dict["_metadata"] + metadata = peft_state_dict["_metadata"] + if prefix is not None: + metadata = {k.replace(f"{prefix}.", ""): v for k, v in metadata.items()} + return metadata rank_pattern = {} alpha_pattern = {} diff --git a/src/diffusers/utils/state_dict_utils.py b/src/diffusers/utils/state_dict_utils.py index 3682c5bfacd6..45922ef162d2 100644 --- a/src/diffusers/utils/state_dict_utils.py +++ b/src/diffusers/utils/state_dict_utils.py @@ -16,6 +16,7 @@ """ import enum +import json from .import_utils import is_torch_available from .logging import get_logger @@ -347,3 +348,19 @@ def state_dict_all_zero(state_dict, filter_str=None): state_dict = {k: v for k, v in state_dict.items() if any(f in k for f in filter_str)} return all(torch.all(param == 0).item() for param in state_dict.values()) + + +def _maybe_populate_state_dict_with_metadata(state_dict, model_file, metadata_key): + import safetensors.torch + + with safetensors.torch.safe_open(model_file, framework="pt", device="cpu") as f: + if hasattr(f, "metadata"): + metadata = f.metadata() + if metadata is not None: + metadata_keys = list(metadata.keys()) + if not (len(metadata_keys) == 1 and metadata_keys[0] == "format"): + peft_metadata = {k: v for k, v in metadata.items() if k != "format"} + state_dict["_metadata"] = json.loads(peft_metadata[metadata_key]) + else: + raise ValueError("Metadata couldn't be parsed from the safetensors file.") + return state_dict diff --git a/tests/lora/test_lora_layers_wan.py b/tests/lora/test_lora_layers_wan.py index 8d13339e554d..0de2c5978516 100644 --- a/tests/lora/test_lora_layers_wan.py +++ b/tests/lora/test_lora_layers_wan.py @@ -16,6 +16,7 @@ import tempfile import unittest +import numpy as np import torch from transformers import AutoTokenizer, T5EncoderModel @@ -30,7 +31,7 @@ sys.path.append(".") -from utils import PeftLoraLoaderMixinTests # noqa: E402 +from utils import PeftLoraLoaderMixinTests, check_if_dicts_are_equal # noqa: E402 @require_peft_backend @@ -139,13 +140,39 @@ def test_simple_inference_with_text_lora_fused(self): def test_simple_inference_with_text_lora_save_load(self): pass - def test_save_load_with_adapter_metadata(self): + def test_adapter_metadata_is_loaded_correctly(self): # Will write the test in utils.py eventually. scheduler_cls = self.scheduler_classes[0] components, _, denoiser_lora_config = self.get_dummy_components(scheduler_cls) pipe = self.pipeline_class(**components) - pipe = pipe.to(torch_device) - pipe.set_progress_bar_config(disable=None) + + pipe, _ = self.check_if_adapters_added_correctly( + pipe, text_lora_config=None, denoiser_lora_config=denoiser_lora_config + ) + + with tempfile.TemporaryDirectory() as tmpdir: + modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True) + lora_state_dicts = self._get_lora_state_dicts(modules_to_save) + metadata = denoiser_lora_config.to_dict() + self.pipeline_class.save_lora_weights( + save_directory=tmpdir, + transformer_lora_adapter_metadata=metadata, + **lora_state_dicts, + ) + pipe.unload_lora_weights() + state_dict = pipe.lora_state_dict(tmpdir, load_with_metadata=True) + + self.assertTrue("_metadata" in state_dict) + + parsed_metadata = state_dict["_metadata"] + parsed_metadata = {k[len("transformer.") :]: v for k, v in parsed_metadata.items()} + check_if_dicts_are_equal(parsed_metadata, metadata) + + def test_adapter_metadata_save_load_inference(self): + # Will write the test in utils.py eventually. + scheduler_cls = self.scheduler_classes[0] + components, _, denoiser_lora_config = self.get_dummy_components(scheduler_cls) + pipe = self.pipeline_class(**components).to(torch_device) _, _, inputs = self.get_dummy_inputs(with_generator=False) output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] @@ -154,13 +181,22 @@ def test_save_load_with_adapter_metadata(self): pipe, _ = self.check_if_adapters_added_correctly( pipe, text_lora_config=None, denoiser_lora_config=denoiser_lora_config ) + output_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] - with tempfile.TemporaryDirectory() as tmpdirname: + with tempfile.TemporaryDirectory() as tmpdir: modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True) lora_state_dicts = self._get_lora_state_dicts(modules_to_save) + metadata = denoiser_lora_config.to_dict() self.pipeline_class.save_lora_weights( - save_directory=tmpdirname, - safe_serialization=False, - lora_adapter_metadata=denoiser_lora_config.to_dict(), + save_directory=tmpdir, + transformer_lora_adapter_metadata=metadata, **lora_state_dicts, ) + pipe.unload_lora_weights() + pipe.load_lora_weights(tmpdir, load_with_metadata=True) + + output_lora_pretrained = pipe(**inputs, generator=torch.manual_seed(0))[0] + + self.assertTrue( + np.allclose(output_lora, output_lora_pretrained, atol=1e-3, rtol=1e-3), "Lora outputs should match." + ) diff --git a/tests/lora/utils.py b/tests/lora/utils.py index 27fef495a484..9cd26f221850 100644 --- a/tests/lora/utils.py +++ b/tests/lora/utils.py @@ -61,6 +61,33 @@ def state_dicts_almost_equal(sd1, sd2): return models_are_equal +def check_if_dicts_are_equal(dict1, dict2): + for key, value in dict1.items(): + if isinstance(value, set): + dict1[key] = list(value) + for key, value in dict2.items(): + if isinstance(value, set): + dict2[key] = list(value) + + for key in dict1: + if key not in dict2: + raise ValueError( + f"Key '{key}' is missing in the second dictionary. Its value in the first dictionary is {dict1[key]}." + ) + if dict1[key] != dict2[key]: + raise ValueError( + f"Difference found at key '{key}': first dictionary has {dict1[key]}, second dictionary has {dict2[key]}." + ) + + for key in dict2: + if key not in dict1: + raise ValueError( + f"Key '{key}' is missing in the first dictionary. Its value in the second dictionary is {dict2[key]}." + ) + + return True + + def check_if_lora_correctly_set(model) -> bool: """ Checks if the LoRA layers are correctly set with peft From ba546bcbd84dbe1bfa33fe873ded354a8f26751e Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Tue, 15 Apr 2025 15:05:50 +0530 Subject: [PATCH 03/53] fix tests --- src/diffusers/loaders/lora_pipeline.py | 2 +- src/diffusers/loaders/peft.py | 7 ++----- src/diffusers/utils/peft_utils.py | 4 ++-- src/diffusers/utils/state_dict_utils.py | 2 +- tests/lora/test_lora_layers_wan.py | 4 ++-- 5 files changed, 8 insertions(+), 11 deletions(-) diff --git a/src/diffusers/loaders/lora_pipeline.py b/src/diffusers/loaders/lora_pipeline.py index 00596b1b0139..8456f04d9e84 100644 --- a/src/diffusers/loaders/lora_pipeline.py +++ b/src/diffusers/loaders/lora_pipeline.py @@ -4889,7 +4889,7 @@ def load_lora_weights( adapter_name=adapter_name, _pipeline=self, low_cpu_mem_usage=low_cpu_mem_usage, - load_with_metdata=load_with_metdata, + load_with_metadata=load_with_metdata, ) @classmethod diff --git a/src/diffusers/loaders/peft.py b/src/diffusers/loaders/peft.py index 70425c96b153..6b61fe03724a 100644 --- a/src/diffusers/loaders/peft.py +++ b/src/diffusers/loaders/peft.py @@ -236,15 +236,12 @@ def load_lora_adapter( if network_alphas is not None and prefix is None: raise ValueError("`network_alphas` cannot be None when `prefix` is None.") - if load_with_metadata is not None and not use_safetensors: - raise ValueError("`load_with_metadata` cannot be specified when not using `use_safetensors`.") - if prefix is not None: - metadata = state_dict.pop("_metadata", None) + metadata = state_dict.pop("lora_metadata", None) state_dict = {k[len(f"{prefix}.") :]: v for k, v in state_dict.items() if k.startswith(f"{prefix}.")} if metadata is not None: - state_dict["_metadata"] = metadata + state_dict["lora_metadata"] = metadata if len(state_dict) > 0: if adapter_name in getattr(self, "peft_config", {}) and not hotswap: diff --git a/src/diffusers/utils/peft_utils.py b/src/diffusers/utils/peft_utils.py index cd63616178bc..6408de79c2a7 100644 --- a/src/diffusers/utils/peft_utils.py +++ b/src/diffusers/utils/peft_utils.py @@ -151,9 +151,9 @@ def get_peft_kwargs( rank_dict, network_alpha_dict, peft_state_dict, is_unet=True, prefix=None, load_with_metadata=False ): if load_with_metadata: - if "_metadata" not in peft_state_dict: + if "lora_metadata" not in peft_state_dict: raise ValueError("Couldn't find '_metadata' key in the `peft_state_dict`.") - metadata = peft_state_dict["_metadata"] + metadata = peft_state_dict["lora_metadata"] if prefix is not None: metadata = {k.replace(f"{prefix}.", ""): v for k, v in metadata.items()} return metadata diff --git a/src/diffusers/utils/state_dict_utils.py b/src/diffusers/utils/state_dict_utils.py index 45922ef162d2..2723ab822df1 100644 --- a/src/diffusers/utils/state_dict_utils.py +++ b/src/diffusers/utils/state_dict_utils.py @@ -360,7 +360,7 @@ def _maybe_populate_state_dict_with_metadata(state_dict, model_file, metadata_ke metadata_keys = list(metadata.keys()) if not (len(metadata_keys) == 1 and metadata_keys[0] == "format"): peft_metadata = {k: v for k, v in metadata.items() if k != "format"} - state_dict["_metadata"] = json.loads(peft_metadata[metadata_key]) + state_dict["lora_metadata"] = json.loads(peft_metadata[metadata_key]) else: raise ValueError("Metadata couldn't be parsed from the safetensors file.") return state_dict diff --git a/tests/lora/test_lora_layers_wan.py b/tests/lora/test_lora_layers_wan.py index 0de2c5978516..4d8b06d748e1 100644 --- a/tests/lora/test_lora_layers_wan.py +++ b/tests/lora/test_lora_layers_wan.py @@ -162,9 +162,9 @@ def test_adapter_metadata_is_loaded_correctly(self): pipe.unload_lora_weights() state_dict = pipe.lora_state_dict(tmpdir, load_with_metadata=True) - self.assertTrue("_metadata" in state_dict) + self.assertTrue("lora_metadata" in state_dict) - parsed_metadata = state_dict["_metadata"] + parsed_metadata = state_dict["lora_metadata"] parsed_metadata = {k[len("transformer.") :]: v for k, v in parsed_metadata.items()} check_if_dicts_are_equal(parsed_metadata, metadata) From 61d37086b5a04fc3e4ef513d37f66cf8ecc92fef Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Tue, 15 Apr 2025 17:26:38 +0530 Subject: [PATCH 04/53] key renaming --- src/diffusers/loaders/lora_base.py | 4 ++-- src/diffusers/loaders/peft.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/diffusers/loaders/lora_base.py b/src/diffusers/loaders/lora_base.py index a845fab8d97e..a2295edb7b2b 100644 --- a/src/diffusers/loaders/lora_base.py +++ b/src/diffusers/loaders/lora_base.py @@ -244,7 +244,7 @@ def _fetch_state_dict( state_dict = safetensors.torch.load_file(model_file, device="cpu") if load_with_metadata: state_dict = _maybe_populate_state_dict_with_metadata( - state_dict, model_file, metadata_key="lora_adapter_config" + state_dict, model_file, metadata_key="lora_adapter_metadata" ) except (IOError, safetensors.SafetensorError) as e: @@ -915,7 +915,7 @@ def save_function(weights, filename): for key, value in lora_adapter_metadata.items(): if isinstance(value, set): lora_adapter_metadata[key] = list(value) - metadata["lora_adapter_config"] = json.dumps(lora_adapter_metadata, indent=2, sort_keys=True) + metadata["lora_adapter_metadata"] = json.dumps(lora_adapter_metadata, indent=2, sort_keys=True) return safetensors.torch.save_file(weights, filename, metadata=metadata) diff --git a/src/diffusers/loaders/peft.py b/src/diffusers/loaders/peft.py index 81f70ce587bb..8fd18c64df40 100644 --- a/src/diffusers/loaders/peft.py +++ b/src/diffusers/loaders/peft.py @@ -500,7 +500,7 @@ def save_function(weights, filename): for key, value in lora_adapter_metadata.items(): if isinstance(value, set): lora_adapter_metadata[key] = list(value) - metadata["lora_adapter_config"] = json.dumps(lora_adapter_metadata, indent=2, sort_keys=True) + metadata["lora_adapter_metadata"] = json.dumps(lora_adapter_metadata, indent=2, sort_keys=True) return safetensors.torch.save_file(weights, filename, metadata=metadata) From e98fb846e48e709ce2ce5918a1badc85de487179 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Tue, 15 Apr 2025 18:54:54 +0530 Subject: [PATCH 05/53] fix --- src/diffusers/loaders/lora_pipeline.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/diffusers/loaders/lora_pipeline.py b/src/diffusers/loaders/lora_pipeline.py index 35b19b13ae22..76ad07355d0f 100644 --- a/src/diffusers/loaders/lora_pipeline.py +++ b/src/diffusers/loaders/lora_pipeline.py @@ -5195,7 +5195,7 @@ def load_lora_weights( raise ValueError("PEFT backend is required for this method.") low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT_LORA) - load_with_metdata = kwargs.get("load_with_metdata", False) + load_with_metadata = kwargs.get("load_with_metadata", False) if low_cpu_mem_usage and is_peft_version("<", "0.13.0"): raise ValueError( "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`." @@ -5222,7 +5222,7 @@ def load_lora_weights( adapter_name=adapter_name, _pipeline=self, low_cpu_mem_usage=low_cpu_mem_usage, - load_with_metadata=load_with_metdata, + load_with_metadata=load_with_metadata, ) @classmethod From 42bb6bc50c14aa2bf461fdf456abebe2fdebc22f Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Fri, 2 May 2025 16:00:30 +0530 Subject: [PATCH 06/53] smol update --- src/diffusers/utils/peft_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/utils/peft_utils.py b/src/diffusers/utils/peft_utils.py index 83c43e822757..ede9fd65d57e 100644 --- a/src/diffusers/utils/peft_utils.py +++ b/src/diffusers/utils/peft_utils.py @@ -152,7 +152,7 @@ def get_peft_kwargs( ): if load_with_metadata: if "lora_metadata" not in peft_state_dict: - raise ValueError("Couldn't find '_metadata' key in the `peft_state_dict`.") + raise ValueError("Couldn't find 'lora_metadata' key in the `peft_state_dict`.") metadata = peft_state_dict["lora_metadata"] if prefix is not None: metadata = {k.replace(f"{prefix}.", ""): v for k, v in metadata.items()} From 7ec4ef48c4eb60d9b5f5ad9960f16d5d28e79227 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Fri, 2 May 2025 16:05:00 +0530 Subject: [PATCH 07/53] smol updates --- src/diffusers/loaders/peft.py | 4 ++-- src/diffusers/utils/peft_utils.py | 6 +++--- src/diffusers/utils/state_dict_utils.py | 2 +- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/src/diffusers/loaders/peft.py b/src/diffusers/loaders/peft.py index f0911df8d4c1..4999ef834295 100644 --- a/src/diffusers/loaders/peft.py +++ b/src/diffusers/loaders/peft.py @@ -239,11 +239,11 @@ def load_lora_adapter( raise ValueError("`network_alphas` cannot be None when `prefix` is None.") if prefix is not None: - metadata = state_dict.pop("lora_metadata", None) + metadata = state_dict.pop("lora_adapter_metadata", None) state_dict = {k[len(f"{prefix}.") :]: v for k, v in state_dict.items() if k.startswith(f"{prefix}.")} if metadata is not None: - state_dict["lora_metadata"] = metadata + state_dict["lora_adapter_metadata"] = metadata if len(state_dict) > 0: if adapter_name in getattr(self, "peft_config", {}) and not hotswap: diff --git a/src/diffusers/utils/peft_utils.py b/src/diffusers/utils/peft_utils.py index ede9fd65d57e..f3cdc2222e29 100644 --- a/src/diffusers/utils/peft_utils.py +++ b/src/diffusers/utils/peft_utils.py @@ -151,9 +151,9 @@ def get_peft_kwargs( rank_dict, network_alpha_dict, peft_state_dict, is_unet=True, prefix=None, load_with_metadata=False ): if load_with_metadata: - if "lora_metadata" not in peft_state_dict: - raise ValueError("Couldn't find 'lora_metadata' key in the `peft_state_dict`.") - metadata = peft_state_dict["lora_metadata"] + if "lora_adapter_metadata" not in peft_state_dict: + raise ValueError("Couldn't find 'lora_adapter_metadata' key in the `peft_state_dict`.") + metadata = peft_state_dict["lora_adapter_metadata"] if prefix is not None: metadata = {k.replace(f"{prefix}.", ""): v for k, v in metadata.items()} return metadata diff --git a/src/diffusers/utils/state_dict_utils.py b/src/diffusers/utils/state_dict_utils.py index 30c6f48e5a04..88a2e9b1f338 100644 --- a/src/diffusers/utils/state_dict_utils.py +++ b/src/diffusers/utils/state_dict_utils.py @@ -360,7 +360,7 @@ def _maybe_populate_state_dict_with_metadata(state_dict, model_file, metadata_ke metadata_keys = list(metadata.keys()) if not (len(metadata_keys) == 1 and metadata_keys[0] == "format"): peft_metadata = {k: v for k, v in metadata.items() if k != "format"} - state_dict["lora_metadata"] = json.loads(peft_metadata[metadata_key]) + state_dict["lora_adapter_metadata"] = json.loads(peft_metadata[metadata_key]) else: raise ValueError("Metadata couldn't be parsed from the safetensors file.") return state_dict From 7f59ca00c6d569e9a0fc72ca9ad89882fd019c31 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Fri, 2 May 2025 16:32:48 +0530 Subject: [PATCH 08/53] load metadata. --- src/diffusers/loaders/lora_base.py | 9 +--- src/diffusers/loaders/lora_pipeline.py | 69 ++++++------------------- src/diffusers/loaders/peft.py | 4 -- src/diffusers/utils/peft_utils.py | 14 +++-- src/diffusers/utils/state_dict_utils.py | 10 ++-- tests/lora/test_lora_layers_wan.py | 5 +- 6 files changed, 38 insertions(+), 73 deletions(-) diff --git a/src/diffusers/loaders/lora_base.py b/src/diffusers/loaders/lora_base.py index a2295edb7b2b..d4d3e659fc9c 100644 --- a/src/diffusers/loaders/lora_base.py +++ b/src/diffusers/loaders/lora_base.py @@ -64,6 +64,7 @@ LORA_WEIGHT_NAME = "pytorch_lora_weights.bin" LORA_WEIGHT_NAME_SAFE = "pytorch_lora_weights.safetensors" +LORA_ADAPTER_METADATA_KEY = "lora_adapter_metadata" def fuse_text_encoder_lora(text_encoder, lora_scale=1.0, safe_fusing=False, adapter_names=None): @@ -208,7 +209,6 @@ def _fetch_state_dict( subfolder, user_agent, allow_pickle, - load_with_metadata=False, ): model_file = None if not isinstance(pretrained_model_name_or_path_or_dict, dict): @@ -226,8 +226,6 @@ def _fetch_state_dict( file_extension=".safetensors", local_files_only=local_files_only, ) - if load_with_metadata and not weight_name.endswith(".safetensors"): - raise ValueError("`load_with_metadata` cannot be set to True when not using safetensors.") model_file = _get_model_file( pretrained_model_name_or_path_or_dict, @@ -242,10 +240,7 @@ def _fetch_state_dict( user_agent=user_agent, ) state_dict = safetensors.torch.load_file(model_file, device="cpu") - if load_with_metadata: - state_dict = _maybe_populate_state_dict_with_metadata( - state_dict, model_file, metadata_key="lora_adapter_metadata" - ) + state_dict = _maybe_populate_state_dict_with_metadata(state_dict, model_file) except (IOError, safetensors.SafetensorError) as e: if not allow_pickle: diff --git a/src/diffusers/loaders/lora_pipeline.py b/src/diffusers/loaders/lora_pipeline.py index fa35f6415417..eee7fba4e1d1 100644 --- a/src/diffusers/loaders/lora_pipeline.py +++ b/src/diffusers/loaders/lora_pipeline.py @@ -4727,7 +4727,6 @@ def lora_state_dict( - A [torch state dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict). - load_with_metadata: TODO cache_dir (`Union[str, os.PathLike]`, *optional*): Path to a directory where a downloaded pretrained model configuration is cached if the standard cache is not used. @@ -4762,7 +4761,6 @@ def lora_state_dict( subfolder = kwargs.pop("subfolder", None) weight_name = kwargs.pop("weight_name", None) use_safetensors = kwargs.pop("use_safetensors", None) - load_with_metadata = kwargs.pop("load_with_metadata", False) allow_pickle = False if use_safetensors is None: @@ -4787,7 +4785,6 @@ def lora_state_dict( subfolder=subfolder, user_agent=user_agent, allow_pickle=allow_pickle, - load_with_metadata=load_with_metadata, ) if any(k.startswith("diffusion_model.") for k in state_dict): state_dict = _convert_non_diffusers_wan_lora_to_diffusers(state_dict) @@ -4861,7 +4858,6 @@ def load_lora_weights( raise ValueError("PEFT backend is required for this method.") low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT_LORA) - load_with_metadata = kwargs.get("load_with_metadata", False) if low_cpu_mem_usage and is_peft_version("<", "0.13.0"): raise ValueError( "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`." @@ -4888,7 +4884,6 @@ def load_lora_weights( adapter_name=adapter_name, _pipeline=self, low_cpu_mem_usage=low_cpu_mem_usage, - load_with_metadata=load_with_metadata, hotswap=hotswap, ) @@ -4902,54 +4897,25 @@ def load_lora_into_transformer( _pipeline=None, low_cpu_mem_usage=False, hotswap: bool = False, - load_with_metadata: bool = False, ): """ - This will load the LoRA layers specified in `state_dict` into `transformer`. - - Parameters: - state_dict (`dict`): - A standard state dict containing the lora layer parameters. The keys can either be indexed - directly into the unet or prefixed with an additional `unet` which can be used to distinguish - between text encoder lora layers. - transformer (`WanTransformer3DModel`): - The Transformer model to load the LoRA layers into. - adapter_name (`str`, *optional*): - Adapter name to be used for referencing the loaded adapter model. If not specified, it will use - `default_{i}` where i is the total number of adapters being loaded. - low_cpu_mem_usage (`bool`, *optional*): - Speed up model loading by only loading the pretrained LoRA weights and not initializing the - random weights. - <<<<<<< HEAD - hotswap : (`bool`, *optional*) - Defaults to `False`. Whether to substitute an existing (LoRA) adapter with the newly loaded - adapter in-place. This means that, instead of loading an additional adapter, this will take the - existing adapter weights and replace them with the weights of the new adapter. This can be - faster and more memory efficient. However, the main advantage of hotswapping is that when the - model is compiled with torch.compile, loading the new adapter does not require recompilation of - the model. When using hotswapping, the passed `adapter_name` should be the name of an already - loaded adapter. - - If the new adapter and the old adapter have different ranks and/or LoRA alphas (i.e. scaling), - you need to call an additional method before loading the adapter: - - ```py - pipeline = ... # load diffusers pipeline - max_rank = ... # the highest rank among all LoRAs that you want to load - # call *before* compiling and loading the LoRA adapter - pipeline.enable_lora_hotswap(target_rank=max_rank) - pipeline.load_lora_weights(file_name) - # optionally compile the model now - ``` - - Note that hotswapping adapters of the text encoder is not yet supported. There are some further - limitations to this technique, which are documented here: - https://huggingface.co/docs/peft/main/en/package_reference/hotswap - load_with_metadata: TODO - ======= - hotswap (`bool`, *optional*): - See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`]. - >>>>>>> main + This will load the LoRA layers specified in `state_dict` into `transformer`. + + Parameters: + state_dict (`dict`): + A standard state dict containing the lora layer parameters. The keys can either be indexed directly + into the unet or prefixed with an additional `unet` which can be used to distinguish between text + encoder lora layers. + transformer (`WanTransformer3DModel`): + The Transformer model to load the LoRA layers into. + adapter_name (`str`, *optional*): + Adapter name to be used for referencing the loaded adapter model. If not specified, it will use + `default_{i}` where i is the total number of adapters being loaded. + low_cpu_mem_usage (`bool`, *optional*): + Speed up model loading by only loading the pretrained LoRA weights and not initializing the random + weights. + hotswap (`bool`, *optional*): + See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`]. """ if low_cpu_mem_usage and is_peft_version("<", "0.13.0"): raise ValueError( @@ -4965,7 +4931,6 @@ def load_lora_into_transformer( _pipeline=_pipeline, low_cpu_mem_usage=low_cpu_mem_usage, hotswap=hotswap, - load_with_metadata=load_with_metadata, ) @classmethod diff --git a/src/diffusers/loaders/peft.py b/src/diffusers/loaders/peft.py index 4999ef834295..23ec15d0315a 100644 --- a/src/diffusers/loaders/peft.py +++ b/src/diffusers/loaders/peft.py @@ -120,7 +120,6 @@ def load_lora_adapter( pretrained_model_name_or_path_or_dict, prefix="transformer", hotswap: bool = False, - load_with_metadata: bool = False, **kwargs, ): r""" @@ -190,7 +189,6 @@ def load_lora_adapter( limitations to this technique, which are documented here: https://huggingface.co/docs/peft/main/en/package_reference/hotswap - load_with_metadata: TODO """ from peft import LoraConfig, inject_adapter_in_model, set_peft_model_state_dict from peft.tuners.tuners_utils import BaseTunerLayer @@ -233,7 +231,6 @@ def load_lora_adapter( subfolder=subfolder, user_agent=user_agent, allow_pickle=allow_pickle, - load_with_metadata=load_with_metadata, ) if network_alphas is not None and prefix is None: raise ValueError("`network_alphas` cannot be None when `prefix` is None.") @@ -280,7 +277,6 @@ def load_lora_adapter( rank, network_alpha_dict=network_alphas, peft_state_dict=state_dict, - load_with_metadata=load_with_metadata, prefix=prefix, ) _maybe_raise_error_for_ambiguity(lora_config_kwargs) diff --git a/src/diffusers/utils/peft_utils.py b/src/diffusers/utils/peft_utils.py index f3cdc2222e29..9f73006effee 100644 --- a/src/diffusers/utils/peft_utils.py +++ b/src/diffusers/utils/peft_utils.py @@ -148,12 +148,16 @@ def unscale_lora_layers(model, weight: Optional[float] = None): def get_peft_kwargs( - rank_dict, network_alpha_dict, peft_state_dict, is_unet=True, prefix=None, load_with_metadata=False + rank_dict, + network_alpha_dict, + peft_state_dict, + is_unet=True, + prefix=None, ): - if load_with_metadata: - if "lora_adapter_metadata" not in peft_state_dict: - raise ValueError("Couldn't find 'lora_adapter_metadata' key in the `peft_state_dict`.") - metadata = peft_state_dict["lora_adapter_metadata"] + from ..loaders.lora_base import LORA_ADAPTER_METADATA_KEY + + if LORA_ADAPTER_METADATA_KEY in peft_state_dict: + metadata = peft_state_dict[LORA_ADAPTER_METADATA_KEY] if prefix is not None: metadata = {k.replace(f"{prefix}.", ""): v for k, v in metadata.items()} return metadata diff --git a/src/diffusers/utils/state_dict_utils.py b/src/diffusers/utils/state_dict_utils.py index 88a2e9b1f338..0dbbce5713b7 100644 --- a/src/diffusers/utils/state_dict_utils.py +++ b/src/diffusers/utils/state_dict_utils.py @@ -350,9 +350,15 @@ def state_dict_all_zero(state_dict, filter_str=None): return all(torch.all(param == 0).item() for param in state_dict.values()) -def _maybe_populate_state_dict_with_metadata(state_dict, model_file, metadata_key): +def _maybe_populate_state_dict_with_metadata(state_dict, model_file): + if not model_file.endswith(".safetensors"): + return state_dict + import safetensors.torch + from ..loaders.lora_base import LORA_ADAPTER_METADATA_KEY + + metadata_key = LORA_ADAPTER_METADATA_KEY with safetensors.torch.safe_open(model_file, framework="pt", device="cpu") as f: if hasattr(f, "metadata"): metadata = f.metadata() @@ -361,6 +367,4 @@ def _maybe_populate_state_dict_with_metadata(state_dict, model_file, metadata_ke if not (len(metadata_keys) == 1 and metadata_keys[0] == "format"): peft_metadata = {k: v for k, v in metadata.items() if k != "format"} state_dict["lora_adapter_metadata"] = json.loads(peft_metadata[metadata_key]) - else: - raise ValueError("Metadata couldn't be parsed from the safetensors file.") return state_dict diff --git a/tests/lora/test_lora_layers_wan.py b/tests/lora/test_lora_layers_wan.py index 4d8b06d748e1..b6c3a07f875f 100644 --- a/tests/lora/test_lora_layers_wan.py +++ b/tests/lora/test_lora_layers_wan.py @@ -26,6 +26,7 @@ WanPipeline, WanTransformer3DModel, ) +from diffusers.loaders.lora_base import LORA_ADAPTER_METADATA_KEY from diffusers.utils.testing_utils import floats_tensor, require_peft_backend, skip_mps, torch_device @@ -162,9 +163,9 @@ def test_adapter_metadata_is_loaded_correctly(self): pipe.unload_lora_weights() state_dict = pipe.lora_state_dict(tmpdir, load_with_metadata=True) - self.assertTrue("lora_metadata" in state_dict) + self.assertTrue(LORA_ADAPTER_METADATA_KEY in state_dict) - parsed_metadata = state_dict["lora_metadata"] + parsed_metadata = state_dict[LORA_ADAPTER_METADATA_KEY] parsed_metadata = {k[len("transformer.") :]: v for k, v in parsed_metadata.items()} check_if_dicts_are_equal(parsed_metadata, metadata) From ded2fd64327ab223ccd2b60f78c067160c5f1169 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Fri, 2 May 2025 16:51:17 +0530 Subject: [PATCH 09/53] automatically save metadata in save_lora_adapter. --- src/diffusers/loaders/peft.py | 19 ++++++++----------- 1 file changed, 8 insertions(+), 11 deletions(-) diff --git a/src/diffusers/loaders/peft.py b/src/diffusers/loaders/peft.py index 23ec15d0315a..af588f75c200 100644 --- a/src/diffusers/loaders/peft.py +++ b/src/diffusers/loaders/peft.py @@ -193,6 +193,8 @@ def load_lora_adapter( from peft import LoraConfig, inject_adapter_in_model, set_peft_model_state_dict from peft.tuners.tuners_utils import BaseTunerLayer + from .lora_base import LORA_ADAPTER_METADATA_KEY + cache_dir = kwargs.pop("cache_dir", None) force_download = kwargs.pop("force_download", False) proxies = kwargs.pop("proxies", None) @@ -236,11 +238,11 @@ def load_lora_adapter( raise ValueError("`network_alphas` cannot be None when `prefix` is None.") if prefix is not None: - metadata = state_dict.pop("lora_adapter_metadata", None) + metadata = state_dict.pop(LORA_ADAPTER_METADATA_KEY, None) state_dict = {k[len(f"{prefix}.") :]: v for k, v in state_dict.items() if k.startswith(f"{prefix}.")} if metadata is not None: - state_dict["lora_adapter_metadata"] = metadata + state_dict[LORA_ADAPTER_METADATA_KEY] = metadata if len(state_dict) > 0: if adapter_name in getattr(self, "peft_config", {}) and not hotswap: @@ -464,16 +466,10 @@ def save_lora_adapter( safe_serialization (`bool`, *optional*, defaults to `True`): Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`. weight_name: (`str`, *optional*, defaults to `None`): Name of the file to serialize the state dict with. - lora_adapter_metadata: TODO """ from peft.utils import get_peft_model_state_dict - from .lora_base import LORA_WEIGHT_NAME, LORA_WEIGHT_NAME_SAFE - - if lora_adapter_metadata is not None and not safe_serialization: - raise ValueError("`lora_adapter_metadata` cannot be specified when not using `safe_serialization`.") - if not isinstance(lora_adapter_metadata, dict): - raise ValueError("`lora_adapter_metadata` must be of type `dict`.") + from .lora_base import LORA_ADAPTER_METADATA_KEY, LORA_WEIGHT_NAME, LORA_WEIGHT_NAME_SAFE if adapter_name is None: adapter_name = get_adapter_name(self) @@ -481,6 +477,8 @@ def save_lora_adapter( if adapter_name not in getattr(self, "peft_config", {}): raise ValueError(f"Adapter name {adapter_name} not found in the model.") + lora_adapter_metadata = self.peft_config[adapter_name] + lora_layers_to_save = get_peft_model_state_dict( self.to(dtype=torch.float32 if upcast_before_saving else None), adapter_name=adapter_name ) @@ -497,7 +495,7 @@ def save_function(weights, filename): for key, value in lora_adapter_metadata.items(): if isinstance(value, set): lora_adapter_metadata[key] = list(value) - metadata["lora_adapter_metadata"] = json.dumps(lora_adapter_metadata, indent=2, sort_keys=True) + metadata[LORA_ADAPTER_METADATA_KEY] = json.dumps(lora_adapter_metadata, indent=2, sort_keys=True) return safetensors.torch.save_file(weights, filename, metadata=metadata) @@ -512,7 +510,6 @@ def save_function(weights, filename): else: weight_name = LORA_WEIGHT_NAME - # TODO: we could consider saving the `peft_config` as well. save_path = Path(save_directory, weight_name).as_posix() save_function(lora_layers_to_save, save_path) logger.info(f"Model weights saved in {save_path}") From d5b3037dbcf69d19d1faf072a62072e7b190f99e Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Fri, 2 May 2025 16:57:00 +0530 Subject: [PATCH 10/53] propagate changes. --- src/diffusers/loaders/lora_pipeline.py | 91 +++++++++++++++++++++----- 1 file changed, 74 insertions(+), 17 deletions(-) diff --git a/src/diffusers/loaders/lora_pipeline.py b/src/diffusers/loaders/lora_pipeline.py index eee7fba4e1d1..8392a3f3936c 100644 --- a/src/diffusers/loaders/lora_pipeline.py +++ b/src/diffusers/loaders/lora_pipeline.py @@ -1667,9 +1667,10 @@ def save_lora_weights( weight_name: str = None, save_function: Callable = None, safe_serialization: bool = True, + transformer_lora_adapter_metadata: Optional[dict] = None, ): r""" - Save the LoRA parameters corresponding to the UNet and text encoder. + Save the LoRA parameters corresponding to the transformer. Arguments: save_directory (`str` or `os.PathLike`): @@ -1686,8 +1687,10 @@ def save_lora_weights( `DIFFUSERS_SAVE_MODE`. safe_serialization (`bool`, *optional*, defaults to `True`): Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`. + transformer_lora_adapter_metadata: TODO """ state_dict = {} + lora_adapter_metadata = {} if not transformer_lora_layers: raise ValueError("You must pass `transformer_lora_layers`.") @@ -1695,6 +1698,9 @@ def save_lora_weights( if transformer_lora_layers: state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name)) + if transformer_lora_adapter_metadata: + lora_adapter_metadata.update(cls.pack_weights(transformer_lora_adapter_metadata, cls.transformer_name)) + # Save the model cls.write_lora_layers( state_dict=state_dict, @@ -1703,6 +1709,7 @@ def save_lora_weights( weight_name=weight_name, save_function=save_function, safe_serialization=safe_serialization, + lora_adapter_metadata=lora_adapter_metadata, ) # Copied from diffusers.loaders.lora_pipeline.SanaLoraLoaderMixin.fuse_lora @@ -2985,9 +2992,10 @@ def save_lora_weights( weight_name: str = None, save_function: Callable = None, safe_serialization: bool = True, + transformer_lora_adapter_metadata: Optional[dict] = None, ): r""" - Save the LoRA parameters corresponding to the UNet and text encoder. + Save the LoRA parameters corresponding to the transformer. Arguments: save_directory (`str` or `os.PathLike`): @@ -3004,8 +3012,10 @@ def save_lora_weights( `DIFFUSERS_SAVE_MODE`. safe_serialization (`bool`, *optional*, defaults to `True`): Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`. + transformer_lora_adapter_metadata: TODO """ state_dict = {} + lora_adapter_metadata = {} if not transformer_lora_layers: raise ValueError("You must pass `transformer_lora_layers`.") @@ -3013,6 +3023,9 @@ def save_lora_weights( if transformer_lora_layers: state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name)) + if transformer_lora_adapter_metadata: + lora_adapter_metadata.update(cls.pack_weights(transformer_lora_adapter_metadata, cls.transformer_name)) + # Save the model cls.write_lora_layers( state_dict=state_dict, @@ -3021,6 +3034,7 @@ def save_lora_weights( weight_name=weight_name, save_function=save_function, safe_serialization=safe_serialization, + lora_adapter_metadata=lora_adapter_metadata, ) def fuse_lora( @@ -3302,9 +3316,10 @@ def save_lora_weights( weight_name: str = None, save_function: Callable = None, safe_serialization: bool = True, + transformer_lora_adapter_metadata: Optional[dict] = None, ): r""" - Save the LoRA parameters corresponding to the UNet and text encoder. + Save the LoRA parameters corresponding to the transformer. Arguments: save_directory (`str` or `os.PathLike`): @@ -3321,8 +3336,10 @@ def save_lora_weights( `DIFFUSERS_SAVE_MODE`. safe_serialization (`bool`, *optional*, defaults to `True`): Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`. + transformer_lora_adapter_metadata: TODO """ state_dict = {} + lora_adapter_metadata = {} if not transformer_lora_layers: raise ValueError("You must pass `transformer_lora_layers`.") @@ -3330,6 +3347,9 @@ def save_lora_weights( if transformer_lora_layers: state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name)) + if transformer_lora_adapter_metadata: + lora_adapter_metadata.update(cls.pack_weights(transformer_lora_adapter_metadata, cls.transformer_name)) + # Save the model cls.write_lora_layers( state_dict=state_dict, @@ -3338,6 +3358,7 @@ def save_lora_weights( weight_name=weight_name, save_function=save_function, safe_serialization=safe_serialization, + lora_adapter_metadata=lora_adapter_metadata, ) # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.fuse_lora @@ -3621,9 +3642,10 @@ def save_lora_weights( weight_name: str = None, save_function: Callable = None, safe_serialization: bool = True, + transformer_lora_adapter_metadata: Optional[dict] = None, ): r""" - Save the LoRA parameters corresponding to the UNet and text encoder. + Save the LoRA parameters corresponding to the transformer. Arguments: save_directory (`str` or `os.PathLike`): @@ -3640,8 +3662,10 @@ def save_lora_weights( `DIFFUSERS_SAVE_MODE`. safe_serialization (`bool`, *optional*, defaults to `True`): Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`. + transformer_lora_adapter_metadata: TODO """ state_dict = {} + lora_adapter_metadata = {} if not transformer_lora_layers: raise ValueError("You must pass `transformer_lora_layers`.") @@ -3649,6 +3673,9 @@ def save_lora_weights( if transformer_lora_layers: state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name)) + if transformer_lora_adapter_metadata: + lora_adapter_metadata.update(cls.pack_weights(transformer_lora_adapter_metadata, cls.transformer_name)) + # Save the model cls.write_lora_layers( state_dict=state_dict, @@ -3657,6 +3684,7 @@ def save_lora_weights( weight_name=weight_name, save_function=save_function, safe_serialization=safe_serialization, + lora_adapter_metadata=lora_adapter_metadata, ) # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.fuse_lora @@ -3940,9 +3968,10 @@ def save_lora_weights( weight_name: str = None, save_function: Callable = None, safe_serialization: bool = True, + transformer_lora_adapter_metadata: Optional[dict] = None, ): r""" - Save the LoRA parameters corresponding to the UNet and text encoder. + Save the LoRA parameters corresponding to the transformer. Arguments: save_directory (`str` or `os.PathLike`): @@ -3959,8 +3988,10 @@ def save_lora_weights( `DIFFUSERS_SAVE_MODE`. safe_serialization (`bool`, *optional*, defaults to `True`): Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`. + transformer_lora_adapter_metadata: TODO """ state_dict = {} + lora_adapter_metadata = {} if not transformer_lora_layers: raise ValueError("You must pass `transformer_lora_layers`.") @@ -3968,6 +3999,9 @@ def save_lora_weights( if transformer_lora_layers: state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name)) + if transformer_lora_adapter_metadata: + lora_adapter_metadata.update(cls.pack_weights(transformer_lora_adapter_metadata, cls.transformer_name)) + # Save the model cls.write_lora_layers( state_dict=state_dict, @@ -3976,6 +4010,7 @@ def save_lora_weights( weight_name=weight_name, save_function=save_function, safe_serialization=safe_serialization, + lora_adapter_metadata=lora_adapter_metadata, ) # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.fuse_lora @@ -4262,9 +4297,10 @@ def save_lora_weights( weight_name: str = None, save_function: Callable = None, safe_serialization: bool = True, + transformer_lora_adapter_metadata: Optional[dict] = None, ): r""" - Save the LoRA parameters corresponding to the UNet and text encoder. + Save the LoRA parameters corresponding to the transformer. Arguments: save_directory (`str` or `os.PathLike`): @@ -4281,8 +4317,10 @@ def save_lora_weights( `DIFFUSERS_SAVE_MODE`. safe_serialization (`bool`, *optional*, defaults to `True`): Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`. + transformer_lora_adapter_metadata: TODO """ state_dict = {} + lora_adapter_metadata = {} if not transformer_lora_layers: raise ValueError("You must pass `transformer_lora_layers`.") @@ -4290,6 +4328,9 @@ def save_lora_weights( if transformer_lora_layers: state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name)) + if transformer_lora_adapter_metadata: + lora_adapter_metadata.update(cls.pack_weights(transformer_lora_adapter_metadata, cls.transformer_name)) + # Save the model cls.write_lora_layers( state_dict=state_dict, @@ -4298,6 +4339,7 @@ def save_lora_weights( weight_name=weight_name, save_function=save_function, safe_serialization=safe_serialization, + lora_adapter_metadata=lora_adapter_metadata, ) # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.fuse_lora @@ -4585,9 +4627,10 @@ def save_lora_weights( weight_name: str = None, save_function: Callable = None, safe_serialization: bool = True, + transformer_lora_adapter_metadata: Optional[dict] = None, ): r""" - Save the LoRA parameters corresponding to the UNet and text encoder. + Save the LoRA parameters corresponding to the transformer. Arguments: save_directory (`str` or `os.PathLike`): @@ -4604,8 +4647,10 @@ def save_lora_weights( `DIFFUSERS_SAVE_MODE`. safe_serialization (`bool`, *optional*, defaults to `True`): Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`. + transformer_lora_adapter_metadata: TODO """ state_dict = {} + lora_adapter_metadata = {} if not transformer_lora_layers: raise ValueError("You must pass `transformer_lora_layers`.") @@ -4613,6 +4658,9 @@ def save_lora_weights( if transformer_lora_layers: state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name)) + if transformer_lora_adapter_metadata: + lora_adapter_metadata.update(cls.pack_weights(transformer_lora_adapter_metadata, cls.transformer_name)) + # Save the model cls.write_lora_layers( state_dict=state_dict, @@ -4621,6 +4669,7 @@ def save_lora_weights( weight_name=weight_name, save_function=save_function, safe_serialization=safe_serialization, + lora_adapter_metadata=lora_adapter_metadata, ) # Copied from diffusers.loaders.lora_pipeline.SanaLoraLoaderMixin.fuse_lora @@ -4890,13 +4939,7 @@ def load_lora_weights( @classmethod # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->WanTransformer3DModel def load_lora_into_transformer( - cls, - state_dict, - transformer, - adapter_name=None, - _pipeline=None, - low_cpu_mem_usage=False, - hotswap: bool = False, + cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False, hotswap: bool = False ): """ This will load the LoRA layers specified in `state_dict` into `transformer`. @@ -4946,7 +4989,7 @@ def save_lora_weights( transformer_lora_adapter_metadata: Optional[dict] = None, ): r""" - Save the LoRA parameters corresponding to the UNet and text encoder. + Save the LoRA parameters corresponding to the transformer. Arguments: save_directory (`str` or `os.PathLike`): @@ -5269,9 +5312,10 @@ def save_lora_weights( weight_name: str = None, save_function: Callable = None, safe_serialization: bool = True, + transformer_lora_adapter_metadata: Optional[dict] = None, ): r""" - Save the LoRA parameters corresponding to the UNet and text encoder. + Save the LoRA parameters corresponding to the transformer. Arguments: save_directory (`str` or `os.PathLike`): @@ -5288,8 +5332,10 @@ def save_lora_weights( `DIFFUSERS_SAVE_MODE`. safe_serialization (`bool`, *optional*, defaults to `True`): Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`. + transformer_lora_adapter_metadata: TODO """ state_dict = {} + lora_adapter_metadata = {} if not transformer_lora_layers: raise ValueError("You must pass `transformer_lora_layers`.") @@ -5297,6 +5343,9 @@ def save_lora_weights( if transformer_lora_layers: state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name)) + if transformer_lora_adapter_metadata: + lora_adapter_metadata.update(cls.pack_weights(transformer_lora_adapter_metadata, cls.transformer_name)) + # Save the model cls.write_lora_layers( state_dict=state_dict, @@ -5305,6 +5354,7 @@ def save_lora_weights( weight_name=weight_name, save_function=save_function, safe_serialization=safe_serialization, + lora_adapter_metadata=lora_adapter_metadata, ) # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.fuse_lora @@ -5588,9 +5638,10 @@ def save_lora_weights( weight_name: str = None, save_function: Callable = None, safe_serialization: bool = True, + transformer_lora_adapter_metadata: Optional[dict] = None, ): r""" - Save the LoRA parameters corresponding to the UNet and text encoder. + Save the LoRA parameters corresponding to the transformer. Arguments: save_directory (`str` or `os.PathLike`): @@ -5607,8 +5658,10 @@ def save_lora_weights( `DIFFUSERS_SAVE_MODE`. safe_serialization (`bool`, *optional*, defaults to `True`): Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`. + transformer_lora_adapter_metadata: TODO """ state_dict = {} + lora_adapter_metadata = {} if not transformer_lora_layers: raise ValueError("You must pass `transformer_lora_layers`.") @@ -5616,6 +5669,9 @@ def save_lora_weights( if transformer_lora_layers: state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name)) + if transformer_lora_adapter_metadata: + lora_adapter_metadata.update(cls.pack_weights(transformer_lora_adapter_metadata, cls.transformer_name)) + # Save the model cls.write_lora_layers( state_dict=state_dict, @@ -5624,6 +5680,7 @@ def save_lora_weights( weight_name=weight_name, save_function=save_function, safe_serialization=safe_serialization, + lora_adapter_metadata=lora_adapter_metadata, ) # Copied from diffusers.loaders.lora_pipeline.SanaLoraLoaderMixin.fuse_lora From bee9e003d825bbe3e71a4acc0a2af9713e14758a Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Fri, 2 May 2025 17:21:14 +0530 Subject: [PATCH 11/53] changes --- src/diffusers/loaders/lora_base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/loaders/lora_base.py b/src/diffusers/loaders/lora_base.py index d4d3e659fc9c..2a8f29822c2d 100644 --- a/src/diffusers/loaders/lora_base.py +++ b/src/diffusers/loaders/lora_base.py @@ -894,7 +894,7 @@ def write_lora_layers( logger.error(f"Provided path ({save_directory}) should be a directory, not a file") return - if lora_adapter_metadata is not None and not safe_serialization: + if lora_adapter_metadata and not safe_serialization: raise ValueError("`lora_adapter_metadata` cannot be specified when not using `safe_serialization`.") if not isinstance(lora_adapter_metadata, dict): raise ValueError("`lora_adapter_metadata` must be of type `dict`.") From a9f5088c5a138840cf960bd87f40fcb229b3e5a5 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Fri, 2 May 2025 17:47:53 +0530 Subject: [PATCH 12/53] add test to models too. --- src/diffusers/loaders/peft.py | 3 +- src/diffusers/utils/testing_utils.py | 27 +++++++++++++++++ tests/lora/test_lora_layers_wan.py | 14 ++++++--- tests/lora/utils.py | 27 ----------------- tests/models/test_modeling_common.py | 43 +++++++++++++++++++++++++++- 5 files changed, 80 insertions(+), 34 deletions(-) diff --git a/src/diffusers/loaders/peft.py b/src/diffusers/loaders/peft.py index af588f75c200..f2194bb2dab2 100644 --- a/src/diffusers/loaders/peft.py +++ b/src/diffusers/loaders/peft.py @@ -451,7 +451,6 @@ def save_lora_adapter( upcast_before_saving: bool = False, safe_serialization: bool = True, weight_name: Optional[str] = None, - lora_adapter_metadata: Optional[dict] = None, ): """ Save the LoRA parameters corresponding to the underlying model. @@ -477,7 +476,7 @@ def save_lora_adapter( if adapter_name not in getattr(self, "peft_config", {}): raise ValueError(f"Adapter name {adapter_name} not found in the model.") - lora_adapter_metadata = self.peft_config[adapter_name] + lora_adapter_metadata = self.peft_config[adapter_name].to_dict() lora_layers_to_save = get_peft_model_state_dict( self.to(dtype=torch.float32 if upcast_before_saving else None), adapter_name=adapter_name diff --git a/src/diffusers/utils/testing_utils.py b/src/diffusers/utils/testing_utils.py index 7a524e76f16e..9d44969394d9 100644 --- a/src/diffusers/utils/testing_utils.py +++ b/src/diffusers/utils/testing_utils.py @@ -132,6 +132,33 @@ def numpy_cosine_similarity_distance(a, b): return distance +def check_if_dicts_are_equal(dict1, dict2): + for key, value in dict1.items(): + if isinstance(value, set): + dict1[key] = list(value) + for key, value in dict2.items(): + if isinstance(value, set): + dict2[key] = list(value) + + for key in dict1: + if key not in dict2: + raise ValueError( + f"Key '{key}' is missing in the second dictionary. Its value in the first dictionary is {dict1[key]}." + ) + if dict1[key] != dict2[key]: + raise ValueError( + f"Difference found at key '{key}': first dictionary has {dict1[key]}, second dictionary has {dict2[key]}." + ) + + for key in dict2: + if key not in dict1: + raise ValueError( + f"Key '{key}' is missing in the first dictionary. Its value in the second dictionary is {dict2[key]}." + ) + + return True + + def print_tensor_test( tensor, limit_to_slices=None, diff --git a/tests/lora/test_lora_layers_wan.py b/tests/lora/test_lora_layers_wan.py index b6c3a07f875f..bdac22d0e477 100644 --- a/tests/lora/test_lora_layers_wan.py +++ b/tests/lora/test_lora_layers_wan.py @@ -27,12 +27,18 @@ WanTransformer3DModel, ) from diffusers.loaders.lora_base import LORA_ADAPTER_METADATA_KEY -from diffusers.utils.testing_utils import floats_tensor, require_peft_backend, skip_mps, torch_device +from diffusers.utils.testing_utils import ( + check_if_dicts_are_equal, + floats_tensor, + require_peft_backend, + skip_mps, + torch_device, +) sys.path.append(".") -from utils import PeftLoraLoaderMixinTests, check_if_dicts_are_equal # noqa: E402 +from utils import PeftLoraLoaderMixinTests # noqa: E402 @require_peft_backend @@ -161,7 +167,7 @@ def test_adapter_metadata_is_loaded_correctly(self): **lora_state_dicts, ) pipe.unload_lora_weights() - state_dict = pipe.lora_state_dict(tmpdir, load_with_metadata=True) + state_dict = pipe.lora_state_dict(tmpdir) self.assertTrue(LORA_ADAPTER_METADATA_KEY in state_dict) @@ -194,7 +200,7 @@ def test_adapter_metadata_save_load_inference(self): **lora_state_dicts, ) pipe.unload_lora_weights() - pipe.load_lora_weights(tmpdir, load_with_metadata=True) + pipe.load_lora_weights(tmpdir) output_lora_pretrained = pipe(**inputs, generator=torch.manual_seed(0))[0] diff --git a/tests/lora/utils.py b/tests/lora/utils.py index a169fe40b4b3..87a8fddfa583 100644 --- a/tests/lora/utils.py +++ b/tests/lora/utils.py @@ -61,33 +61,6 @@ def state_dicts_almost_equal(sd1, sd2): return models_are_equal -def check_if_dicts_are_equal(dict1, dict2): - for key, value in dict1.items(): - if isinstance(value, set): - dict1[key] = list(value) - for key, value in dict2.items(): - if isinstance(value, set): - dict2[key] = list(value) - - for key in dict1: - if key not in dict2: - raise ValueError( - f"Key '{key}' is missing in the second dictionary. Its value in the first dictionary is {dict1[key]}." - ) - if dict1[key] != dict2[key]: - raise ValueError( - f"Difference found at key '{key}': first dictionary has {dict1[key]}, second dictionary has {dict2[key]}." - ) - - for key in dict2: - if key not in dict1: - raise ValueError( - f"Key '{key}' is missing in the first dictionary. Its value in the second dictionary is {dict2[key]}." - ) - - return True - - def check_if_lora_correctly_set(model) -> bool: """ Checks if the LoRA layers are correctly set with peft diff --git a/tests/models/test_modeling_common.py b/tests/models/test_modeling_common.py index 57431d8b161b..7e854fbebe16 100644 --- a/tests/models/test_modeling_common.py +++ b/tests/models/test_modeling_common.py @@ -30,6 +30,7 @@ import numpy as np import requests_mock +import safetensors.torch import torch import torch.nn as nn from accelerate.utils.modeling import _get_proper_dtype, compute_module_sizes, dtype_byte_size @@ -62,6 +63,7 @@ backend_max_memory_allocated, backend_reset_peak_memory_stats, backend_synchronize, + check_if_dicts_are_equal, floats_tensor, get_python_version, is_torch_compile, @@ -1062,7 +1064,6 @@ def test_deprecated_kwargs(self): @torch.no_grad() @unittest.skipIf(not is_peft_available(), "Only with PEFT") def test_save_load_lora_adapter(self, use_dora=False): - import safetensors from peft import LoraConfig from peft.utils import get_peft_model_state_dict @@ -1146,6 +1147,46 @@ def test_wrong_adapter_name_raises_error(self): self.assertTrue(f"Adapter name {wrong_name} not found in the model." in str(err_context.exception)) + @torch.no_grad() + @unittest.skipIf(not is_peft_available(), "Only with PEFT") + def test_adapter_metadata_is_loaded_correctly(self): + from peft import LoraConfig + + from diffusers.loaders.lora_base import LORA_ADAPTER_METADATA_KEY + from diffusers.loaders.peft import PeftAdapterMixin + + init_dict, _ = self.prepare_init_args_and_inputs_for_common() + model = self.model_class(**init_dict).to(torch_device) + + if not issubclass(model.__class__, PeftAdapterMixin): + return + + denoiser_lora_config = LoraConfig( + r=4, + lora_alpha=4, + target_modules=["to_q", "to_k", "to_v", "to_out.0"], + init_lora_weights=False, + use_dora=False, + ) + model.add_adapter(denoiser_lora_config) + metadata = model.peft_config["default"].to_dict() + self.assertTrue(check_if_lora_correctly_set(model), "LoRA layers not set correctly") + + with tempfile.TemporaryDirectory() as tmpdir: + model.save_lora_adapter(tmpdir) + model_file = os.path.join(tmpdir, "pytorch_lora_weights.safetensors") + self.assertTrue(os.path.isfile(model_file)) + + with safetensors.torch.safe_open(model_file, framework="pt", device="cpu") as f: + if hasattr(f, "metadata"): + parsed_metadata = f.metadata() + parsed_metadata = {k: v for k, v in parsed_metadata.items() if k != "format"} + self.assertTrue(LORA_ADAPTER_METADATA_KEY in parsed_metadata) + parsed_metadata = {k: v for k, v in parsed_metadata.items() if k != "format"} + + parsed_metadata = json.loads(parsed_metadata[LORA_ADAPTER_METADATA_KEY]) + check_if_dicts_are_equal(parsed_metadata, metadata) + @require_torch_accelerator def test_cpu_offload(self): config, inputs_dict = self.prepare_init_args_and_inputs_for_common() From 771630301916342b6fffcbe1a4648b4641d98c43 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Fri, 2 May 2025 17:57:32 +0530 Subject: [PATCH 13/53] tigher tests. --- tests/models/test_modeling_common.py | 31 +++++++++++++--------------- 1 file changed, 14 insertions(+), 17 deletions(-) diff --git a/tests/models/test_modeling_common.py b/tests/models/test_modeling_common.py index 7e854fbebe16..3bed2a87ce18 100644 --- a/tests/models/test_modeling_common.py +++ b/tests/models/test_modeling_common.py @@ -1060,10 +1060,10 @@ def test_deprecated_kwargs(self): " from `_deprecated_kwargs = []`" ) - @parameterized.expand([True, False]) + @parameterized.expand([(4, 4, True), (4, 8, False), (8, 4, False)]) @torch.no_grad() @unittest.skipIf(not is_peft_available(), "Only with PEFT") - def test_save_load_lora_adapter(self, use_dora=False): + def test_save_load_lora_adapter(self, rank, lora_alpha, use_dora=False): from peft import LoraConfig from peft.utils import get_peft_model_state_dict @@ -1079,8 +1079,8 @@ def test_save_load_lora_adapter(self, use_dora=False): output_no_lora = model(**inputs_dict, return_dict=False)[0] denoiser_lora_config = LoraConfig( - r=4, - lora_alpha=4, + r=rank, + lora_alpha=lora_alpha, target_modules=["to_q", "to_k", "to_v", "to_out.0"], init_lora_weights=False, use_dora=use_dora, @@ -1147,12 +1147,12 @@ def test_wrong_adapter_name_raises_error(self): self.assertTrue(f"Adapter name {wrong_name} not found in the model." in str(err_context.exception)) + @parameterized.expand([(4, 4, True), (4, 8, False), (8, 4, False)]) @torch.no_grad() @unittest.skipIf(not is_peft_available(), "Only with PEFT") - def test_adapter_metadata_is_loaded_correctly(self): + def test_adapter_metadata_is_loaded_correctly(self, rank, lora_alpha, use_dora): from peft import LoraConfig - from diffusers.loaders.lora_base import LORA_ADAPTER_METADATA_KEY from diffusers.loaders.peft import PeftAdapterMixin init_dict, _ = self.prepare_init_args_and_inputs_for_common() @@ -1162,11 +1162,11 @@ def test_adapter_metadata_is_loaded_correctly(self): return denoiser_lora_config = LoraConfig( - r=4, - lora_alpha=4, + r=rank, + lora_alpha=lora_alpha, target_modules=["to_q", "to_k", "to_v", "to_out.0"], init_lora_weights=False, - use_dora=False, + use_dora=use_dora, ) model.add_adapter(denoiser_lora_config) metadata = model.peft_config["default"].to_dict() @@ -1177,15 +1177,12 @@ def test_adapter_metadata_is_loaded_correctly(self): model_file = os.path.join(tmpdir, "pytorch_lora_weights.safetensors") self.assertTrue(os.path.isfile(model_file)) - with safetensors.torch.safe_open(model_file, framework="pt", device="cpu") as f: - if hasattr(f, "metadata"): - parsed_metadata = f.metadata() - parsed_metadata = {k: v for k, v in parsed_metadata.items() if k != "format"} - self.assertTrue(LORA_ADAPTER_METADATA_KEY in parsed_metadata) - parsed_metadata = {k: v for k, v in parsed_metadata.items() if k != "format"} + model.unload_lora() + self.assertFalse(check_if_lora_correctly_set(model), "LoRA layers not set correctly") - parsed_metadata = json.loads(parsed_metadata[LORA_ADAPTER_METADATA_KEY]) - check_if_dicts_are_equal(parsed_metadata, metadata) + model.load_lora_adapter(tmpdir, prefix=None, use_safetensors=True) + parsed_metadata = model.peft_config["default_0"].to_dict() + check_if_dicts_are_equal(metadata, parsed_metadata) @require_torch_accelerator def test_cpu_offload(self): From 0ac1a39f4c69d6019a49fb235d741eceab64ef23 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Fri, 2 May 2025 18:12:29 +0530 Subject: [PATCH 14/53] updates --- src/diffusers/loaders/lora_base.py | 2 +- src/diffusers/loaders/lora_pipeline.py | 50 +++++++++++--------------- src/diffusers/utils/peft_utils.py | 7 ++-- 3 files changed, 25 insertions(+), 34 deletions(-) diff --git a/src/diffusers/loaders/lora_base.py b/src/diffusers/loaders/lora_base.py index 2a8f29822c2d..c45234191687 100644 --- a/src/diffusers/loaders/lora_base.py +++ b/src/diffusers/loaders/lora_base.py @@ -906,7 +906,7 @@ def save_function(weights, filename): # We need to be able to serialize the NoneTypes too, otherwise we run into # 'NoneType' object cannot be converted to 'PyString' metadata = {"format": "pt"} - if lora_adapter_metadata is not None: + if lora_adapter_metadata: for key, value in lora_adapter_metadata.items(): if isinstance(value, set): lora_adapter_metadata[key] = list(value) diff --git a/src/diffusers/loaders/lora_pipeline.py b/src/diffusers/loaders/lora_pipeline.py index 8392a3f3936c..0bdd060e0a3d 100644 --- a/src/diffusers/loaders/lora_pipeline.py +++ b/src/diffusers/loaders/lora_pipeline.py @@ -1695,10 +1695,9 @@ def save_lora_weights( if not transformer_lora_layers: raise ValueError("You must pass `transformer_lora_layers`.") - if transformer_lora_layers: - state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name)) + state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name)) - if transformer_lora_adapter_metadata: + if transformer_lora_adapter_metadata is not None: lora_adapter_metadata.update(cls.pack_weights(transformer_lora_adapter_metadata, cls.transformer_name)) # Save the model @@ -3020,10 +3019,9 @@ def save_lora_weights( if not transformer_lora_layers: raise ValueError("You must pass `transformer_lora_layers`.") - if transformer_lora_layers: - state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name)) + state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name)) - if transformer_lora_adapter_metadata: + if transformer_lora_adapter_metadata is not None: lora_adapter_metadata.update(cls.pack_weights(transformer_lora_adapter_metadata, cls.transformer_name)) # Save the model @@ -3344,10 +3342,9 @@ def save_lora_weights( if not transformer_lora_layers: raise ValueError("You must pass `transformer_lora_layers`.") - if transformer_lora_layers: - state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name)) + state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name)) - if transformer_lora_adapter_metadata: + if transformer_lora_adapter_metadata is not None: lora_adapter_metadata.update(cls.pack_weights(transformer_lora_adapter_metadata, cls.transformer_name)) # Save the model @@ -3670,10 +3667,9 @@ def save_lora_weights( if not transformer_lora_layers: raise ValueError("You must pass `transformer_lora_layers`.") - if transformer_lora_layers: - state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name)) + state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name)) - if transformer_lora_adapter_metadata: + if transformer_lora_adapter_metadata is not None: lora_adapter_metadata.update(cls.pack_weights(transformer_lora_adapter_metadata, cls.transformer_name)) # Save the model @@ -3996,10 +3992,9 @@ def save_lora_weights( if not transformer_lora_layers: raise ValueError("You must pass `transformer_lora_layers`.") - if transformer_lora_layers: - state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name)) + state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name)) - if transformer_lora_adapter_metadata: + if transformer_lora_adapter_metadata is not None: lora_adapter_metadata.update(cls.pack_weights(transformer_lora_adapter_metadata, cls.transformer_name)) # Save the model @@ -4325,10 +4320,9 @@ def save_lora_weights( if not transformer_lora_layers: raise ValueError("You must pass `transformer_lora_layers`.") - if transformer_lora_layers: - state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name)) + state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name)) - if transformer_lora_adapter_metadata: + if transformer_lora_adapter_metadata is not None: lora_adapter_metadata.update(cls.pack_weights(transformer_lora_adapter_metadata, cls.transformer_name)) # Save the model @@ -4655,10 +4649,9 @@ def save_lora_weights( if not transformer_lora_layers: raise ValueError("You must pass `transformer_lora_layers`.") - if transformer_lora_layers: - state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name)) + state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name)) - if transformer_lora_adapter_metadata: + if transformer_lora_adapter_metadata is not None: lora_adapter_metadata.update(cls.pack_weights(transformer_lora_adapter_metadata, cls.transformer_name)) # Save the model @@ -5014,10 +5007,9 @@ def save_lora_weights( if not transformer_lora_layers: raise ValueError("You must pass `transformer_lora_layers`.") - if transformer_lora_layers: - state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name)) + state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name)) - if transformer_lora_adapter_metadata: + if transformer_lora_adapter_metadata is not None: lora_adapter_metadata.update(cls.pack_weights(transformer_lora_adapter_metadata, cls.transformer_name)) # Save the model @@ -5340,10 +5332,9 @@ def save_lora_weights( if not transformer_lora_layers: raise ValueError("You must pass `transformer_lora_layers`.") - if transformer_lora_layers: - state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name)) + state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name)) - if transformer_lora_adapter_metadata: + if transformer_lora_adapter_metadata is not None: lora_adapter_metadata.update(cls.pack_weights(transformer_lora_adapter_metadata, cls.transformer_name)) # Save the model @@ -5666,10 +5657,9 @@ def save_lora_weights( if not transformer_lora_layers: raise ValueError("You must pass `transformer_lora_layers`.") - if transformer_lora_layers: - state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name)) + state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name)) - if transformer_lora_adapter_metadata: + if transformer_lora_adapter_metadata is not None: lora_adapter_metadata.update(cls.pack_weights(transformer_lora_adapter_metadata, cls.transformer_name)) # Save the model diff --git a/src/diffusers/utils/peft_utils.py b/src/diffusers/utils/peft_utils.py index 9f73006effee..d290231589a4 100644 --- a/src/diffusers/utils/peft_utils.py +++ b/src/diffusers/utils/peft_utils.py @@ -158,9 +158,10 @@ def get_peft_kwargs( if LORA_ADAPTER_METADATA_KEY in peft_state_dict: metadata = peft_state_dict[LORA_ADAPTER_METADATA_KEY] - if prefix is not None: - metadata = {k.replace(f"{prefix}.", ""): v for k, v in metadata.items()} - return metadata + if metadata: + if prefix is not None: + metadata = {k.replace(f"{prefix}.", ""): v for k, v in metadata.items()} + return metadata rank_pattern = {} alpha_pattern = {} From 4b51bbf89c9c48353eabe3b925884af956512292 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Fri, 2 May 2025 18:30:43 +0530 Subject: [PATCH 15/53] fixes --- src/diffusers/loaders/lora_base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/loaders/lora_base.py b/src/diffusers/loaders/lora_base.py index c45234191687..7b1d16cd9ec5 100644 --- a/src/diffusers/loaders/lora_base.py +++ b/src/diffusers/loaders/lora_base.py @@ -896,7 +896,7 @@ def write_lora_layers( if lora_adapter_metadata and not safe_serialization: raise ValueError("`lora_adapter_metadata` cannot be specified when not using `safe_serialization`.") - if not isinstance(lora_adapter_metadata, dict): + if lora_adapter_metadata and not isinstance(lora_adapter_metadata, dict): raise ValueError("`lora_adapter_metadata` must be of type `dict`.") if save_function is None: From e2ca95aed6b6933093bb64e6e7fc0cc55286d2d7 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Fri, 2 May 2025 20:24:48 +0530 Subject: [PATCH 16/53] rename tests. --- tests/models/test_modeling_common.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/models/test_modeling_common.py b/tests/models/test_modeling_common.py index 3bed2a87ce18..49342f5affa5 100644 --- a/tests/models/test_modeling_common.py +++ b/tests/models/test_modeling_common.py @@ -1150,7 +1150,7 @@ def test_wrong_adapter_name_raises_error(self): @parameterized.expand([(4, 4, True), (4, 8, False), (8, 4, False)]) @torch.no_grad() @unittest.skipIf(not is_peft_available(), "Only with PEFT") - def test_adapter_metadata_is_loaded_correctly(self, rank, lora_alpha, use_dora): + def test_lora_adapter_metadata_is_loaded_correctly(self, rank, lora_alpha, use_dora): from peft import LoraConfig from diffusers.loaders.peft import PeftAdapterMixin From e0449c2a88a7355adb8a26f6680be3b29e937f9c Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Sat, 3 May 2025 12:54:47 +0530 Subject: [PATCH 17/53] sorted. --- src/diffusers/utils/testing_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/diffusers/utils/testing_utils.py b/src/diffusers/utils/testing_utils.py index 9d44969394d9..e52c7d29cae0 100644 --- a/src/diffusers/utils/testing_utils.py +++ b/src/diffusers/utils/testing_utils.py @@ -135,10 +135,10 @@ def numpy_cosine_similarity_distance(a, b): def check_if_dicts_are_equal(dict1, dict2): for key, value in dict1.items(): if isinstance(value, set): - dict1[key] = list(value) + dict1[key] = sorted(value) for key, value in dict2.items(): if isinstance(value, set): - dict2[key] = list(value) + dict2[key] = sorted(value) for key in dict1: if key not in dict2: From 918aef1a10f9940ffda73a3fff190753fc64f8bc Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Sat, 3 May 2025 23:28:29 +0530 Subject: [PATCH 18/53] Update src/diffusers/loaders/lora_base.py Co-authored-by: Benjamin Bossan --- src/diffusers/loaders/lora_base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/loaders/lora_base.py b/src/diffusers/loaders/lora_base.py index 7b1d16cd9ec5..806a6dfacabf 100644 --- a/src/diffusers/loaders/lora_base.py +++ b/src/diffusers/loaders/lora_base.py @@ -888,7 +888,7 @@ def write_lora_layers( weight_name: str, save_function: Callable, safe_serialization: bool, - lora_adapter_metadata: dict = None, + lora_adapter_metadata: Optional[dict] = None, ): if os.path.isfile(save_directory): logger.error(f"Provided path ({save_directory}) should be a directory, not a file") From 4bd325c413bb5e4689272cb3320a4f026224acf4 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Sat, 3 May 2025 23:59:18 +0530 Subject: [PATCH 19/53] review suggestions. --- src/diffusers/loaders/lora_base.py | 3 +-- src/diffusers/loaders/peft.py | 9 ++++----- src/diffusers/utils/testing_utils.py | 14 +++++--------- tests/lora/test_lora_layers_wan.py | 9 ++++----- 4 files changed, 14 insertions(+), 21 deletions(-) diff --git a/src/diffusers/loaders/lora_base.py b/src/diffusers/loaders/lora_base.py index 806a6dfacabf..b88066fd88c4 100644 --- a/src/diffusers/loaders/lora_base.py +++ b/src/diffusers/loaders/lora_base.py @@ -903,8 +903,7 @@ def write_lora_layers( if safe_serialization: def save_function(weights, filename): - # We need to be able to serialize the NoneTypes too, otherwise we run into - # 'NoneType' object cannot be converted to 'PyString' + # Inject framework format. metadata = {"format": "pt"} if lora_adapter_metadata: for key, value in lora_adapter_metadata.items(): diff --git a/src/diffusers/loaders/peft.py b/src/diffusers/loaders/peft.py index f2194bb2dab2..d943f402a02d 100644 --- a/src/diffusers/loaders/peft.py +++ b/src/diffusers/loaders/peft.py @@ -238,11 +238,11 @@ def load_lora_adapter( raise ValueError("`network_alphas` cannot be None when `prefix` is None.") if prefix is not None: - metadata = state_dict.pop(LORA_ADAPTER_METADATA_KEY, None) state_dict = {k[len(f"{prefix}.") :]: v for k, v in state_dict.items() if k.startswith(f"{prefix}.")} - if metadata is not None: - state_dict[LORA_ADAPTER_METADATA_KEY] = metadata + metadata = state_dict.pop(LORA_ADAPTER_METADATA_KEY, None) + if metadata is not None: + state_dict[LORA_ADAPTER_METADATA_KEY] = metadata if len(state_dict) > 0: if adapter_name in getattr(self, "peft_config", {}) and not hotswap: @@ -487,8 +487,7 @@ def save_lora_adapter( if safe_serialization: def save_function(weights, filename): - # We need to be able to serialize the NoneTypes too, otherwise we run into - # 'NoneType' object cannot be converted to 'PyString' + # Inject framework format. metadata = {"format": "pt"} if lora_adapter_metadata is not None: for key, value in lora_adapter_metadata.items(): diff --git a/src/diffusers/utils/testing_utils.py b/src/diffusers/utils/testing_utils.py index e52c7d29cae0..87adfede0d2b 100644 --- a/src/diffusers/utils/testing_utils.py +++ b/src/diffusers/utils/testing_utils.py @@ -133,6 +133,8 @@ def numpy_cosine_similarity_distance(a, b): def check_if_dicts_are_equal(dict1, dict2): + dict1, dict2 = dict1.copy(), dict2.copy() + for key, value in dict1.items(): if isinstance(value, set): dict1[key] = sorted(value) @@ -142,19 +144,13 @@ def check_if_dicts_are_equal(dict1, dict2): for key in dict1: if key not in dict2: - raise ValueError( - f"Key '{key}' is missing in the second dictionary. Its value in the first dictionary is {dict1[key]}." - ) + return False if dict1[key] != dict2[key]: - raise ValueError( - f"Difference found at key '{key}': first dictionary has {dict1[key]}, second dictionary has {dict2[key]}." - ) + return False for key in dict2: if key not in dict1: - raise ValueError( - f"Key '{key}' is missing in the first dictionary. Its value in the second dictionary is {dict2[key]}." - ) + return False return True diff --git a/tests/lora/test_lora_layers_wan.py b/tests/lora/test_lora_layers_wan.py index bdac22d0e477..28f5cd5acfb3 100644 --- a/tests/lora/test_lora_layers_wan.py +++ b/tests/lora/test_lora_layers_wan.py @@ -31,7 +31,6 @@ check_if_dicts_are_equal, floats_tensor, require_peft_backend, - skip_mps, torch_device, ) @@ -42,7 +41,7 @@ @require_peft_backend -@skip_mps +# @skip_mps class WanLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): pipeline_class = WanPipeline scheduler_cls = FlowMatchEulerDiscreteScheduler @@ -147,8 +146,8 @@ def test_simple_inference_with_text_lora_fused(self): def test_simple_inference_with_text_lora_save_load(self): pass - def test_adapter_metadata_is_loaded_correctly(self): - # Will write the test in utils.py eventually. + def test_lora_adapter_metadata_is_loaded_correctly(self): + # TODO: Will write the test in utils.py eventually. scheduler_cls = self.scheduler_classes[0] components, _, denoiser_lora_config = self.get_dummy_components(scheduler_cls) pipe = self.pipeline_class(**components) @@ -175,7 +174,7 @@ def test_adapter_metadata_is_loaded_correctly(self): parsed_metadata = {k[len("transformer.") :]: v for k, v in parsed_metadata.items()} check_if_dicts_are_equal(parsed_metadata, metadata) - def test_adapter_metadata_save_load_inference(self): + def test_lora_adapter_metadata_save_load_inference(self): # Will write the test in utils.py eventually. scheduler_cls = self.scheduler_classes[0] components, _, denoiser_lora_config = self.get_dummy_components(scheduler_cls) From e8bec868edb7475b9b44a4d26ff54f3f102232ce Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Mon, 5 May 2025 17:04:51 +0530 Subject: [PATCH 20/53] removeprefix. --- src/diffusers/utils/peft_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/utils/peft_utils.py b/src/diffusers/utils/peft_utils.py index d290231589a4..97258c29edd8 100644 --- a/src/diffusers/utils/peft_utils.py +++ b/src/diffusers/utils/peft_utils.py @@ -160,7 +160,7 @@ def get_peft_kwargs( metadata = peft_state_dict[LORA_ADAPTER_METADATA_KEY] if metadata: if prefix is not None: - metadata = {k.replace(f"{prefix}.", ""): v for k, v in metadata.items()} + metadata = {k.removeprefix(prefix + "."): v for k, v in metadata.items()} return metadata rank_pattern = {} From 7bb6c9f339ee7c52b6ed86aa46b15e8a3b177f37 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Thu, 8 May 2025 14:29:37 +0530 Subject: [PATCH 21/53] propagate changes. --- src/diffusers/loaders/lora_base.py | 5 +- src/diffusers/loaders/lora_pipeline.py | 23 ++++++ src/diffusers/loaders/peft.py | 11 ++- src/diffusers/utils/peft_utils.py | 10 ++- tests/lora/test_lora_layers_wan.py | 71 +---------------- tests/lora/utils.py | 104 ++++++++++++++++++++++++- 6 files changed, 140 insertions(+), 84 deletions(-) diff --git a/src/diffusers/loaders/lora_base.py b/src/diffusers/loaders/lora_base.py index b88066fd88c4..da83481f371a 100644 --- a/src/diffusers/loaders/lora_base.py +++ b/src/diffusers/loaders/lora_base.py @@ -353,8 +353,11 @@ def _load_lora_into_text_encoder( raise ValueError("At the moment, hotswapping is not supported for text encoders, please pass `hotswap=False`.") # Load the layers corresponding to text encoder and make necessary adjustments. + if LORA_ADAPTER_METADATA_KEY in state_dict: + metadata = state_dict[LORA_ADAPTER_METADATA_KEY] if prefix is not None: state_dict = {k[len(f"{prefix}.") :]: v for k, v in state_dict.items() if k.startswith(f"{prefix}.")} + state_dict[LORA_ADAPTER_METADATA_KEY] = metadata if len(state_dict) > 0: logger.info(f"Loading {prefix}.") @@ -382,7 +385,7 @@ def _load_lora_into_text_encoder( alpha_keys = [k for k in network_alphas.keys() if k.startswith(prefix) and k.split(".")[0] == prefix] network_alphas = {k.replace(f"{prefix}.", ""): v for k, v in network_alphas.items() if k in alpha_keys} - lora_config_kwargs = get_peft_kwargs(rank, network_alphas, state_dict, is_unet=False) + lora_config_kwargs = get_peft_kwargs(rank, network_alphas, state_dict, is_unet=False, prefix=prefix) if "use_dora" in lora_config_kwargs: if lora_config_kwargs["use_dora"]: diff --git a/src/diffusers/loaders/lora_pipeline.py b/src/diffusers/loaders/lora_pipeline.py index 0bdd060e0a3d..eb56f4365253 100644 --- a/src/diffusers/loaders/lora_pipeline.py +++ b/src/diffusers/loaders/lora_pipeline.py @@ -644,6 +644,9 @@ def load_lora_weights( if not is_correct_format: raise ValueError("Invalid LoRA checkpoint.") + from .lora_base import LORA_ADAPTER_METADATA_KEY + + print(f"{LORA_ADAPTER_METADATA_KEY in state_dict=} before UNet") self.load_lora_into_unet( state_dict, network_alphas=network_alphas, @@ -653,6 +656,7 @@ def load_lora_weights( low_cpu_mem_usage=low_cpu_mem_usage, hotswap=hotswap, ) + print(f"{LORA_ADAPTER_METADATA_KEY in state_dict=} before text encoder.") self.load_lora_into_text_encoder( state_dict, network_alphas=network_alphas, @@ -664,6 +668,7 @@ def load_lora_weights( low_cpu_mem_usage=low_cpu_mem_usage, hotswap=hotswap, ) + print(f"{LORA_ADAPTER_METADATA_KEY in state_dict=} before text encoder 2.") self.load_lora_into_text_encoder( state_dict, network_alphas=network_alphas, @@ -732,6 +737,7 @@ def lora_state_dict( """ # Load the main state dict first which has the LoRA layers for either of # UNet and text encoder or both. + cache_dir = kwargs.pop("cache_dir", None) force_download = kwargs.pop("force_download", False) proxies = kwargs.pop("proxies", None) @@ -914,6 +920,9 @@ def save_lora_weights( weight_name: str = None, save_function: Callable = None, safe_serialization: bool = True, + unet_lora_adapter_metadata=None, + text_encoder_lora_adapter_metadata=None, + text_encoder_2_lora_adapter_metadata=None, ): r""" Save the LoRA parameters corresponding to the UNet and text encoder. @@ -939,8 +948,12 @@ def save_lora_weights( `DIFFUSERS_SAVE_MODE`. safe_serialization (`bool`, *optional*, defaults to `True`): Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`. + unet_lora_adapter_metadata: TODO + text_encoder_lora_adapter_metadata: TODO + text_encoder_2_lora_adapter_metadata: TODO """ state_dict = {} + lora_adapter_metadata = {} if not (unet_lora_layers or text_encoder_lora_layers or text_encoder_2_lora_layers): raise ValueError( @@ -956,6 +969,15 @@ def save_lora_weights( if text_encoder_2_lora_layers: state_dict.update(cls.pack_weights(text_encoder_2_lora_layers, "text_encoder_2")) + if unet_lora_adapter_metadata is not None: + lora_adapter_metadata.update(cls.pack_weights(unet_lora_adapter_metadata, cls.unet_name)) + + if text_encoder_lora_adapter_metadata: + lora_adapter_metadata.update(cls.pack_weights(text_encoder_lora_adapter_metadata, cls.text_encoder_name)) + + if text_encoder_2_lora_adapter_metadata: + lora_adapter_metadata.update(cls.pack_weights(text_encoder_2_lora_adapter_metadata, "text_encoder_2")) + cls.write_lora_layers( state_dict=state_dict, save_directory=save_directory, @@ -963,6 +985,7 @@ def save_lora_weights( weight_name=weight_name, save_function=save_function, safe_serialization=safe_serialization, + lora_adapter_metadata=lora_adapter_metadata, ) def fuse_lora( diff --git a/src/diffusers/loaders/peft.py b/src/diffusers/loaders/peft.py index d943f402a02d..c8e4bfd63c0a 100644 --- a/src/diffusers/loaders/peft.py +++ b/src/diffusers/loaders/peft.py @@ -193,7 +193,7 @@ def load_lora_adapter( from peft import LoraConfig, inject_adapter_in_model, set_peft_model_state_dict from peft.tuners.tuners_utils import BaseTunerLayer - from .lora_base import LORA_ADAPTER_METADATA_KEY + from ..loaders.lora_base import LORA_ADAPTER_METADATA_KEY cache_dir = kwargs.pop("cache_dir", None) force_download = kwargs.pop("force_download", False) @@ -234,15 +234,14 @@ def load_lora_adapter( user_agent=user_agent, allow_pickle=allow_pickle, ) + if LORA_ADAPTER_METADATA_KEY in state_dict: + metadata = state_dict[LORA_ADAPTER_METADATA_KEY] if network_alphas is not None and prefix is None: raise ValueError("`network_alphas` cannot be None when `prefix` is None.") if prefix is not None: - state_dict = {k[len(f"{prefix}.") :]: v for k, v in state_dict.items() if k.startswith(f"{prefix}.")} - - metadata = state_dict.pop(LORA_ADAPTER_METADATA_KEY, None) - if metadata is not None: - state_dict[LORA_ADAPTER_METADATA_KEY] = metadata + state_dict = {k.removeprefix(f"{prefix}."): v for k, v in state_dict.items() if k.startswith(f"{prefix}.")} + state_dict[LORA_ADAPTER_METADATA_KEY] = metadata if len(state_dict) > 0: if adapter_name in getattr(self, "peft_config", {}) and not hotswap: diff --git a/src/diffusers/utils/peft_utils.py b/src/diffusers/utils/peft_utils.py index 97258c29edd8..9a95d0f44db0 100644 --- a/src/diffusers/utils/peft_utils.py +++ b/src/diffusers/utils/peft_utils.py @@ -158,10 +158,12 @@ def get_peft_kwargs( if LORA_ADAPTER_METADATA_KEY in peft_state_dict: metadata = peft_state_dict[LORA_ADAPTER_METADATA_KEY] - if metadata: - if prefix is not None: - metadata = {k.removeprefix(prefix + "."): v for k, v in metadata.items()} - return metadata + else: + metadata = None + if metadata: + if prefix is not None: + metadata = {k.removeprefix(f"{prefix}."): v for k, v in metadata.items() if k.startswith(f"{prefix}.")} + return metadata rank_pattern = {} alpha_pattern = {} diff --git a/tests/lora/test_lora_layers_wan.py b/tests/lora/test_lora_layers_wan.py index 28f5cd5acfb3..7866556e900f 100644 --- a/tests/lora/test_lora_layers_wan.py +++ b/tests/lora/test_lora_layers_wan.py @@ -13,10 +13,8 @@ # limitations under the License. import sys -import tempfile import unittest -import numpy as np import torch from transformers import AutoTokenizer, T5EncoderModel @@ -26,13 +24,7 @@ WanPipeline, WanTransformer3DModel, ) -from diffusers.loaders.lora_base import LORA_ADAPTER_METADATA_KEY -from diffusers.utils.testing_utils import ( - check_if_dicts_are_equal, - floats_tensor, - require_peft_backend, - torch_device, -) +from diffusers.utils.testing_utils import floats_tensor, require_peft_backend sys.path.append(".") @@ -145,64 +137,3 @@ def test_simple_inference_with_text_lora_fused(self): @unittest.skip("Text encoder LoRA is not supported in Wan.") def test_simple_inference_with_text_lora_save_load(self): pass - - def test_lora_adapter_metadata_is_loaded_correctly(self): - # TODO: Will write the test in utils.py eventually. - scheduler_cls = self.scheduler_classes[0] - components, _, denoiser_lora_config = self.get_dummy_components(scheduler_cls) - pipe = self.pipeline_class(**components) - - pipe, _ = self.check_if_adapters_added_correctly( - pipe, text_lora_config=None, denoiser_lora_config=denoiser_lora_config - ) - - with tempfile.TemporaryDirectory() as tmpdir: - modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True) - lora_state_dicts = self._get_lora_state_dicts(modules_to_save) - metadata = denoiser_lora_config.to_dict() - self.pipeline_class.save_lora_weights( - save_directory=tmpdir, - transformer_lora_adapter_metadata=metadata, - **lora_state_dicts, - ) - pipe.unload_lora_weights() - state_dict = pipe.lora_state_dict(tmpdir) - - self.assertTrue(LORA_ADAPTER_METADATA_KEY in state_dict) - - parsed_metadata = state_dict[LORA_ADAPTER_METADATA_KEY] - parsed_metadata = {k[len("transformer.") :]: v for k, v in parsed_metadata.items()} - check_if_dicts_are_equal(parsed_metadata, metadata) - - def test_lora_adapter_metadata_save_load_inference(self): - # Will write the test in utils.py eventually. - scheduler_cls = self.scheduler_classes[0] - components, _, denoiser_lora_config = self.get_dummy_components(scheduler_cls) - pipe = self.pipeline_class(**components).to(torch_device) - _, _, inputs = self.get_dummy_inputs(with_generator=False) - - output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] - self.assertTrue(output_no_lora.shape == self.output_shape) - - pipe, _ = self.check_if_adapters_added_correctly( - pipe, text_lora_config=None, denoiser_lora_config=denoiser_lora_config - ) - output_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] - - with tempfile.TemporaryDirectory() as tmpdir: - modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True) - lora_state_dicts = self._get_lora_state_dicts(modules_to_save) - metadata = denoiser_lora_config.to_dict() - self.pipeline_class.save_lora_weights( - save_directory=tmpdir, - transformer_lora_adapter_metadata=metadata, - **lora_state_dicts, - ) - pipe.unload_lora_weights() - pipe.load_lora_weights(tmpdir) - - output_lora_pretrained = pipe(**inputs, generator=torch.manual_seed(0))[0] - - self.assertTrue( - np.allclose(output_lora, output_lora_pretrained, atol=1e-3, rtol=1e-3), "Lora outputs should match." - ) diff --git a/tests/lora/utils.py b/tests/lora/utils.py index 87a8fddfa583..3fe83f102f94 100644 --- a/tests/lora/utils.py +++ b/tests/lora/utils.py @@ -22,6 +22,7 @@ import numpy as np import pytest import torch +from parameterized import parameterized from diffusers import ( AutoencoderKL, @@ -29,10 +30,12 @@ LCMScheduler, UNet2DConditionModel, ) +from diffusers.loaders.lora_base import LORA_ADAPTER_METADATA_KEY from diffusers.utils import logging from diffusers.utils.import_utils import is_peft_available from diffusers.utils.testing_utils import ( CaptureLogger, + check_if_dicts_are_equal, floats_tensor, is_torch_version, require_peft_backend, @@ -71,6 +74,13 @@ def check_if_lora_correctly_set(model) -> bool: return False +def check_module_lora_metadata(parsed_metadata: dict, lora_metadatas: dict, module_key: str): + extracted = { + k.removeprefix(f"{module_key}."): v for k, v in parsed_metadata.items() if k.startswith(f"{module_key}.") + } + check_if_dicts_are_equal(extracted, lora_metadatas[f"{module_key}_lora_adapter_metadata"]) + + def initialize_dummy_state_dict(state_dict): if not all(v.device.type == "meta" for _, v in state_dict.items()): raise ValueError("`state_dict` has non-meta values.") @@ -106,7 +116,7 @@ class PeftLoraLoaderMixinTests: text_encoder_target_modules = ["q_proj", "k_proj", "v_proj", "out_proj"] denoiser_target_modules = ["to_q", "to_k", "to_v", "to_out.0"] - def get_dummy_components(self, scheduler_cls=None, use_dora=False): + def get_dummy_components(self, scheduler_cls=None, use_dora=False, lora_alpha=None): if self.unet_kwargs and self.transformer_kwargs: raise ValueError("Both `unet_kwargs` and `transformer_kwargs` cannot be specified.") if self.has_two_text_encoders and self.has_three_text_encoders: @@ -114,6 +124,7 @@ def get_dummy_components(self, scheduler_cls=None, use_dora=False): scheduler_cls = self.scheduler_cls if scheduler_cls is None else scheduler_cls rank = 4 + lora_alpha = rank if lora_alpha is None else lora_alpha torch.manual_seed(0) if self.unet_kwargs is not None: @@ -149,7 +160,7 @@ def get_dummy_components(self, scheduler_cls=None, use_dora=False): text_lora_config = LoraConfig( r=rank, - lora_alpha=rank, + lora_alpha=lora_alpha, target_modules=self.text_encoder_target_modules, init_lora_weights=False, use_dora=use_dora, @@ -157,7 +168,7 @@ def get_dummy_components(self, scheduler_cls=None, use_dora=False): denoiser_lora_config = LoraConfig( r=rank, - lora_alpha=rank, + lora_alpha=lora_alpha, target_modules=self.denoiser_target_modules, init_lora_weights=False, use_dora=use_dora, @@ -234,6 +245,13 @@ def _get_lora_state_dicts(self, modules_to_save): state_dicts[f"{module_name}_lora_layers"] = get_peft_model_state_dict(module) return state_dicts + def _get_lora_adapter_metadata(self, modules_to_save): + metadatas = {} + for module_name, module in modules_to_save.items(): + if module is not None: + metadatas[f"{module_name}_lora_adapter_metadata"] = module.peft_config["default"].to_dict() + return metadatas + def _get_modules_to_save(self, pipe, has_denoiser=False): modules_to_save = {} lora_loadable_modules = self.pipeline_class._lora_loadable_modules @@ -2149,3 +2167,83 @@ def check_module(denoiser): _, _, inputs = self.get_dummy_inputs(with_generator=False) pipe(**inputs, generator=torch.manual_seed(0))[0] + + @parameterized.expand([4, 8, 16]) + def test_lora_adapter_metadata_is_loaded_correctly(self, lora_alpha): + scheduler_cls = self.scheduler_classes[0] + components, text_lora_config, denoiser_lora_config = self.get_dummy_components( + scheduler_cls, lora_alpha=lora_alpha + ) + pipe = self.pipeline_class(**components) + + pipe, _ = self.check_if_adapters_added_correctly( + pipe, text_lora_config=text_lora_config, denoiser_lora_config=denoiser_lora_config + ) + + with tempfile.TemporaryDirectory() as tmpdir: + modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True) + lora_state_dicts = self._get_lora_state_dicts(modules_to_save) + lora_metadatas = self._get_lora_adapter_metadata(modules_to_save) + self.pipeline_class.save_lora_weights(save_directory=tmpdir, **lora_state_dicts, **lora_metadatas) + pipe.unload_lora_weights() + + out = pipe.lora_state_dict(tmpdir) + if isinstance(out, tuple): + state_dict, _ = out + else: + state_dict = out + + self.assertTrue(LORA_ADAPTER_METADATA_KEY in state_dict) + + parsed_metadata = state_dict[LORA_ADAPTER_METADATA_KEY] + denoiser_key = ( + f"{self.pipeline_class.transformer_name}" + if self.transformer_kwargs is not None + else f"{self.pipeline_class.unet_name}" + ) + check_module_lora_metadata( + parsed_metadata=parsed_metadata, lora_metadatas=lora_metadatas, module_key=denoiser_key + ) + + if "text_encoder" in self.pipeline_class._lora_loadable_modules: + text_encoder_key = self.pipeline_class.text_encoder_name + check_module_lora_metadata( + parsed_metadata=parsed_metadata, lora_metadatas=lora_metadatas, module_key=text_encoder_key + ) + + if "text_encoder_2" in self.pipeline_class._lora_loadable_modules: + text_encoder_2_key = "text_encoder_2" + check_module_lora_metadata( + parsed_metadata=parsed_metadata, lora_metadatas=lora_metadatas, module_key=text_encoder_2_key + ) + + @parameterized.expand([4, 8, 16]) + def test_lora_adapter_metadata_save_load_inference(self, lora_alpha): + scheduler_cls = self.scheduler_classes[0] + components, text_lora_config, denoiser_lora_config = self.get_dummy_components( + scheduler_cls, lora_alpha=lora_alpha + ) + pipe = self.pipeline_class(**components).to(torch_device) + _, _, inputs = self.get_dummy_inputs(with_generator=False) + + output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] + self.assertTrue(output_no_lora.shape == self.output_shape) + + pipe, _ = self.check_if_adapters_added_correctly( + pipe, text_lora_config=text_lora_config, denoiser_lora_config=denoiser_lora_config + ) + output_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] + + with tempfile.TemporaryDirectory() as tmpdir: + modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True) + lora_state_dicts = self._get_lora_state_dicts(modules_to_save) + lora_metadatas = self._get_lora_adapter_metadata(modules_to_save) + self.pipeline_class.save_lora_weights(save_directory=tmpdir, **lora_state_dicts, **lora_metadatas) + pipe.unload_lora_weights() + pipe.load_lora_weights(tmpdir) + + output_lora_pretrained = pipe(**inputs, generator=torch.manual_seed(0))[0] + + self.assertTrue( + np.allclose(output_lora, output_lora_pretrained, atol=1e-3, rtol=1e-3), "Lora outputs should match." + ) From 116306edda32459a1cb91ace9b1a6b013139167c Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Thu, 8 May 2025 14:32:39 +0530 Subject: [PATCH 22/53] fix-copies --- src/diffusers/loaders/lora_pipeline.py | 18 +++++++++++++++++- 1 file changed, 17 insertions(+), 1 deletion(-) diff --git a/src/diffusers/loaders/lora_pipeline.py b/src/diffusers/loaders/lora_pipeline.py index eb56f4365253..eb43db1fcf73 100644 --- a/src/diffusers/loaders/lora_pipeline.py +++ b/src/diffusers/loaders/lora_pipeline.py @@ -737,7 +737,6 @@ def lora_state_dict( """ # Load the main state dict first which has the LoRA layers for either of # UNet and text encoder or both. - cache_dir = kwargs.pop("cache_dir", None) force_download = kwargs.pop("force_download", False) proxies = kwargs.pop("proxies", None) @@ -1355,6 +1354,9 @@ def save_lora_weights( weight_name: str = None, save_function: Callable = None, safe_serialization: bool = True, + transformer_lora_adapter_metadata=None, + text_encoder_lora_adapter_metadata=None, + text_encoder_2_lora_adapter_metadata=None, ): r""" Save the LoRA parameters corresponding to the UNet and text encoder. @@ -1380,8 +1382,12 @@ def save_lora_weights( `DIFFUSERS_SAVE_MODE`. safe_serialization (`bool`, *optional*, defaults to `True`): Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`. + transformer_lora_adapter_metadata: TODO + text_encoder_lora_adapter_metadata: TODO + text_encoder_2_lora_adapter_metadata: TODO """ state_dict = {} + lora_adapter_metadata = {} if not (transformer_lora_layers or text_encoder_lora_layers or text_encoder_2_lora_layers): raise ValueError( @@ -1397,6 +1403,15 @@ def save_lora_weights( if text_encoder_2_lora_layers: state_dict.update(cls.pack_weights(text_encoder_2_lora_layers, "text_encoder_2")) + if transformer_lora_adapter_metadata is not None: + lora_adapter_metadata.update(cls.pack_weights(transformer_lora_adapter_metadata, cls.transformer_name)) + + if text_encoder_lora_adapter_metadata: + lora_adapter_metadata.update(cls.pack_weights(text_encoder_lora_adapter_metadata, cls.text_encoder_name)) + + if text_encoder_2_lora_adapter_metadata: + lora_adapter_metadata.update(cls.pack_weights(text_encoder_2_lora_adapter_metadata, "text_encoder_2")) + cls.write_lora_layers( state_dict=state_dict, save_directory=save_directory, @@ -1404,6 +1419,7 @@ def save_lora_weights( weight_name=weight_name, save_function=save_function, safe_serialization=safe_serialization, + lora_adapter_metadata=lora_adapter_metadata, ) # Copied from diffusers.loaders.lora_pipeline.StableDiffusionXLLoraLoaderMixin.fuse_lora with unet->transformer From ae0580a548cdc7d8c8eeed26cbc84041a05be501 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Thu, 8 May 2025 14:43:01 +0530 Subject: [PATCH 23/53] sd --- src/diffusers/loaders/lora_pipeline.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/src/diffusers/loaders/lora_pipeline.py b/src/diffusers/loaders/lora_pipeline.py index eb43db1fcf73..4b814d43e823 100644 --- a/src/diffusers/loaders/lora_pipeline.py +++ b/src/diffusers/loaders/lora_pipeline.py @@ -457,6 +457,8 @@ def save_lora_weights( weight_name: str = None, save_function: Callable = None, safe_serialization: bool = True, + unet_lora_adapter_metadata=None, + text_encoder_lora_adapter_metadata=None, ): r""" Save the LoRA parameters corresponding to the UNet and text encoder. @@ -479,8 +481,11 @@ def save_lora_weights( `DIFFUSERS_SAVE_MODE`. safe_serialization (`bool`, *optional*, defaults to `True`): Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`. + unet_lora_adapter_metadata: TODO + text_encoder_lora_adapter_metadata: TODO """ state_dict = {} + lora_adapter_metadata = {} if not (unet_lora_layers or text_encoder_lora_layers): raise ValueError("You must pass at least one of `unet_lora_layers` and `text_encoder_lora_layers`.") @@ -491,6 +496,12 @@ def save_lora_weights( if text_encoder_lora_layers: state_dict.update(cls.pack_weights(text_encoder_lora_layers, cls.text_encoder_name)) + if unet_lora_adapter_metadata is not None: + lora_adapter_metadata.update(cls.pack_weights(unet_lora_adapter_metadata, cls.unet_name)) + + if text_encoder_lora_adapter_metadata: + lora_adapter_metadata.update(cls.pack_weights(text_encoder_lora_adapter_metadata, cls.text_encoder_name)) + # Save the model cls.write_lora_layers( state_dict=state_dict, @@ -499,6 +510,7 @@ def save_lora_weights( weight_name=weight_name, save_function=save_function, safe_serialization=safe_serialization, + lora_adapter_metadata=lora_adapter_metadata, ) def fuse_lora( From f6fde6f4a81fad2cf23c5042febb4960d83208d5 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Thu, 8 May 2025 14:46:43 +0530 Subject: [PATCH 24/53] docs. --- src/diffusers/loaders/lora_pipeline.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/src/diffusers/loaders/lora_pipeline.py b/src/diffusers/loaders/lora_pipeline.py index 4b814d43e823..c85916d1709b 100644 --- a/src/diffusers/loaders/lora_pipeline.py +++ b/src/diffusers/loaders/lora_pipeline.py @@ -2267,6 +2267,8 @@ def save_lora_weights( weight_name: str = None, save_function: Callable = None, safe_serialization: bool = True, + transformer_lora_adapter_metadata=None, + text_encoder_lora_adapter_metadata=None, ): r""" Save the LoRA parameters corresponding to the UNet and text encoder. @@ -2289,8 +2291,11 @@ def save_lora_weights( `DIFFUSERS_SAVE_MODE`. safe_serialization (`bool`, *optional*, defaults to `True`): Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`. + transformer_lora_adapter_metadata: TODO + text_encoder_lora_adapter_metadata: TODO """ state_dict = {} + lora_adapter_metadata = {} if not (transformer_lora_layers or text_encoder_lora_layers): raise ValueError("You must pass at least one of `transformer_lora_layers` and `text_encoder_lora_layers`.") @@ -2301,6 +2306,12 @@ def save_lora_weights( if text_encoder_lora_layers: state_dict.update(cls.pack_weights(text_encoder_lora_layers, cls.text_encoder_name)) + if transformer_lora_adapter_metadata is not None: + lora_adapter_metadata.update(cls.pack_weights(transformer_lora_adapter_metadata, cls.transformer_name)) + + if text_encoder_lora_adapter_metadata: + lora_adapter_metadata.update(cls.pack_weights(text_encoder_lora_adapter_metadata, cls.text_encoder_name)) + # Save the model cls.write_lora_layers( state_dict=state_dict, @@ -2309,6 +2320,7 @@ def save_lora_weights( weight_name=weight_name, save_function=save_function, safe_serialization=safe_serialization, + lora_adapter_metadata=lora_adapter_metadata, ) def fuse_lora( From 87417b21545c6a11e6faa9eef634fc0d1ccaa85d Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Thu, 8 May 2025 15:40:28 +0530 Subject: [PATCH 25/53] fixes --- src/diffusers/loaders/lora_base.py | 4 +++- src/diffusers/loaders/peft.py | 4 +++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/src/diffusers/loaders/lora_base.py b/src/diffusers/loaders/lora_base.py index 016a387a6575..66272009f507 100644 --- a/src/diffusers/loaders/lora_base.py +++ b/src/diffusers/loaders/lora_base.py @@ -353,11 +353,13 @@ def _load_lora_into_text_encoder( raise ValueError("At the moment, hotswapping is not supported for text encoders, please pass `hotswap=False`.") # Load the layers corresponding to text encoder and make necessary adjustments. + metadata = None if LORA_ADAPTER_METADATA_KEY in state_dict: metadata = state_dict[LORA_ADAPTER_METADATA_KEY] if prefix is not None: state_dict = {k.removeprefix(f"{prefix}."): v for k, v in state_dict.items() if k.startswith(f"{prefix}.")} - state_dict[LORA_ADAPTER_METADATA_KEY] = metadata + if metadata is not None: + state_dict[LORA_ADAPTER_METADATA_KEY] = metadata if len(state_dict) > 0: logger.info(f"Loading {prefix}.") diff --git a/src/diffusers/loaders/peft.py b/src/diffusers/loaders/peft.py index 51f0b432bb79..69fdabdc18a6 100644 --- a/src/diffusers/loaders/peft.py +++ b/src/diffusers/loaders/peft.py @@ -234,6 +234,7 @@ def load_lora_adapter( user_agent=user_agent, allow_pickle=allow_pickle, ) + metadata = None if LORA_ADAPTER_METADATA_KEY in state_dict: metadata = state_dict[LORA_ADAPTER_METADATA_KEY] if network_alphas is not None and prefix is None: @@ -241,7 +242,8 @@ def load_lora_adapter( if prefix is not None: state_dict = {k.removeprefix(f"{prefix}."): v for k, v in state_dict.items() if k.startswith(f"{prefix}.")} - state_dict[LORA_ADAPTER_METADATA_KEY] = metadata + if metadata is not None: + state_dict[LORA_ADAPTER_METADATA_KEY] = metadata if len(state_dict) > 0: if adapter_name in getattr(self, "peft_config", {}) and not hotswap: From 16dba2dffe7fdec98cd16ece770ecc9521326bf4 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Fri, 9 May 2025 09:42:24 +0530 Subject: [PATCH 26/53] get review ready. --- src/diffusers/loaders/lora_base.py | 5 ++- src/diffusers/loaders/lora_pipeline.py | 60 +++++++++++++++++--------- src/diffusers/loaders/peft.py | 8 +--- 3 files changed, 46 insertions(+), 27 deletions(-) diff --git a/src/diffusers/loaders/lora_base.py b/src/diffusers/loaders/lora_base.py index 66272009f507..c674708c058e 100644 --- a/src/diffusers/loaders/lora_base.py +++ b/src/diffusers/loaders/lora_base.py @@ -409,7 +409,10 @@ def _load_lora_into_text_encoder( if is_peft_version("<=", "0.13.2"): lora_config_kwargs.pop("lora_bias") - lora_config = LoraConfig(**lora_config_kwargs) + try: + lora_config = LoraConfig(**lora_config_kwargs) + except TypeError as e: + logger.error(f"`LoraConfig` class could not be instantiated:\n{e}.") # adapter_name if adapter_name is None: diff --git a/src/diffusers/loaders/lora_pipeline.py b/src/diffusers/loaders/lora_pipeline.py index ca4a17169161..d567baccb8e7 100644 --- a/src/diffusers/loaders/lora_pipeline.py +++ b/src/diffusers/loaders/lora_pipeline.py @@ -481,8 +481,10 @@ def save_lora_weights( `DIFFUSERS_SAVE_MODE`. safe_serialization (`bool`, *optional*, defaults to `True`): Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`. - unet_lora_adapter_metadata: TODO - text_encoder_lora_adapter_metadata: TODO + unet_lora_adapter_metadata: + LoRA adapter metadata associated with the unet to be serialized with the state dict. + text_encoder_lora_adapter_metadata: + LoRA adapter metadata associated with the text encoder to be serialized with the state dict. """ state_dict = {} lora_adapter_metadata = {} @@ -959,9 +961,12 @@ def save_lora_weights( `DIFFUSERS_SAVE_MODE`. safe_serialization (`bool`, *optional*, defaults to `True`): Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`. - unet_lora_adapter_metadata: TODO - text_encoder_lora_adapter_metadata: TODO - text_encoder_2_lora_adapter_metadata: TODO + unet_lora_adapter_metadata: + LoRA adapter metadata associated with the unet to be serialized with the state dict. + text_encoder_lora_adapter_metadata: + LoRA adapter metadata associated with the text encoder to be serialized with the state dict. + text_encoder_2_lora_adapter_metadata: + LoRA adapter metadata associated with the second text encoder to be serialized with the state dict. """ state_dict = {} lora_adapter_metadata = {} @@ -1394,9 +1399,12 @@ def save_lora_weights( `DIFFUSERS_SAVE_MODE`. safe_serialization (`bool`, *optional*, defaults to `True`): Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`. - transformer_lora_adapter_metadata: TODO - text_encoder_lora_adapter_metadata: TODO - text_encoder_2_lora_adapter_metadata: TODO + transformer_lora_adapter_metadata: + LoRA adapter metadata associated with the transformer to be serialized with the state dict. + text_encoder_lora_adapter_metadata: + LoRA adapter metadata associated with the text encoder to be serialized with the state dict. + text_encoder_2_lora_adapter_metadata: + LoRA adapter metadata associated with the second text encoder to be serialized with the state dict. """ state_dict = {} lora_adapter_metadata = {} @@ -1738,7 +1746,8 @@ def save_lora_weights( `DIFFUSERS_SAVE_MODE`. safe_serialization (`bool`, *optional*, defaults to `True`): Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`. - transformer_lora_adapter_metadata: TODO + transformer_lora_adapter_metadata: + LoRA adapter metadata associated with the transformer to be serialized with the state dict. """ state_dict = {} lora_adapter_metadata = {} @@ -2291,8 +2300,10 @@ def save_lora_weights( `DIFFUSERS_SAVE_MODE`. safe_serialization (`bool`, *optional*, defaults to `True`): Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`. - transformer_lora_adapter_metadata: TODO - text_encoder_lora_adapter_metadata: TODO + transformer_lora_adapter_metadata: + LoRA adapter metadata associated with the transformer to be serialized with the state dict. + text_encoder_lora_adapter_metadata: + LoRA adapter metadata associated with the text encoder to be serialized with the state dict. """ state_dict = {} lora_adapter_metadata = {} @@ -3074,7 +3085,8 @@ def save_lora_weights( `DIFFUSERS_SAVE_MODE`. safe_serialization (`bool`, *optional*, defaults to `True`): Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`. - transformer_lora_adapter_metadata: TODO + transformer_lora_adapter_metadata: + LoRA adapter metadata associated with the transformer to be serialized with the state dict. """ state_dict = {} lora_adapter_metadata = {} @@ -3397,7 +3409,8 @@ def save_lora_weights( `DIFFUSERS_SAVE_MODE`. safe_serialization (`bool`, *optional*, defaults to `True`): Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`. - transformer_lora_adapter_metadata: TODO + transformer_lora_adapter_metadata: + LoRA adapter metadata associated with the transformer to be serialized with the state dict. """ state_dict = {} lora_adapter_metadata = {} @@ -3722,7 +3735,8 @@ def save_lora_weights( `DIFFUSERS_SAVE_MODE`. safe_serialization (`bool`, *optional*, defaults to `True`): Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`. - transformer_lora_adapter_metadata: TODO + transformer_lora_adapter_metadata: + LoRA adapter metadata associated with the transformer to be serialized with the state dict. """ state_dict = {} lora_adapter_metadata = {} @@ -4047,7 +4061,8 @@ def save_lora_weights( `DIFFUSERS_SAVE_MODE`. safe_serialization (`bool`, *optional*, defaults to `True`): Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`. - transformer_lora_adapter_metadata: TODO + transformer_lora_adapter_metadata: + LoRA adapter metadata associated with the transformer to be serialized with the state dict. """ state_dict = {} lora_adapter_metadata = {} @@ -4375,7 +4390,8 @@ def save_lora_weights( `DIFFUSERS_SAVE_MODE`. safe_serialization (`bool`, *optional*, defaults to `True`): Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`. - transformer_lora_adapter_metadata: TODO + transformer_lora_adapter_metadata: + LoRA adapter metadata associated with the transformer to be serialized with the state dict. """ state_dict = {} lora_adapter_metadata = {} @@ -4704,7 +4720,8 @@ def save_lora_weights( `DIFFUSERS_SAVE_MODE`. safe_serialization (`bool`, *optional*, defaults to `True`): Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`. - transformer_lora_adapter_metadata: TODO + transformer_lora_adapter_metadata: + LoRA adapter metadata associated with the transformer to be serialized with the state dict. """ state_dict = {} lora_adapter_metadata = {} @@ -5062,7 +5079,8 @@ def save_lora_weights( `DIFFUSERS_SAVE_MODE`. safe_serialization (`bool`, *optional*, defaults to `True`): Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`. - transformer_lora_adapter_metadata: TODO + transformer_lora_adapter_metadata: + LoRA adapter metadata associated with the transformer to be serialized with the state dict. """ state_dict = {} lora_adapter_metadata = {} @@ -5387,7 +5405,8 @@ def save_lora_weights( `DIFFUSERS_SAVE_MODE`. safe_serialization (`bool`, *optional*, defaults to `True`): Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`. - transformer_lora_adapter_metadata: TODO + transformer_lora_adapter_metadata: + LoRA adapter metadata associated with the transformer to be serialized with the state dict. """ state_dict = {} lora_adapter_metadata = {} @@ -5712,7 +5731,8 @@ def save_lora_weights( `DIFFUSERS_SAVE_MODE`. safe_serialization (`bool`, *optional*, defaults to `True`): Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`. - transformer_lora_adapter_metadata: TODO + transformer_lora_adapter_metadata: + LoRA adapter metadata associated with the transformer to be serialized with the state dict. """ state_dict = {} lora_adapter_metadata = {} diff --git a/src/diffusers/loaders/peft.py b/src/diffusers/loaders/peft.py index 69fdabdc18a6..9ddd9fc506a6 100644 --- a/src/diffusers/loaders/peft.py +++ b/src/diffusers/loaders/peft.py @@ -116,11 +116,7 @@ def _optionally_disable_offloading(cls, _pipeline): return _func_optionally_disable_offloading(_pipeline=_pipeline) def load_lora_adapter( - self, - pretrained_model_name_or_path_or_dict, - prefix="transformer", - hotswap: bool = False, - **kwargs, + self, pretrained_model_name_or_path_or_dict, prefix="transformer", hotswap: bool = False, **kwargs ): r""" Loads a LoRA adapter into the underlying model. @@ -309,7 +305,7 @@ def load_lora_adapter( try: lora_config = LoraConfig(**lora_config_kwargs) except TypeError as e: - logger.error(f"`LoraConfig` class could not be instantiated with the following trace: {e}.") + logger.error(f"`LoraConfig` class could not be instantiated:\n{e}.") # adapter_name if adapter_name is None: From 67bceda5b29b58fda4376ac3deff6228d03c1f3f Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Fri, 9 May 2025 18:23:21 +0530 Subject: [PATCH 27/53] one more test to catch error. --- src/diffusers/loaders/lora_base.py | 2 +- src/diffusers/loaders/peft.py | 2 +- tests/models/test_modeling_common.py | 47 ++++++++++++++++++++++++++++ 3 files changed, 49 insertions(+), 2 deletions(-) diff --git a/src/diffusers/loaders/lora_base.py b/src/diffusers/loaders/lora_base.py index c674708c058e..c6babb283e99 100644 --- a/src/diffusers/loaders/lora_base.py +++ b/src/diffusers/loaders/lora_base.py @@ -412,7 +412,7 @@ def _load_lora_into_text_encoder( try: lora_config = LoraConfig(**lora_config_kwargs) except TypeError as e: - logger.error(f"`LoraConfig` class could not be instantiated:\n{e}.") + raise TypeError(f"`LoraConfig` class could not be instantiated:\n{e}.") # adapter_name if adapter_name is None: diff --git a/src/diffusers/loaders/peft.py b/src/diffusers/loaders/peft.py index 9ddd9fc506a6..fa7e867d87d2 100644 --- a/src/diffusers/loaders/peft.py +++ b/src/diffusers/loaders/peft.py @@ -305,7 +305,7 @@ def load_lora_adapter( try: lora_config = LoraConfig(**lora_config_kwargs) except TypeError as e: - logger.error(f"`LoraConfig` class could not be instantiated:\n{e}.") + raise TypeError(f"`LoraConfig` class could not be instantiated:\n{e}.") # adapter_name if adapter_name is None: diff --git a/tests/models/test_modeling_common.py b/tests/models/test_modeling_common.py index 49342f5affa5..5e6edc2654c3 100644 --- a/tests/models/test_modeling_common.py +++ b/tests/models/test_modeling_common.py @@ -1184,6 +1184,53 @@ def test_lora_adapter_metadata_is_loaded_correctly(self, rank, lora_alpha, use_d parsed_metadata = model.peft_config["default_0"].to_dict() check_if_dicts_are_equal(metadata, parsed_metadata) + @torch.no_grad() + @unittest.skipIf(not is_peft_available(), "Only with PEFT") + def test_lora_adapter_wrong_metadata_raises_error(self): + from peft import LoraConfig + + from diffusers.loaders.lora_base import LORA_ADAPTER_METADATA_KEY + from diffusers.loaders.peft import PeftAdapterMixin + + init_dict, _ = self.prepare_init_args_and_inputs_for_common() + model = self.model_class(**init_dict).to(torch_device) + + if not issubclass(model.__class__, PeftAdapterMixin): + return + + denoiser_lora_config = LoraConfig( + r=4, + lora_alpha=4, + target_modules=["to_q", "to_k", "to_v", "to_out.0"], + init_lora_weights=False, + use_dora=False, + ) + model.add_adapter(denoiser_lora_config) + self.assertTrue(check_if_lora_correctly_set(model), "LoRA layers not set correctly") + + with tempfile.TemporaryDirectory() as tmpdir: + model.save_lora_adapter(tmpdir) + model_file = os.path.join(tmpdir, "pytorch_lora_weights.safetensors") + self.assertTrue(os.path.isfile(model_file)) + + # Perturb the metadata in the state dict. + loaded_state_dict = safetensors.torch.load_file(model_file) + metadata = {"format": "pt"} + lora_adapter_metadata = denoiser_lora_config.to_dict() + lora_adapter_metadata.update({"foo": 1, "bar": 2}) + for key, value in lora_adapter_metadata.items(): + if isinstance(value, set): + lora_adapter_metadata[key] = list(value) + metadata[LORA_ADAPTER_METADATA_KEY] = json.dumps(lora_adapter_metadata, indent=2, sort_keys=True) + safetensors.torch.save_file(loaded_state_dict, model_file, metadata=metadata) + + model.unload_lora() + self.assertFalse(check_if_lora_correctly_set(model), "LoRA layers not set correctly") + + with self.assertRaises(TypeError) as err_context: + model.load_lora_adapter(tmpdir, prefix=None, use_safetensors=True) + self.assertTrue("`LoraConfig` class could not be instantiated" in str(err_context.exception)) + @require_torch_accelerator def test_cpu_offload(self): config, inputs_dict = self.prepare_init_args_and_inputs_for_common() From 4304a6d91d662e23d840bc86b0c375be8c71d5ad Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Thu, 22 May 2025 10:06:26 +0530 Subject: [PATCH 28/53] change to a different approach. --- src/diffusers/loaders/lora_base.py | 37 +++++++++++++++++-------- src/diffusers/loaders/lora_pipeline.py | 26 +++++++++++++---- src/diffusers/loaders/peft.py | 37 ++++++++++++------------- src/diffusers/utils/state_dict_utils.py | 21 ++++++-------- tests/lora/utils.py | 17 ++++++------ 5 files changed, 79 insertions(+), 59 deletions(-) diff --git a/src/diffusers/loaders/lora_base.py b/src/diffusers/loaders/lora_base.py index 6d507eaf34c5..a326d054ac67 100644 --- a/src/diffusers/loaders/lora_base.py +++ b/src/diffusers/loaders/lora_base.py @@ -46,7 +46,7 @@ set_adapter_layers, set_weights_and_activate_adapters, ) -from ..utils.state_dict_utils import _maybe_populate_state_dict_with_metadata +from ..utils.state_dict_utils import _load_sft_state_dict_metadata if is_transformers_available(): @@ -209,6 +209,7 @@ def _fetch_state_dict( subfolder, user_agent, allow_pickle, + metadata=None, ): model_file = None if not isinstance(pretrained_model_name_or_path_or_dict, dict): @@ -240,13 +241,14 @@ def _fetch_state_dict( user_agent=user_agent, ) state_dict = safetensors.torch.load_file(model_file, device="cpu") - state_dict = _maybe_populate_state_dict_with_metadata(state_dict, model_file) + metadata = _load_sft_state_dict_metadata(model_file) except (IOError, safetensors.SafetensorError) as e: if not allow_pickle: raise e # try loading non-safetensors weights model_file = None + metadata = None pass if model_file is None: @@ -267,10 +269,11 @@ def _fetch_state_dict( user_agent=user_agent, ) state_dict = load_state_dict(model_file) + metadata = None else: state_dict = pretrained_model_name_or_path_or_dict - return state_dict + return state_dict, metadata def _best_guess_weight_name( @@ -312,6 +315,11 @@ def _best_guess_weight_name( return weight_name +def _pack_sd_with_prefix(state_dict, prefix): + sd_with_prefix = {f"{prefix}.{key}": value for key, value in state_dict.items()} + return sd_with_prefix + + def _load_lora_into_text_encoder( state_dict, network_alphas, @@ -320,6 +328,7 @@ def _load_lora_into_text_encoder( lora_scale=1.0, text_encoder_name="text_encoder", adapter_name=None, + metadata=None, _pipeline=None, low_cpu_mem_usage=False, hotswap: bool = False, @@ -327,6 +336,9 @@ def _load_lora_into_text_encoder( if not USE_PEFT_BACKEND: raise ValueError("PEFT backend is required for this method.") + if network_alphas and metadata: + raise ValueError("Both `network_alphas` and `metadata` cannot be specified.") + peft_kwargs = {} if low_cpu_mem_usage: if not is_peft_version(">=", "0.13.1"): @@ -353,13 +365,10 @@ def _load_lora_into_text_encoder( raise ValueError("At the moment, hotswapping is not supported for text encoders, please pass `hotswap=False`.") # Load the layers corresponding to text encoder and make necessary adjustments. - metadata = None - if LORA_ADAPTER_METADATA_KEY in state_dict: - metadata = state_dict[LORA_ADAPTER_METADATA_KEY] if prefix is not None: state_dict = {k.removeprefix(f"{prefix}."): v for k, v in state_dict.items() if k.startswith(f"{prefix}.")} - if metadata is not None: - state_dict[LORA_ADAPTER_METADATA_KEY] = metadata + if metadata is not None: + metadata = {k.removeprefix(f"{prefix}."): v for k, v in metadata.items() if k.startswith(f"{prefix}.")} if len(state_dict) > 0: logger.info(f"Loading {prefix}.") @@ -387,7 +396,10 @@ def _load_lora_into_text_encoder( alpha_keys = [k for k in network_alphas.keys() if k.startswith(prefix) and k.split(".")[0] == prefix] network_alphas = {k.removeprefix(f"{prefix}."): v for k, v in network_alphas.items() if k in alpha_keys} - lora_config_kwargs = get_peft_kwargs(rank, network_alphas, state_dict, is_unet=False, prefix=prefix) + if metadata is not None: + lora_config_kwargs = metadata + else: + lora_config_kwargs = get_peft_kwargs(rank, network_alphas, state_dict, is_unet=False, prefix=prefix) if "use_dora" in lora_config_kwargs: if lora_config_kwargs["use_dora"]: @@ -885,8 +897,7 @@ def set_lora_device(self, adapter_names: List[str], device: Union[torch.device, @staticmethod def pack_weights(layers, prefix): layers_weights = layers.state_dict() if isinstance(layers, torch.nn.Module) else layers - layers_state_dict = {f"{prefix}.{module_name}": param for module_name, param in layers_weights.items()} - return layers_state_dict + return _pack_sd_with_prefix(layers_weights, prefix) @staticmethod def write_lora_layers( @@ -917,7 +928,9 @@ def save_function(weights, filename): for key, value in lora_adapter_metadata.items(): if isinstance(value, set): lora_adapter_metadata[key] = list(value) - metadata["lora_adapter_metadata"] = json.dumps(lora_adapter_metadata, indent=2, sort_keys=True) + metadata[LORA_ADAPTER_METADATA_KEY] = json.dumps( + lora_adapter_metadata, indent=2, sort_keys=True + ) return safetensors.torch.save_file(weights, filename, metadata=metadata) diff --git a/src/diffusers/loaders/lora_pipeline.py b/src/diffusers/loaders/lora_pipeline.py index 698453949e41..8eaad820ad6f 100644 --- a/src/diffusers/loaders/lora_pipeline.py +++ b/src/diffusers/loaders/lora_pipeline.py @@ -37,6 +37,7 @@ LoraBaseMixin, _fetch_state_dict, _load_lora_into_text_encoder, + _pack_sd_with_prefix, ) from .lora_conversion_utils import ( _convert_bfl_flux_control_lora_to_diffusers, @@ -197,7 +198,8 @@ def load_lora_weights( pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy() # First, ensure that the checkpoint is a compatible one and can be successfully loaded. - state_dict, network_alphas = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs) + kwargs["return_lora_metadata"] = True + state_dict, network_alphas, metadata = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs) is_correct_format = all("lora" in key for key in state_dict.keys()) if not is_correct_format: @@ -208,6 +210,7 @@ def load_lora_weights( network_alphas=network_alphas, unet=getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet, adapter_name=adapter_name, + metadata=metadata, _pipeline=self, low_cpu_mem_usage=low_cpu_mem_usage, hotswap=hotswap, @@ -221,6 +224,7 @@ def load_lora_weights( lora_scale=self.lora_scale, adapter_name=adapter_name, _pipeline=self, + metadata=metadata, low_cpu_mem_usage=low_cpu_mem_usage, hotswap=hotswap, ) @@ -277,6 +281,7 @@ def lora_state_dict( The subfolder location of a model file within a larger model repository on the Hub or locally. weight_name (`str`, *optional*, defaults to None): Name of the serialized state dict file. + return_lora_metadata: TODO """ # Load the main state dict first which has the LoRA layers for either of # UNet and text encoder or both. @@ -290,6 +295,7 @@ def lora_state_dict( weight_name = kwargs.pop("weight_name", None) unet_config = kwargs.pop("unet_config", None) use_safetensors = kwargs.pop("use_safetensors", None) + return_lora_metadata = kwargs.pop("return_lora_metadata", False) allow_pickle = False if use_safetensors is None: @@ -301,7 +307,7 @@ def lora_state_dict( "framework": "pytorch", } - state_dict = _fetch_state_dict( + state_dict, metadata = _fetch_state_dict( pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict, weight_name=weight_name, use_safetensors=use_safetensors, @@ -338,7 +344,8 @@ def lora_state_dict( state_dict = _maybe_map_sgm_blocks_to_diffusers(state_dict, unet_config) state_dict, network_alphas = _convert_non_diffusers_lora_to_diffusers(state_dict) - return state_dict, network_alphas + out = (state_dict, network_alphas, metadata) if return_lora_metadata else (state_dict, network_alphas) + return out @classmethod def load_lora_into_unet( @@ -347,6 +354,7 @@ def load_lora_into_unet( network_alphas, unet, adapter_name=None, + metadata=None, _pipeline=None, low_cpu_mem_usage=False, hotswap: bool = False, @@ -391,6 +399,7 @@ def load_lora_into_unet( prefix=cls.unet_name, network_alphas=network_alphas, adapter_name=adapter_name, + metadata=metadata, _pipeline=_pipeline, low_cpu_mem_usage=low_cpu_mem_usage, hotswap=hotswap, @@ -405,6 +414,7 @@ def load_lora_into_text_encoder( prefix=None, lora_scale=1.0, adapter_name=None, + metadata=None, _pipeline=None, low_cpu_mem_usage=False, hotswap: bool = False, @@ -430,6 +440,7 @@ def load_lora_into_text_encoder( adapter_name (`str`, *optional*): Adapter name to be used for referencing the loaded adapter model. If not specified, it will use `default_{i}` where i is the total number of adapters being loaded. + metadata: TODO low_cpu_mem_usage (`bool`, *optional*): Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights. @@ -444,6 +455,7 @@ def load_lora_into_text_encoder( prefix=prefix, text_encoder_name=cls.text_encoder_name, adapter_name=adapter_name, + metadata=metadata, _pipeline=_pipeline, low_cpu_mem_usage=low_cpu_mem_usage, hotswap=hotswap, @@ -500,11 +512,13 @@ def save_lora_weights( if text_encoder_lora_layers: state_dict.update(cls.pack_weights(text_encoder_lora_layers, cls.text_encoder_name)) - if unet_lora_adapter_metadata is not None: - lora_adapter_metadata.update(cls.pack_weights(unet_lora_adapter_metadata, cls.unet_name)) + if unet_lora_adapter_metadata: + lora_adapter_metadata.update(_pack_sd_with_prefix(unet_lora_adapter_metadata, cls.unet_name)) if text_encoder_lora_adapter_metadata: - lora_adapter_metadata.update(cls.pack_weights(text_encoder_lora_adapter_metadata, cls.text_encoder_name)) + lora_adapter_metadata.update( + _pack_sd_with_prefix(text_encoder_lora_adapter_metadata, cls.text_encoder_name) + ) # Save the model cls.write_lora_layers( diff --git a/src/diffusers/loaders/peft.py b/src/diffusers/loaders/peft.py index b0ff6c759dfb..b73813aa6d18 100644 --- a/src/diffusers/loaders/peft.py +++ b/src/diffusers/loaders/peft.py @@ -185,13 +185,11 @@ def load_lora_adapter( Note that hotswapping adapters of the text encoder is not yet supported. There are some further limitations to this technique, which are documented here: https://huggingface.co/docs/peft/main/en/package_reference/hotswap - + metadata: TODO """ from peft import LoraConfig, inject_adapter_in_model, set_peft_model_state_dict from peft.tuners.tuners_utils import BaseTunerLayer - from ..loaders.lora_base import LORA_ADAPTER_METADATA_KEY - cache_dir = kwargs.pop("cache_dir", None) force_download = kwargs.pop("force_download", False) proxies = kwargs.pop("proxies", None) @@ -205,6 +203,7 @@ def load_lora_adapter( network_alphas = kwargs.pop("network_alphas", None) _pipeline = kwargs.pop("_pipeline", None) low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", False) + metadata = kwargs.pop("metadata", None) allow_pickle = False if low_cpu_mem_usage and is_peft_version("<=", "0.13.0"): @@ -212,12 +211,9 @@ def load_lora_adapter( "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`." ) - user_agent = { - "file_type": "attn_procs_weights", - "framework": "pytorch", - } + user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"} - state_dict = _fetch_state_dict( + state_dict, metadata = _fetch_state_dict( pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict, weight_name=weight_name, use_safetensors=use_safetensors, @@ -230,17 +226,17 @@ def load_lora_adapter( subfolder=subfolder, user_agent=user_agent, allow_pickle=allow_pickle, + metadata=metadata, ) - metadata = None - if LORA_ADAPTER_METADATA_KEY in state_dict: - metadata = state_dict[LORA_ADAPTER_METADATA_KEY] if network_alphas is not None and prefix is None: raise ValueError("`network_alphas` cannot be None when `prefix` is None.") + if network_alphas and metadata: + raise ValueError("Both `network_alphas` and `metadata` cannot be specified.") if prefix is not None: state_dict = {k.removeprefix(f"{prefix}."): v for k, v in state_dict.items() if k.startswith(f"{prefix}.")} - if metadata is not None: - state_dict[LORA_ADAPTER_METADATA_KEY] = metadata + if metadata is not None: + metadata = {k.removeprefix(f"{prefix}."): v for k, v in metadata.items() if k.startswith(f"{prefix}.")} if len(state_dict) > 0: if adapter_name in getattr(self, "peft_config", {}) and not hotswap: @@ -275,12 +271,15 @@ def load_lora_adapter( k.removeprefix(f"{prefix}."): v for k, v in network_alphas.items() if k in alpha_keys } - lora_config_kwargs = get_peft_kwargs( - rank, - network_alpha_dict=network_alphas, - peft_state_dict=state_dict, - prefix=prefix, - ) + if metadata is not None: + lora_config_kwargs = metadata + else: + lora_config_kwargs = get_peft_kwargs( + rank, + network_alpha_dict=network_alphas, + peft_state_dict=state_dict, + prefix=prefix, + ) _maybe_raise_error_for_ambiguity(lora_config_kwargs) if "use_dora" in lora_config_kwargs: diff --git a/src/diffusers/utils/state_dict_utils.py b/src/diffusers/utils/state_dict_utils.py index 0dbbce5713b7..9e0c208eea92 100644 --- a/src/diffusers/utils/state_dict_utils.py +++ b/src/diffusers/utils/state_dict_utils.py @@ -350,21 +350,16 @@ def state_dict_all_zero(state_dict, filter_str=None): return all(torch.all(param == 0).item() for param in state_dict.values()) -def _maybe_populate_state_dict_with_metadata(state_dict, model_file): - if not model_file.endswith(".safetensors"): - return state_dict - +def _load_sft_state_dict_metadata(model_file: str): import safetensors.torch from ..loaders.lora_base import LORA_ADAPTER_METADATA_KEY - metadata_key = LORA_ADAPTER_METADATA_KEY + metadata = None with safetensors.torch.safe_open(model_file, framework="pt", device="cpu") as f: - if hasattr(f, "metadata"): - metadata = f.metadata() - if metadata is not None: - metadata_keys = list(metadata.keys()) - if not (len(metadata_keys) == 1 and metadata_keys[0] == "format"): - peft_metadata = {k: v for k, v in metadata.items() if k != "format"} - state_dict["lora_adapter_metadata"] = json.loads(peft_metadata[metadata_key]) - return state_dict + metadata = f.metadata() + if metadata is not None: + metadata_keys = list(metadata.keys()) + if not (len(metadata_keys) == 1 and metadata_keys[0] == "format"): + metadata = json.loads(metadata[LORA_ADAPTER_METADATA_KEY]) + return metadata diff --git a/tests/lora/utils.py b/tests/lora/utils.py index 3fe83f102f94..82fe8f071e2a 100644 --- a/tests/lora/utils.py +++ b/tests/lora/utils.py @@ -30,7 +30,6 @@ LCMScheduler, UNet2DConditionModel, ) -from diffusers.loaders.lora_base import LORA_ADAPTER_METADATA_KEY from diffusers.utils import logging from diffusers.utils.import_utils import is_peft_available from diffusers.utils.testing_utils import ( @@ -2187,32 +2186,32 @@ def test_lora_adapter_metadata_is_loaded_correctly(self, lora_alpha): self.pipeline_class.save_lora_weights(save_directory=tmpdir, **lora_state_dicts, **lora_metadatas) pipe.unload_lora_weights() - out = pipe.lora_state_dict(tmpdir) - if isinstance(out, tuple): - state_dict, _ = out - else: - state_dict = out - - self.assertTrue(LORA_ADAPTER_METADATA_KEY in state_dict) + out = pipe.lora_state_dict(tmpdir, return_lora_metadata=True) + if len(out) == 3: + _, _, parsed_metadata = out + elif len(out) == 2: + _, parsed_metadata = out - parsed_metadata = state_dict[LORA_ADAPTER_METADATA_KEY] denoiser_key = ( f"{self.pipeline_class.transformer_name}" if self.transformer_kwargs is not None else f"{self.pipeline_class.unet_name}" ) + self.assertTrue(any(k.startswith(f"{denoiser_key}.") for k in parsed_metadata)) check_module_lora_metadata( parsed_metadata=parsed_metadata, lora_metadatas=lora_metadatas, module_key=denoiser_key ) if "text_encoder" in self.pipeline_class._lora_loadable_modules: text_encoder_key = self.pipeline_class.text_encoder_name + self.assertTrue(any(k.startswith(f"{text_encoder_key}.") for k in parsed_metadata)) check_module_lora_metadata( parsed_metadata=parsed_metadata, lora_metadatas=lora_metadatas, module_key=text_encoder_key ) if "text_encoder_2" in self.pipeline_class._lora_loadable_modules: text_encoder_2_key = "text_encoder_2" + self.assertTrue(any(k.startswith(f"{text_encoder_2_key}.") for k in parsed_metadata)) check_module_lora_metadata( parsed_metadata=parsed_metadata, lora_metadatas=lora_metadatas, module_key=text_encoder_2_key ) From 425ea95fb264c4177940b35665e97a085c1f3d22 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Thu, 22 May 2025 10:07:13 +0530 Subject: [PATCH 29/53] fix-copies. --- src/diffusers/loaders/lora_pipeline.py | 29 +++++++++++++++++++++----- 1 file changed, 24 insertions(+), 5 deletions(-) diff --git a/src/diffusers/loaders/lora_pipeline.py b/src/diffusers/loaders/lora_pipeline.py index 8eaad820ad6f..6ea7457ddea5 100644 --- a/src/diffusers/loaders/lora_pipeline.py +++ b/src/diffusers/loaders/lora_pipeline.py @@ -764,6 +764,7 @@ def lora_state_dict( The subfolder location of a model file within a larger model repository on the Hub or locally. weight_name (`str`, *optional*, defaults to None): Name of the serialized state dict file. + return_lora_metadata: TODO """ # Load the main state dict first which has the LoRA layers for either of # UNet and text encoder or both. @@ -777,6 +778,7 @@ def lora_state_dict( weight_name = kwargs.pop("weight_name", None) unet_config = kwargs.pop("unet_config", None) use_safetensors = kwargs.pop("use_safetensors", None) + return_lora_metadata = kwargs.pop("return_lora_metadata", False) allow_pickle = False if use_safetensors is None: @@ -788,7 +790,7 @@ def lora_state_dict( "framework": "pytorch", } - state_dict = _fetch_state_dict( + state_dict, metadata = _fetch_state_dict( pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict, weight_name=weight_name, use_safetensors=use_safetensors, @@ -825,7 +827,8 @@ def lora_state_dict( state_dict = _maybe_map_sgm_blocks_to_diffusers(state_dict, unet_config) state_dict, network_alphas = _convert_non_diffusers_lora_to_diffusers(state_dict) - return state_dict, network_alphas + out = (state_dict, network_alphas, metadata) if return_lora_metadata else (state_dict, network_alphas) + return out @classmethod # Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.load_lora_into_unet @@ -835,6 +838,7 @@ def load_lora_into_unet( network_alphas, unet, adapter_name=None, + metadata=None, _pipeline=None, low_cpu_mem_usage=False, hotswap: bool = False, @@ -879,6 +883,7 @@ def load_lora_into_unet( prefix=cls.unet_name, network_alphas=network_alphas, adapter_name=adapter_name, + metadata=metadata, _pipeline=_pipeline, low_cpu_mem_usage=low_cpu_mem_usage, hotswap=hotswap, @@ -894,6 +899,7 @@ def load_lora_into_text_encoder( prefix=None, lora_scale=1.0, adapter_name=None, + metadata=None, _pipeline=None, low_cpu_mem_usage=False, hotswap: bool = False, @@ -919,6 +925,7 @@ def load_lora_into_text_encoder( adapter_name (`str`, *optional*): Adapter name to be used for referencing the loaded adapter model. If not specified, it will use `default_{i}` where i is the total number of adapters being loaded. + metadata: TODO low_cpu_mem_usage (`bool`, *optional*): Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights. @@ -933,6 +940,7 @@ def load_lora_into_text_encoder( prefix=prefix, text_encoder_name=cls.text_encoder_name, adapter_name=adapter_name, + metadata=metadata, _pipeline=_pipeline, low_cpu_mem_usage=low_cpu_mem_usage, hotswap=hotswap, @@ -1331,6 +1339,7 @@ def load_lora_into_text_encoder( prefix=None, lora_scale=1.0, adapter_name=None, + metadata=None, _pipeline=None, low_cpu_mem_usage=False, hotswap: bool = False, @@ -1356,6 +1365,7 @@ def load_lora_into_text_encoder( adapter_name (`str`, *optional*): Adapter name to be used for referencing the loaded adapter model. If not specified, it will use `default_{i}` where i is the total number of adapters being loaded. + metadata: TODO low_cpu_mem_usage (`bool`, *optional*): Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights. @@ -1370,6 +1380,7 @@ def load_lora_into_text_encoder( prefix=prefix, text_encoder_name=cls.text_encoder_name, adapter_name=adapter_name, + metadata=metadata, _pipeline=_pipeline, low_cpu_mem_usage=low_cpu_mem_usage, hotswap=hotswap, @@ -2237,6 +2248,7 @@ def load_lora_into_text_encoder( prefix=None, lora_scale=1.0, adapter_name=None, + metadata=None, _pipeline=None, low_cpu_mem_usage=False, hotswap: bool = False, @@ -2262,6 +2274,7 @@ def load_lora_into_text_encoder( adapter_name (`str`, *optional*): Adapter name to be used for referencing the loaded adapter model. If not specified, it will use `default_{i}` where i is the total number of adapters being loaded. + metadata: TODO low_cpu_mem_usage (`bool`, *optional*): Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights. @@ -2276,6 +2289,7 @@ def load_lora_into_text_encoder( prefix=prefix, text_encoder_name=cls.text_encoder_name, adapter_name=adapter_name, + metadata=metadata, _pipeline=_pipeline, low_cpu_mem_usage=low_cpu_mem_usage, hotswap=hotswap, @@ -2333,11 +2347,13 @@ def save_lora_weights( if text_encoder_lora_layers: state_dict.update(cls.pack_weights(text_encoder_lora_layers, cls.text_encoder_name)) - if transformer_lora_adapter_metadata is not None: - lora_adapter_metadata.update(cls.pack_weights(transformer_lora_adapter_metadata, cls.transformer_name)) + if transformer_lora_adapter_metadata: + lora_adapter_metadata.update(_pack_sd_with_prefix(transformer_lora_adapter_metadata, cls.transformer_name)) if text_encoder_lora_adapter_metadata: - lora_adapter_metadata.update(cls.pack_weights(text_encoder_lora_adapter_metadata, cls.text_encoder_name)) + lora_adapter_metadata.update( + _pack_sd_with_prefix(text_encoder_lora_adapter_metadata, cls.text_encoder_name) + ) # Save the model cls.write_lora_layers( @@ -2769,6 +2785,7 @@ def load_lora_into_text_encoder( prefix=None, lora_scale=1.0, adapter_name=None, + metadata=None, _pipeline=None, low_cpu_mem_usage=False, hotswap: bool = False, @@ -2794,6 +2811,7 @@ def load_lora_into_text_encoder( adapter_name (`str`, *optional*): Adapter name to be used for referencing the loaded adapter model. If not specified, it will use `default_{i}` where i is the total number of adapters being loaded. + metadata: TODO low_cpu_mem_usage (`bool`, *optional*): Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights. @@ -2808,6 +2826,7 @@ def load_lora_into_text_encoder( prefix=prefix, text_encoder_name=cls.text_encoder_name, adapter_name=adapter_name, + metadata=metadata, _pipeline=_pipeline, low_cpu_mem_usage=low_cpu_mem_usage, hotswap=hotswap, From e08830ef87744074303185c127b628c2119bce6d Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Thu, 22 May 2025 10:08:06 +0530 Subject: [PATCH 30/53] todo --- src/diffusers/loaders/lora_pipeline.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/diffusers/loaders/lora_pipeline.py b/src/diffusers/loaders/lora_pipeline.py index 6ea7457ddea5..a60aa983dffc 100644 --- a/src/diffusers/loaders/lora_pipeline.py +++ b/src/diffusers/loaders/lora_pipeline.py @@ -376,6 +376,7 @@ def load_lora_into_unet( adapter_name (`str`, *optional*): Adapter name to be used for referencing the loaded adapter model. If not specified, it will use `default_{i}` where i is the total number of adapters being loaded. + metadata: TODO low_cpu_mem_usage (`bool`, *optional*): Speed up model loading only loading the pretrained LoRA weights and not initializing the random weights. @@ -860,6 +861,7 @@ def load_lora_into_unet( adapter_name (`str`, *optional*): Adapter name to be used for referencing the loaded adapter model. If not specified, it will use `default_{i}` where i is the total number of adapters being loaded. + metadata: TODO low_cpu_mem_usage (`bool`, *optional*): Speed up model loading only loading the pretrained LoRA weights and not initializing the random weights. From 40f5c974e0b7930c6553669bcb55f25126cd64ee Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Thu, 22 May 2025 10:14:46 +0530 Subject: [PATCH 31/53] sd3 --- src/diffusers/loaders/lora_pipeline.py | 186 +++++++++++++++++++------ 1 file changed, 144 insertions(+), 42 deletions(-) diff --git a/src/diffusers/loaders/lora_pipeline.py b/src/diffusers/loaders/lora_pipeline.py index a60aa983dffc..a5fb0184cfa8 100644 --- a/src/diffusers/loaders/lora_pipeline.py +++ b/src/diffusers/loaders/lora_pipeline.py @@ -665,7 +665,8 @@ def load_lora_weights( pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy() # First, ensure that the checkpoint is a compatible one and can be successfully loaded. - state_dict, network_alphas = self.lora_state_dict( + kwargs["return_lora_metadata"] = True + state_dict, network_alphas, metadata = self.lora_state_dict( pretrained_model_name_or_path_or_dict, unet_config=self.unet.config, **kwargs, @@ -675,19 +676,16 @@ def load_lora_weights( if not is_correct_format: raise ValueError("Invalid LoRA checkpoint.") - from .lora_base import LORA_ADAPTER_METADATA_KEY - - print(f"{LORA_ADAPTER_METADATA_KEY in state_dict=} before UNet") self.load_lora_into_unet( state_dict, network_alphas=network_alphas, unet=self.unet, adapter_name=adapter_name, + metadata=metadata, _pipeline=self, low_cpu_mem_usage=low_cpu_mem_usage, hotswap=hotswap, ) - print(f"{LORA_ADAPTER_METADATA_KEY in state_dict=} before text encoder.") self.load_lora_into_text_encoder( state_dict, network_alphas=network_alphas, @@ -695,11 +693,11 @@ def load_lora_weights( prefix=self.text_encoder_name, lora_scale=self.lora_scale, adapter_name=adapter_name, + metadata=metadata, _pipeline=self, low_cpu_mem_usage=low_cpu_mem_usage, hotswap=hotswap, ) - print(f"{LORA_ADAPTER_METADATA_KEY in state_dict=} before text encoder 2.") self.load_lora_into_text_encoder( state_dict, network_alphas=network_alphas, @@ -707,6 +705,7 @@ def load_lora_weights( prefix=f"{self.text_encoder_name}_2", lora_scale=self.lora_scale, adapter_name=adapter_name, + metadata=metadata, _pipeline=self, low_cpu_mem_usage=low_cpu_mem_usage, hotswap=hotswap, @@ -1161,6 +1160,7 @@ def lora_state_dict( allowed by Git. subfolder (`str`, *optional*, defaults to `""`): The subfolder location of a model file within a larger model repository on the Hub or locally. + return_lora_metadata: TODO """ # Load the main state dict first which has the LoRA layers for either of @@ -1174,18 +1174,16 @@ def lora_state_dict( subfolder = kwargs.pop("subfolder", None) weight_name = kwargs.pop("weight_name", None) use_safetensors = kwargs.pop("use_safetensors", None) + return_lora_metadata = kwargs.pop("return_lora_metadata", False) allow_pickle = False if use_safetensors is None: use_safetensors = True allow_pickle = True - user_agent = { - "file_type": "attn_procs_weights", - "framework": "pytorch", - } + user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"} - state_dict = _fetch_state_dict( + state_dict, metadata = _fetch_state_dict( pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict, weight_name=weight_name, use_safetensors=use_safetensors, @@ -1206,7 +1204,8 @@ def lora_state_dict( logger.warning(warn_msg) state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k} - return state_dict + out = (state_dict, metadata) if return_lora_metadata else state_dict + return out def load_lora_weights( self, @@ -1255,7 +1254,8 @@ def load_lora_weights( pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy() # First, ensure that the checkpoint is a compatible one and can be successfully loaded. - state_dict = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs) + kwargs["return_lora_metadata"] = True + state_dict, metadata = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs) is_correct_format = all("lora" in key for key in state_dict.keys()) if not is_correct_format: @@ -1265,6 +1265,7 @@ def load_lora_weights( state_dict, transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer, adapter_name=adapter_name, + metadata=metadata, _pipeline=self, low_cpu_mem_usage=low_cpu_mem_usage, hotswap=hotswap, @@ -1276,6 +1277,7 @@ def load_lora_weights( prefix=self.text_encoder_name, lora_scale=self.lora_scale, adapter_name=adapter_name, + metadata=metadata, _pipeline=self, low_cpu_mem_usage=low_cpu_mem_usage, hotswap=hotswap, @@ -1287,6 +1289,7 @@ def load_lora_weights( prefix=f"{self.text_encoder_name}_2", lora_scale=self.lora_scale, adapter_name=adapter_name, + metadata=metadata, _pipeline=self, low_cpu_mem_usage=low_cpu_mem_usage, hotswap=hotswap, @@ -1294,7 +1297,14 @@ def load_lora_weights( @classmethod def load_lora_into_transformer( - cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False, hotswap: bool = False + cls, + state_dict, + transformer, + adapter_name=None, + metadata=None, + _pipeline=None, + low_cpu_mem_usage=False, + hotswap: bool = False, ): """ This will load the LoRA layers specified in `state_dict` into `transformer`. @@ -1309,6 +1319,7 @@ def load_lora_into_transformer( adapter_name (`str`, *optional*): Adapter name to be used for referencing the loaded adapter model. If not specified, it will use `default_{i}` where i is the total number of adapters being loaded. + metadata: TODO low_cpu_mem_usage (`bool`, *optional*): Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights. @@ -1326,6 +1337,7 @@ def load_lora_into_transformer( state_dict, network_alphas=None, adapter_name=adapter_name, + metadata=metadata, _pipeline=_pipeline, low_cpu_mem_usage=low_cpu_mem_usage, hotswap=hotswap, @@ -1708,7 +1720,14 @@ def load_lora_weights( @classmethod # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->AuraFlowTransformer2DModel def load_lora_into_transformer( - cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False, hotswap: bool = False + cls, + state_dict, + transformer, + adapter_name=None, + metadata=None, + _pipeline=None, + low_cpu_mem_usage=False, + hotswap: bool = False, ): """ This will load the LoRA layers specified in `state_dict` into `transformer`. @@ -1723,6 +1742,7 @@ def load_lora_into_transformer( adapter_name (`str`, *optional*): Adapter name to be used for referencing the loaded adapter model. If not specified, it will use `default_{i}` where i is the total number of adapters being loaded. + metadata: TODO low_cpu_mem_usage (`bool`, *optional*): Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights. @@ -1740,6 +1760,7 @@ def load_lora_into_transformer( state_dict, network_alphas=None, adapter_name=adapter_name, + metadata=metadata, _pipeline=_pipeline, low_cpu_mem_usage=low_cpu_mem_usage, hotswap=hotswap, @@ -2948,6 +2969,7 @@ def lora_state_dict( allowed by Git. subfolder (`str`, *optional*, defaults to `""`): The subfolder location of a model file within a larger model repository on the Hub or locally. + return_lora_metadata: TODO """ # Load the main state dict first which has the LoRA layers for either of @@ -2961,18 +2983,16 @@ def lora_state_dict( subfolder = kwargs.pop("subfolder", None) weight_name = kwargs.pop("weight_name", None) use_safetensors = kwargs.pop("use_safetensors", None) + return_lora_metadata = kwargs.pop("return_lora_metadata", False) allow_pickle = False if use_safetensors is None: use_safetensors = True allow_pickle = True - user_agent = { - "file_type": "attn_procs_weights", - "framework": "pytorch", - } + user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"} - state_dict = _fetch_state_dict( + state_dict, metadata = _fetch_state_dict( pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict, weight_name=weight_name, use_safetensors=use_safetensors, @@ -2993,7 +3013,8 @@ def lora_state_dict( logger.warning(warn_msg) state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k} - return state_dict + out = (state_dict, metadata) if return_lora_metadata else state_dict + return out def load_lora_weights( self, @@ -3055,7 +3076,14 @@ def load_lora_weights( @classmethod # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->CogVideoXTransformer3DModel def load_lora_into_transformer( - cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False, hotswap: bool = False + cls, + state_dict, + transformer, + adapter_name=None, + metadata=None, + _pipeline=None, + low_cpu_mem_usage=False, + hotswap: bool = False, ): """ This will load the LoRA layers specified in `state_dict` into `transformer`. @@ -3070,6 +3098,7 @@ def load_lora_into_transformer( adapter_name (`str`, *optional*): Adapter name to be used for referencing the loaded adapter model. If not specified, it will use `default_{i}` where i is the total number of adapters being loaded. + metadata: TODO low_cpu_mem_usage (`bool`, *optional*): Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights. @@ -3087,6 +3116,7 @@ def load_lora_into_transformer( state_dict, network_alphas=None, adapter_name=adapter_name, + metadata=metadata, _pipeline=_pipeline, low_cpu_mem_usage=low_cpu_mem_usage, hotswap=hotswap, @@ -3271,6 +3301,7 @@ def lora_state_dict( allowed by Git. subfolder (`str`, *optional*, defaults to `""`): The subfolder location of a model file within a larger model repository on the Hub or locally. + return_lora_metadata: TODO """ # Load the main state dict first which has the LoRA layers for either of @@ -3284,18 +3315,16 @@ def lora_state_dict( subfolder = kwargs.pop("subfolder", None) weight_name = kwargs.pop("weight_name", None) use_safetensors = kwargs.pop("use_safetensors", None) + return_lora_metadata = kwargs.pop("return_lora_metadata", False) allow_pickle = False if use_safetensors is None: use_safetensors = True allow_pickle = True - user_agent = { - "file_type": "attn_procs_weights", - "framework": "pytorch", - } + user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"} - state_dict = _fetch_state_dict( + state_dict, metadata = _fetch_state_dict( pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict, weight_name=weight_name, use_safetensors=use_safetensors, @@ -3316,7 +3345,8 @@ def lora_state_dict( logger.warning(warn_msg) state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k} - return state_dict + out = (state_dict, metadata) if return_lora_metadata else state_dict + return out # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights def load_lora_weights( @@ -3379,7 +3409,14 @@ def load_lora_weights( @classmethod # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->MochiTransformer3DModel def load_lora_into_transformer( - cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False, hotswap: bool = False + cls, + state_dict, + transformer, + adapter_name=None, + metadata=None, + _pipeline=None, + low_cpu_mem_usage=False, + hotswap: bool = False, ): """ This will load the LoRA layers specified in `state_dict` into `transformer`. @@ -3394,6 +3431,7 @@ def load_lora_into_transformer( adapter_name (`str`, *optional*): Adapter name to be used for referencing the loaded adapter model. If not specified, it will use `default_{i}` where i is the total number of adapters being loaded. + metadata: TODO low_cpu_mem_usage (`bool`, *optional*): Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights. @@ -3411,6 +3449,7 @@ def load_lora_into_transformer( state_dict, network_alphas=None, adapter_name=adapter_name, + metadata=metadata, _pipeline=_pipeline, low_cpu_mem_usage=low_cpu_mem_usage, hotswap=hotswap, @@ -3708,7 +3747,14 @@ def load_lora_weights( @classmethod # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->LTXVideoTransformer3DModel def load_lora_into_transformer( - cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False, hotswap: bool = False + cls, + state_dict, + transformer, + adapter_name=None, + metadata=None, + _pipeline=None, + low_cpu_mem_usage=False, + hotswap: bool = False, ): """ This will load the LoRA layers specified in `state_dict` into `transformer`. @@ -3723,6 +3769,7 @@ def load_lora_into_transformer( adapter_name (`str`, *optional*): Adapter name to be used for referencing the loaded adapter model. If not specified, it will use `default_{i}` where i is the total number of adapters being loaded. + metadata: TODO low_cpu_mem_usage (`bool`, *optional*): Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights. @@ -3740,6 +3787,7 @@ def load_lora_into_transformer( state_dict, network_alphas=None, adapter_name=adapter_name, + metadata=metadata, _pipeline=_pipeline, low_cpu_mem_usage=low_cpu_mem_usage, hotswap=hotswap, @@ -3926,6 +3974,7 @@ def lora_state_dict( allowed by Git. subfolder (`str`, *optional*, defaults to `""`): The subfolder location of a model file within a larger model repository on the Hub or locally. + return_lora_metadata: TODO """ # Load the main state dict first which has the LoRA layers for either of @@ -3939,18 +3988,16 @@ def lora_state_dict( subfolder = kwargs.pop("subfolder", None) weight_name = kwargs.pop("weight_name", None) use_safetensors = kwargs.pop("use_safetensors", None) + return_lora_metadata = kwargs.pop("return_lora_metadata", False) allow_pickle = False if use_safetensors is None: use_safetensors = True allow_pickle = True - user_agent = { - "file_type": "attn_procs_weights", - "framework": "pytorch", - } + user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"} - state_dict = _fetch_state_dict( + state_dict, metadata = _fetch_state_dict( pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict, weight_name=weight_name, use_safetensors=use_safetensors, @@ -3971,7 +4018,8 @@ def lora_state_dict( logger.warning(warn_msg) state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k} - return state_dict + out = (state_dict, metadata) if return_lora_metadata else state_dict + return out # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights def load_lora_weights( @@ -4034,7 +4082,14 @@ def load_lora_weights( @classmethod # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->SanaTransformer2DModel def load_lora_into_transformer( - cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False, hotswap: bool = False + cls, + state_dict, + transformer, + adapter_name=None, + metadata=None, + _pipeline=None, + low_cpu_mem_usage=False, + hotswap: bool = False, ): """ This will load the LoRA layers specified in `state_dict` into `transformer`. @@ -4049,6 +4104,7 @@ def load_lora_into_transformer( adapter_name (`str`, *optional*): Adapter name to be used for referencing the loaded adapter model. If not specified, it will use `default_{i}` where i is the total number of adapters being loaded. + metadata: TODO low_cpu_mem_usage (`bool`, *optional*): Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights. @@ -4066,6 +4122,7 @@ def load_lora_into_transformer( state_dict, network_alphas=None, adapter_name=adapter_name, + metadata=metadata, _pipeline=_pipeline, low_cpu_mem_usage=low_cpu_mem_usage, hotswap=hotswap, @@ -4363,7 +4420,14 @@ def load_lora_weights( @classmethod # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->HunyuanVideoTransformer3DModel def load_lora_into_transformer( - cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False, hotswap: bool = False + cls, + state_dict, + transformer, + adapter_name=None, + metadata=None, + _pipeline=None, + low_cpu_mem_usage=False, + hotswap: bool = False, ): """ This will load the LoRA layers specified in `state_dict` into `transformer`. @@ -4378,6 +4442,7 @@ def load_lora_into_transformer( adapter_name (`str`, *optional*): Adapter name to be used for referencing the loaded adapter model. If not specified, it will use `default_{i}` where i is the total number of adapters being loaded. + metadata: TODO low_cpu_mem_usage (`bool`, *optional*): Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights. @@ -4395,6 +4460,7 @@ def load_lora_into_transformer( state_dict, network_alphas=None, adapter_name=adapter_name, + metadata=metadata, _pipeline=_pipeline, low_cpu_mem_usage=low_cpu_mem_usage, hotswap=hotswap, @@ -4693,7 +4759,14 @@ def load_lora_weights( @classmethod # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->Lumina2Transformer2DModel def load_lora_into_transformer( - cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False, hotswap: bool = False + cls, + state_dict, + transformer, + adapter_name=None, + metadata=None, + _pipeline=None, + low_cpu_mem_usage=False, + hotswap: bool = False, ): """ This will load the LoRA layers specified in `state_dict` into `transformer`. @@ -4708,6 +4781,7 @@ def load_lora_into_transformer( adapter_name (`str`, *optional*): Adapter name to be used for referencing the loaded adapter model. If not specified, it will use `default_{i}` where i is the total number of adapters being loaded. + metadata: TODO low_cpu_mem_usage (`bool`, *optional*): Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights. @@ -4725,6 +4799,7 @@ def load_lora_into_transformer( state_dict, network_alphas=None, adapter_name=adapter_name, + metadata=metadata, _pipeline=_pipeline, low_cpu_mem_usage=low_cpu_mem_usage, hotswap=hotswap, @@ -5073,7 +5148,14 @@ def load_lora_weights( @classmethod # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->WanTransformer3DModel def load_lora_into_transformer( - cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False, hotswap: bool = False + cls, + state_dict, + transformer, + adapter_name=None, + metadata=None, + _pipeline=None, + low_cpu_mem_usage=False, + hotswap: bool = False, ): """ This will load the LoRA layers specified in `state_dict` into `transformer`. @@ -5088,6 +5170,7 @@ def load_lora_into_transformer( adapter_name (`str`, *optional*): Adapter name to be used for referencing the loaded adapter model. If not specified, it will use `default_{i}` where i is the total number of adapters being loaded. + metadata: TODO low_cpu_mem_usage (`bool`, *optional*): Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights. @@ -5105,6 +5188,7 @@ def load_lora_into_transformer( state_dict, network_alphas=None, adapter_name=adapter_name, + metadata=metadata, _pipeline=_pipeline, low_cpu_mem_usage=low_cpu_mem_usage, hotswap=hotswap, @@ -5399,7 +5483,14 @@ def load_lora_weights( @classmethod # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->CogView4Transformer2DModel def load_lora_into_transformer( - cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False, hotswap: bool = False + cls, + state_dict, + transformer, + adapter_name=None, + metadata=None, + _pipeline=None, + low_cpu_mem_usage=False, + hotswap: bool = False, ): """ This will load the LoRA layers specified in `state_dict` into `transformer`. @@ -5414,6 +5505,7 @@ def load_lora_into_transformer( adapter_name (`str`, *optional*): Adapter name to be used for referencing the loaded adapter model. If not specified, it will use `default_{i}` where i is the total number of adapters being loaded. + metadata: TODO low_cpu_mem_usage (`bool`, *optional*): Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights. @@ -5431,6 +5523,7 @@ def load_lora_into_transformer( state_dict, network_alphas=None, adapter_name=adapter_name, + metadata=metadata, _pipeline=_pipeline, low_cpu_mem_usage=low_cpu_mem_usage, hotswap=hotswap, @@ -5728,7 +5821,14 @@ def load_lora_weights( @classmethod # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->HiDreamImageTransformer2DModel def load_lora_into_transformer( - cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False, hotswap: bool = False + cls, + state_dict, + transformer, + adapter_name=None, + metadata=None, + _pipeline=None, + low_cpu_mem_usage=False, + hotswap: bool = False, ): """ This will load the LoRA layers specified in `state_dict` into `transformer`. @@ -5743,6 +5843,7 @@ def load_lora_into_transformer( adapter_name (`str`, *optional*): Adapter name to be used for referencing the loaded adapter model. If not specified, it will use `default_{i}` where i is the total number of adapters being loaded. + metadata: TODO low_cpu_mem_usage (`bool`, *optional*): Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights. @@ -5760,6 +5861,7 @@ def load_lora_into_transformer( state_dict, network_alphas=None, adapter_name=adapter_name, + metadata=metadata, _pipeline=_pipeline, low_cpu_mem_usage=low_cpu_mem_usage, hotswap=hotswap, From 5a2a02322017f6384fec0d1810a9ff0c9fa7645c Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Thu, 22 May 2025 10:39:31 +0530 Subject: [PATCH 32/53] update --- src/diffusers/loaders/lora_pipeline.py | 68 +++++++++++++++++--------- 1 file changed, 44 insertions(+), 24 deletions(-) diff --git a/src/diffusers/loaders/lora_pipeline.py b/src/diffusers/loaders/lora_pipeline.py index a5fb0184cfa8..9d4c0b8fc5bb 100644 --- a/src/diffusers/loaders/lora_pipeline.py +++ b/src/diffusers/loaders/lora_pipeline.py @@ -1612,6 +1612,7 @@ def lora_state_dict( allowed by Git. subfolder (`str`, *optional*, defaults to `""`): The subfolder location of a model file within a larger model repository on the Hub or locally. + return_lora_metadata: TODO """ # Load the main state dict first which has the LoRA layers for either of @@ -1625,18 +1626,16 @@ def lora_state_dict( subfolder = kwargs.pop("subfolder", None) weight_name = kwargs.pop("weight_name", None) use_safetensors = kwargs.pop("use_safetensors", None) + return_lora_metadata = kwargs.pop("return_lora_metadata", False) allow_pickle = False if use_safetensors is None: use_safetensors = True allow_pickle = True - user_agent = { - "file_type": "attn_procs_weights", - "framework": "pytorch", - } + user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"} - state_dict = _fetch_state_dict( + state_dict, metadata = _fetch_state_dict( pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict, weight_name=weight_name, use_safetensors=use_safetensors, @@ -1657,7 +1656,8 @@ def lora_state_dict( logger.warning(warn_msg) state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k} - return state_dict + out = (state_dict, metadata) if return_lora_metadata else state_dict + return out # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights def load_lora_weights( @@ -1702,7 +1702,8 @@ def load_lora_weights( pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy() # First, ensure that the checkpoint is a compatible one and can be successfully loaded. - state_dict = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs) + kwargs["return_lora_metadata"] = True + state_dict, metadata = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs) is_correct_format = all("lora" in key for key in state_dict.keys()) if not is_correct_format: @@ -1712,6 +1713,7 @@ def load_lora_weights( state_dict, transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer, adapter_name=adapter_name, + metadata=metadata, _pipeline=self, low_cpu_mem_usage=low_cpu_mem_usage, hotswap=hotswap, @@ -3058,7 +3060,8 @@ def load_lora_weights( pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy() # First, ensure that the checkpoint is a compatible one and can be successfully loaded. - state_dict = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs) + kwargs["return_lora_metadata"] = True + state_dict, metadata = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs) is_correct_format = all("lora" in key for key in state_dict.keys()) if not is_correct_format: @@ -3068,6 +3071,7 @@ def load_lora_weights( state_dict, transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer, adapter_name=adapter_name, + metadata=metadata, _pipeline=self, low_cpu_mem_usage=low_cpu_mem_usage, hotswap=hotswap, @@ -3391,7 +3395,8 @@ def load_lora_weights( pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy() # First, ensure that the checkpoint is a compatible one and can be successfully loaded. - state_dict = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs) + kwargs["return_lora_metadata"] = True + state_dict, metadata = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs) is_correct_format = all("lora" in key for key in state_dict.keys()) if not is_correct_format: @@ -3401,6 +3406,7 @@ def load_lora_weights( state_dict, transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer, adapter_name=adapter_name, + metadata=metadata, _pipeline=self, low_cpu_mem_usage=low_cpu_mem_usage, hotswap=hotswap, @@ -3635,7 +3641,7 @@ def lora_state_dict( allowed by Git. subfolder (`str`, *optional*, defaults to `""`): The subfolder location of a model file within a larger model repository on the Hub or locally. - + return_lora_metadata: TODO """ # Load the main state dict first which has the LoRA layers for either of # transformer and text encoder or both. @@ -3648,6 +3654,7 @@ def lora_state_dict( subfolder = kwargs.pop("subfolder", None) weight_name = kwargs.pop("weight_name", None) use_safetensors = kwargs.pop("use_safetensors", None) + return_lora_metadata = kwargs.pop("return_lora_metadata", False) allow_pickle = False if use_safetensors is None: @@ -3659,7 +3666,7 @@ def lora_state_dict( "framework": "pytorch", } - state_dict = _fetch_state_dict( + state_dict, metadata = _fetch_state_dict( pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict, weight_name=weight_name, use_safetensors=use_safetensors, @@ -3684,7 +3691,8 @@ def lora_state_dict( if is_non_diffusers_format: state_dict = _convert_non_diffusers_ltxv_lora_to_diffusers(state_dict) - return state_dict + out = (state_dict, metadata) if return_lora_metadata else state_dict + return out # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights def load_lora_weights( @@ -3729,7 +3737,8 @@ def load_lora_weights( pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy() # First, ensure that the checkpoint is a compatible one and can be successfully loaded. - state_dict = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs) + kwargs["return_lora_metadata"] = True + state_dict, metadata = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs) is_correct_format = all("lora" in key for key in state_dict.keys()) if not is_correct_format: @@ -3739,6 +3748,7 @@ def load_lora_weights( state_dict, transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer, adapter_name=adapter_name, + metadata=metadata, _pipeline=self, low_cpu_mem_usage=low_cpu_mem_usage, hotswap=hotswap, @@ -4064,7 +4074,8 @@ def load_lora_weights( pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy() # First, ensure that the checkpoint is a compatible one and can be successfully loaded. - state_dict = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs) + kwargs["return_lora_metadata"] = True + state_dict, metadata = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs) is_correct_format = all("lora" in key for key in state_dict.keys()) if not is_correct_format: @@ -4074,6 +4085,7 @@ def load_lora_weights( state_dict, transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer, adapter_name=adapter_name, + metadata=metadata, _pipeline=self, low_cpu_mem_usage=low_cpu_mem_usage, hotswap=hotswap, @@ -4402,7 +4414,8 @@ def load_lora_weights( pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy() # First, ensure that the checkpoint is a compatible one and can be successfully loaded. - state_dict = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs) + kwargs["return_lora_metadata"] = True + state_dict, metadata = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs) is_correct_format = all("lora" in key for key in state_dict.keys()) if not is_correct_format: @@ -4412,6 +4425,7 @@ def load_lora_weights( state_dict, transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer, adapter_name=adapter_name, + metadata=metadata, _pipeline=self, low_cpu_mem_usage=low_cpu_mem_usage, hotswap=hotswap, @@ -4741,7 +4755,8 @@ def load_lora_weights( pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy() # First, ensure that the checkpoint is a compatible one and can be successfully loaded. - state_dict = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs) + kwargs["return_lora_metadata"] = True + state_dict, metadata = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs) is_correct_format = all("lora" in key for key in state_dict.keys()) if not is_correct_format: @@ -4751,6 +4766,7 @@ def load_lora_weights( state_dict, transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer, adapter_name=adapter_name, + metadata=metadata, _pipeline=self, low_cpu_mem_usage=low_cpu_mem_usage, hotswap=hotswap, @@ -5375,6 +5391,7 @@ def lora_state_dict( allowed by Git. subfolder (`str`, *optional*, defaults to `""`): The subfolder location of a model file within a larger model repository on the Hub or locally. + return_lora_metadata: TODO """ # Load the main state dict first which has the LoRA layers for either of @@ -5388,18 +5405,16 @@ def lora_state_dict( subfolder = kwargs.pop("subfolder", None) weight_name = kwargs.pop("weight_name", None) use_safetensors = kwargs.pop("use_safetensors", None) + return_lora_metadata = kwargs.pop("return_lora_metadata", False) allow_pickle = False if use_safetensors is None: use_safetensors = True allow_pickle = True - user_agent = { - "file_type": "attn_procs_weights", - "framework": "pytorch", - } + user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"} - state_dict = _fetch_state_dict( + state_dict, metadata = _fetch_state_dict( pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict, weight_name=weight_name, use_safetensors=use_safetensors, @@ -5420,7 +5435,8 @@ def lora_state_dict( logger.warning(warn_msg) state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k} - return state_dict + out = (state_dict, metadata) if return_lora_metadata else state_dict + return out # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights def load_lora_weights( @@ -5465,7 +5481,8 @@ def load_lora_weights( pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy() # First, ensure that the checkpoint is a compatible one and can be successfully loaded. - state_dict = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs) + kwargs["return_lora_metadata"] = True + state_dict, metadata = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs) is_correct_format = all("lora" in key for key in state_dict.keys()) if not is_correct_format: @@ -5475,6 +5492,7 @@ def load_lora_weights( state_dict, transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer, adapter_name=adapter_name, + metadata=metadata, _pipeline=self, low_cpu_mem_usage=low_cpu_mem_usage, hotswap=hotswap, @@ -5803,7 +5821,8 @@ def load_lora_weights( pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy() # First, ensure that the checkpoint is a compatible one and can be successfully loaded. - state_dict = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs) + kwargs["return_lora_metadata"] = True + state_dict, metadata = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs) is_correct_format = all("lora" in key for key in state_dict.keys()) if not is_correct_format: @@ -5813,6 +5832,7 @@ def load_lora_weights( state_dict, transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer, adapter_name=adapter_name, + metadata=metadata, _pipeline=self, low_cpu_mem_usage=low_cpu_mem_usage, hotswap=hotswap, From 0ae34081884034ee0a890d9b654c3c37cabc0200 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Thu, 22 May 2025 10:47:20 +0530 Subject: [PATCH 33/53] revert changes in get_peft_kwargs. --- src/diffusers/utils/peft_utils.py | 12 ------------ 1 file changed, 12 deletions(-) diff --git a/src/diffusers/utils/peft_utils.py b/src/diffusers/utils/peft_utils.py index 9a95d0f44db0..66cf7ef8a5c5 100644 --- a/src/diffusers/utils/peft_utils.py +++ b/src/diffusers/utils/peft_utils.py @@ -152,19 +152,7 @@ def get_peft_kwargs( network_alpha_dict, peft_state_dict, is_unet=True, - prefix=None, ): - from ..loaders.lora_base import LORA_ADAPTER_METADATA_KEY - - if LORA_ADAPTER_METADATA_KEY in peft_state_dict: - metadata = peft_state_dict[LORA_ADAPTER_METADATA_KEY] - else: - metadata = None - if metadata: - if prefix is not None: - metadata = {k.removeprefix(f"{prefix}."): v for k, v in metadata.items() if k.startswith(f"{prefix}.")} - return metadata - rank_pattern = {} alpha_pattern = {} r = lora_alpha = list(rank_dict.values())[0] From 99fe09cdb54df6b8a5264a375d19fe13dc95c544 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Thu, 22 May 2025 10:52:07 +0530 Subject: [PATCH 34/53] update --- src/diffusers/loaders/lora_base.py | 1 - src/diffusers/utils/peft_utils.py | 7 +------ tests/lora/test_lora_layers_wan.py | 4 ++-- 3 files changed, 3 insertions(+), 9 deletions(-) diff --git a/src/diffusers/loaders/lora_base.py b/src/diffusers/loaders/lora_base.py index a326d054ac67..910b077ed346 100644 --- a/src/diffusers/loaders/lora_base.py +++ b/src/diffusers/loaders/lora_base.py @@ -227,7 +227,6 @@ def _fetch_state_dict( file_extension=".safetensors", local_files_only=local_files_only, ) - model_file = _get_model_file( pretrained_model_name_or_path_or_dict, weights_name=weight_name or LORA_WEIGHT_NAME_SAFE, diff --git a/src/diffusers/utils/peft_utils.py b/src/diffusers/utils/peft_utils.py index 66cf7ef8a5c5..7d0a6faa7afb 100644 --- a/src/diffusers/utils/peft_utils.py +++ b/src/diffusers/utils/peft_utils.py @@ -147,12 +147,7 @@ def unscale_lora_layers(model, weight: Optional[float] = None): module.set_scale(adapter_name, 1.0) -def get_peft_kwargs( - rank_dict, - network_alpha_dict, - peft_state_dict, - is_unet=True, -): +def get_peft_kwargs(rank_dict, network_alpha_dict, peft_state_dict, is_unet=True): rank_pattern = {} alpha_pattern = {} r = lora_alpha = list(rank_dict.values())[0] diff --git a/tests/lora/test_lora_layers_wan.py b/tests/lora/test_lora_layers_wan.py index 7866556e900f..a1420012d601 100644 --- a/tests/lora/test_lora_layers_wan.py +++ b/tests/lora/test_lora_layers_wan.py @@ -24,7 +24,7 @@ WanPipeline, WanTransformer3DModel, ) -from diffusers.utils.testing_utils import floats_tensor, require_peft_backend +from diffusers.utils.testing_utils import floats_tensor, require_peft_backend, skip_mps sys.path.append(".") @@ -33,7 +33,7 @@ @require_peft_backend -# @skip_mps +@skip_mps class WanLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): pipeline_class = WanPipeline scheduler_cls = FlowMatchEulerDiscreteScheduler From f4d417975e461fe6a0294733db55da8ed59404ef Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Thu, 22 May 2025 11:26:17 +0530 Subject: [PATCH 35/53] fixes --- examples/community/ip_adapter_face_id.py | 5 +- src/diffusers/loaders/ip_adapter.py | 15 ++---- src/diffusers/loaders/lora_base.py | 2 +- src/diffusers/loaders/lora_pipeline.py | 66 +++++++++--------------- src/diffusers/loaders/peft.py | 5 +- src/diffusers/loaders/unet.py | 5 +- 6 files changed, 32 insertions(+), 66 deletions(-) diff --git a/examples/community/ip_adapter_face_id.py b/examples/community/ip_adapter_face_id.py index 203be1d4c874..7a7f1bcbd31a 100644 --- a/examples/community/ip_adapter_face_id.py +++ b/examples/community/ip_adapter_face_id.py @@ -282,10 +282,7 @@ def load_ip_adapter_face_id(self, pretrained_model_name_or_path_or_dict, weight_ revision = kwargs.pop("revision", None) subfolder = kwargs.pop("subfolder", None) - user_agent = { - "file_type": "attn_procs_weights", - "framework": "pytorch", - } + user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"} model_file = _get_model_file( pretrained_model_name_or_path_or_dict, weights_name=weight_name, diff --git a/src/diffusers/loaders/ip_adapter.py b/src/diffusers/loaders/ip_adapter.py index f4c48f254c44..63b4f226434d 100644 --- a/src/diffusers/loaders/ip_adapter.py +++ b/src/diffusers/loaders/ip_adapter.py @@ -159,10 +159,7 @@ def load_ip_adapter( " `low_cpu_mem_usage=False`." ) - user_agent = { - "file_type": "attn_procs_weights", - "framework": "pytorch", - } + user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"} state_dicts = [] for pretrained_model_name_or_path_or_dict, weight_name, subfolder in zip( pretrained_model_name_or_path_or_dict, weight_name, subfolder @@ -465,10 +462,7 @@ def load_ip_adapter( " `low_cpu_mem_usage=False`." ) - user_agent = { - "file_type": "attn_procs_weights", - "framework": "pytorch", - } + user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"} state_dicts = [] for pretrained_model_name_or_path_or_dict, weight_name, subfolder in zip( pretrained_model_name_or_path_or_dict, weight_name, subfolder @@ -750,10 +744,7 @@ def load_ip_adapter( " `low_cpu_mem_usage=False`." ) - user_agent = { - "file_type": "attn_procs_weights", - "framework": "pytorch", - } + user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"} if not isinstance(pretrained_model_name_or_path_or_dict, dict): model_file = _get_model_file( diff --git a/src/diffusers/loaders/lora_base.py b/src/diffusers/loaders/lora_base.py index 910b077ed346..92347ccf2425 100644 --- a/src/diffusers/loaders/lora_base.py +++ b/src/diffusers/loaders/lora_base.py @@ -398,7 +398,7 @@ def _load_lora_into_text_encoder( if metadata is not None: lora_config_kwargs = metadata else: - lora_config_kwargs = get_peft_kwargs(rank, network_alphas, state_dict, is_unet=False, prefix=prefix) + lora_config_kwargs = get_peft_kwargs(rank, network_alphas, state_dict, is_unet=False) if "use_dora" in lora_config_kwargs: if lora_config_kwargs["use_dora"]: diff --git a/src/diffusers/loaders/lora_pipeline.py b/src/diffusers/loaders/lora_pipeline.py index 9d4c0b8fc5bb..9723bcabbecf 100644 --- a/src/diffusers/loaders/lora_pipeline.py +++ b/src/diffusers/loaders/lora_pipeline.py @@ -302,10 +302,7 @@ def lora_state_dict( use_safetensors = True allow_pickle = True - user_agent = { - "file_type": "attn_procs_weights", - "framework": "pytorch", - } + user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"} state_dict, metadata = _fetch_state_dict( pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict, @@ -785,10 +782,7 @@ def lora_state_dict( use_safetensors = True allow_pickle = True - user_agent = { - "file_type": "attn_procs_weights", - "framework": "pytorch", - } + user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"} state_dict, metadata = _fetch_state_dict( pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict, @@ -1954,7 +1948,7 @@ def lora_state_dict( allowed by Git. subfolder (`str`, *optional*, defaults to `""`): The subfolder location of a model file within a larger model repository on the Hub or locally. - + return_lora_metadata: TODO """ # Load the main state dict first which has the LoRA layers for either of # transformer and text encoder or both. @@ -1967,18 +1961,16 @@ def lora_state_dict( subfolder = kwargs.pop("subfolder", None) weight_name = kwargs.pop("weight_name", None) use_safetensors = kwargs.pop("use_safetensors", None) + return_lora_metadata = kwargs.pop("return_lora_metadata", False) allow_pickle = False if use_safetensors is None: use_safetensors = True allow_pickle = True - user_agent = { - "file_type": "attn_procs_weights", - "framework": "pytorch", - } + user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"} - state_dict = _fetch_state_dict( + state_dict, metadata = _fetch_state_dict( pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict, weight_name=weight_name, use_safetensors=use_safetensors, @@ -2032,10 +2024,12 @@ def lora_state_dict( f"The alpha key ({k}) seems to be incorrect. If you think this error is unexpected, please open as issue." ) + outputs = [state_dict] if return_alphas: - return state_dict, network_alphas - else: - return state_dict + outputs.append(network_alphas) + if return_lora_metadata: + outputs.append(metadata) + return tuple(outputs) def load_lora_weights( self, @@ -2084,7 +2078,8 @@ def load_lora_weights( pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy() # First, ensure that the checkpoint is a compatible one and can be successfully loaded. - state_dict, network_alphas = self.lora_state_dict( + kwargs["return_lora_metadata"] = True + state_dict, network_alphas, metadata = self.lora_state_dict( pretrained_model_name_or_path_or_dict, return_alphas=True, **kwargs ) @@ -2135,6 +2130,7 @@ def load_lora_weights( network_alphas=network_alphas, transformer=transformer, adapter_name=adapter_name, + metadata=metadata, _pipeline=self, low_cpu_mem_usage=low_cpu_mem_usage, hotswap=hotswap, @@ -2154,6 +2150,7 @@ def load_lora_weights( prefix=self.text_encoder_name, lora_scale=self.lora_scale, adapter_name=adapter_name, + metadata=metadata, _pipeline=self, low_cpu_mem_usage=low_cpu_mem_usage, hotswap=hotswap, @@ -3661,10 +3658,7 @@ def lora_state_dict( use_safetensors = True allow_pickle = True - user_agent = { - "file_type": "attn_procs_weights", - "framework": "pytorch", - } + user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"} state_dict, metadata = _fetch_state_dict( pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict, @@ -4339,10 +4333,7 @@ def lora_state_dict( use_safetensors = True allow_pickle = True - user_agent = { - "file_type": "attn_procs_weights", - "framework": "pytorch", - } + user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"} state_dict = _fetch_state_dict( pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict, @@ -4679,10 +4670,7 @@ def lora_state_dict( use_safetensors = True allow_pickle = True - user_agent = { - "file_type": "attn_procs_weights", - "framework": "pytorch", - } + user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"} state_dict = _fetch_state_dict( pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict, @@ -5001,7 +4989,7 @@ def lora_state_dict( allowed by Git. subfolder (`str`, *optional*, defaults to `""`): The subfolder location of a model file within a larger model repository on the Hub or locally. - + return_lora_metadata: TODO """ # Load the main state dict first which has the LoRA layers for either of # transformer and text encoder or both. @@ -5014,18 +5002,16 @@ def lora_state_dict( subfolder = kwargs.pop("subfolder", None) weight_name = kwargs.pop("weight_name", None) use_safetensors = kwargs.pop("use_safetensors", None) + return_lora_metadata = kwargs.pop("return_lora_metadata", False) allow_pickle = False if use_safetensors is None: use_safetensors = True allow_pickle = True - user_agent = { - "file_type": "attn_procs_weights", - "framework": "pytorch", - } + user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"} - state_dict = _fetch_state_dict( + state_dict, metadata = _fetch_state_dict( pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict, weight_name=weight_name, use_safetensors=use_safetensors, @@ -5050,7 +5036,8 @@ def lora_state_dict( logger.warning(warn_msg) state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k} - return state_dict + out = (state_dict, metadata) if return_lora_metadata else state_dict + return out @classmethod def _maybe_expand_t2v_lora_for_i2v( @@ -5746,10 +5733,7 @@ def lora_state_dict( use_safetensors = True allow_pickle = True - user_agent = { - "file_type": "attn_procs_weights", - "framework": "pytorch", - } + user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"} state_dict = _fetch_state_dict( pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict, diff --git a/src/diffusers/loaders/peft.py b/src/diffusers/loaders/peft.py index b73813aa6d18..6968ce3b7421 100644 --- a/src/diffusers/loaders/peft.py +++ b/src/diffusers/loaders/peft.py @@ -275,10 +275,7 @@ def load_lora_adapter( lora_config_kwargs = metadata else: lora_config_kwargs = get_peft_kwargs( - rank, - network_alpha_dict=network_alphas, - peft_state_dict=state_dict, - prefix=prefix, + rank, network_alpha_dict=network_alphas, peft_state_dict=state_dict ) _maybe_raise_error_for_ambiguity(lora_config_kwargs) diff --git a/src/diffusers/loaders/unet.py b/src/diffusers/loaders/unet.py index 1d8aba900c85..d9308c57bfe5 100644 --- a/src/diffusers/loaders/unet.py +++ b/src/diffusers/loaders/unet.py @@ -155,10 +155,7 @@ def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict use_safetensors = True allow_pickle = True - user_agent = { - "file_type": "attn_procs_weights", - "framework": "pytorch", - } + user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"} model_file = None if not isinstance(pretrained_model_name_or_path_or_dict, dict): From 46f47263cf0967efb53f8c326c188dd045de7571 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Thu, 22 May 2025 11:30:44 +0530 Subject: [PATCH 36/53] fixes --- src/diffusers/loaders/lora_pipeline.py | 24 +++++++++++++++--------- 1 file changed, 15 insertions(+), 9 deletions(-) diff --git a/src/diffusers/loaders/lora_pipeline.py b/src/diffusers/loaders/lora_pipeline.py index 9723bcabbecf..143d9edda0b0 100644 --- a/src/diffusers/loaders/lora_pipeline.py +++ b/src/diffusers/loaders/lora_pipeline.py @@ -4314,7 +4314,7 @@ def lora_state_dict( allowed by Git. subfolder (`str`, *optional*, defaults to `""`): The subfolder location of a model file within a larger model repository on the Hub or locally. - + return_lora_metadata: TODO """ # Load the main state dict first which has the LoRA layers for either of # transformer and text encoder or both. @@ -4327,6 +4327,7 @@ def lora_state_dict( subfolder = kwargs.pop("subfolder", None) weight_name = kwargs.pop("weight_name", None) use_safetensors = kwargs.pop("use_safetensors", None) + return_lora_metadata = kwargs.pop("return_lora_metadata", False) allow_pickle = False if use_safetensors is None: @@ -4335,7 +4336,7 @@ def lora_state_dict( user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"} - state_dict = _fetch_state_dict( + state_dict, metadata = _fetch_state_dict( pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict, weight_name=weight_name, use_safetensors=use_safetensors, @@ -4360,7 +4361,8 @@ def lora_state_dict( if is_original_hunyuan_video: state_dict = _convert_hunyuan_video_lora_to_diffusers(state_dict) - return state_dict + out = (state_dict, metadata) if return_lora_metadata else state_dict + return out # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights def load_lora_weights( @@ -4651,7 +4653,7 @@ def lora_state_dict( allowed by Git. subfolder (`str`, *optional*, defaults to `""`): The subfolder location of a model file within a larger model repository on the Hub or locally. - + return_lora_metadata: TODO """ # Load the main state dict first which has the LoRA layers for either of # transformer and text encoder or both. @@ -4664,6 +4666,7 @@ def lora_state_dict( subfolder = kwargs.pop("subfolder", None) weight_name = kwargs.pop("weight_name", None) use_safetensors = kwargs.pop("use_safetensors", None) + return_lora_metadata = kwargs.pop("return_lora_metadata", False) allow_pickle = False if use_safetensors is None: @@ -4672,7 +4675,7 @@ def lora_state_dict( user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"} - state_dict = _fetch_state_dict( + state_dict, metadata = _fetch_state_dict( pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict, weight_name=weight_name, use_safetensors=use_safetensors, @@ -4698,7 +4701,8 @@ def lora_state_dict( if non_diffusers: state_dict = _convert_non_diffusers_lumina2_lora_to_diffusers(state_dict) - return state_dict + out = (state_dict, metadata) if return_lora_metadata else state_dict + return out # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights def load_lora_weights( @@ -5714,7 +5718,7 @@ def lora_state_dict( allowed by Git. subfolder (`str`, *optional*, defaults to `""`): The subfolder location of a model file within a larger model repository on the Hub or locally. - + return_lora_metadata: TODO """ # Load the main state dict first which has the LoRA layers for either of # transformer and text encoder or both. @@ -5727,6 +5731,7 @@ def lora_state_dict( subfolder = kwargs.pop("subfolder", None) weight_name = kwargs.pop("weight_name", None) use_safetensors = kwargs.pop("use_safetensors", None) + return_lora_metadata = kwargs.pop("return_lora_metadata", False) allow_pickle = False if use_safetensors is None: @@ -5735,7 +5740,7 @@ def lora_state_dict( user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"} - state_dict = _fetch_state_dict( + state_dict, metadata = _fetch_state_dict( pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict, weight_name=weight_name, use_safetensors=use_safetensors, @@ -5760,7 +5765,8 @@ def lora_state_dict( if is_non_diffusers_format: state_dict = _convert_non_diffusers_hidream_lora_to_diffusers(state_dict) - return state_dict + out = (state_dict, metadata) if return_lora_metadata else state_dict + return out # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights def load_lora_weights( From 1348463290f99dfffa7a2408cda525bddc33e600 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Thu, 22 May 2025 11:51:32 +0530 Subject: [PATCH 37/53] simplify _load_sft_state_dict_metadata --- src/diffusers/utils/state_dict_utils.py | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/src/diffusers/utils/state_dict_utils.py b/src/diffusers/utils/state_dict_utils.py index 9e0c208eea92..498f7e566c67 100644 --- a/src/diffusers/utils/state_dict_utils.py +++ b/src/diffusers/utils/state_dict_utils.py @@ -355,11 +355,9 @@ def _load_sft_state_dict_metadata(model_file: str): from ..loaders.lora_base import LORA_ADAPTER_METADATA_KEY - metadata = None with safetensors.torch.safe_open(model_file, framework="pt", device="cpu") as f: - metadata = f.metadata() - if metadata is not None: - metadata_keys = list(metadata.keys()) - if not (len(metadata_keys) == 1 and metadata_keys[0] == "format"): - metadata = json.loads(metadata[LORA_ADAPTER_METADATA_KEY]) - return metadata + metadata = f.metadata() or {} + + metadata.pop("format", None) + raw = metadata.get(LORA_ADAPTER_METADATA_KEY) + return json.loads(raw) if raw else None From ef16bcec6f9b3856877697444fb5f03ca058f559 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Thu, 22 May 2025 11:53:08 +0530 Subject: [PATCH 38/53] update --- src/diffusers/loaders/lora_pipeline.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/diffusers/loaders/lora_pipeline.py b/src/diffusers/loaders/lora_pipeline.py index 143d9edda0b0..b7192068be27 100644 --- a/src/diffusers/loaders/lora_pipeline.py +++ b/src/diffusers/loaders/lora_pipeline.py @@ -2163,6 +2163,7 @@ def load_lora_into_transformer( network_alphas, transformer, adapter_name=None, + metadata=None, _pipeline=None, low_cpu_mem_usage=False, hotswap: bool = False, @@ -2184,6 +2185,7 @@ def load_lora_into_transformer( adapter_name (`str`, *optional*): Adapter name to be used for referencing the loaded adapter model. If not specified, it will use `default_{i}` where i is the total number of adapters being loaded. + metadata: TODO low_cpu_mem_usage (`bool`, *optional*): Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights. From 9cba78ea071d43228a10c204740b899cfbc84376 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Thu, 22 May 2025 11:59:18 +0530 Subject: [PATCH 39/53] style fix --- src/diffusers/loaders/lora_pipeline.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/diffusers/loaders/lora_pipeline.py b/src/diffusers/loaders/lora_pipeline.py index b7192068be27..fe3e0b63092b 100644 --- a/src/diffusers/loaders/lora_pipeline.py +++ b/src/diffusers/loaders/lora_pipeline.py @@ -2756,6 +2756,7 @@ def load_lora_into_transformer( network_alphas, transformer, adapter_name=None, + metadata=None, _pipeline=None, low_cpu_mem_usage=False, hotswap: bool = False, @@ -2777,6 +2778,7 @@ def load_lora_into_transformer( adapter_name (`str`, *optional*): Adapter name to be used for referencing the loaded adapter model. If not specified, it will use `default_{i}` where i is the total number of adapters being loaded. + metadata: TODO low_cpu_mem_usage (`bool`, *optional*): Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights. From 28d634f67152cdb3f3467ddd3af3d3d8397270db Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Thu, 22 May 2025 12:23:19 +0530 Subject: [PATCH 40/53] uipdate --- src/diffusers/loaders/lora_pipeline.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/diffusers/loaders/lora_pipeline.py b/src/diffusers/loaders/lora_pipeline.py index fe3e0b63092b..d66f69b7606d 100644 --- a/src/diffusers/loaders/lora_pipeline.py +++ b/src/diffusers/loaders/lora_pipeline.py @@ -2082,6 +2082,7 @@ def load_lora_weights( state_dict, network_alphas, metadata = self.lora_state_dict( pretrained_model_name_or_path_or_dict, return_alphas=True, **kwargs ) + print(f"{metadata=}") has_lora_keys = any("lora" in key for key in state_dict.keys()) @@ -2203,6 +2204,7 @@ def load_lora_into_transformer( state_dict, network_alphas=network_alphas, adapter_name=adapter_name, + metadata=metadata, _pipeline=_pipeline, low_cpu_mem_usage=low_cpu_mem_usage, hotswap=hotswap, @@ -5137,7 +5139,8 @@ def load_lora_weights( pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy() # First, ensure that the checkpoint is a compatible one and can be successfully loaded. - state_dict = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs) + kwargs["return_lora_metadata"] = True + state_dict, metadata = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs) # convert T2V LoRA to I2V LoRA (when loaded to Wan I2V) by adding zeros for the additional (missing) _img layers state_dict = self._maybe_expand_t2v_lora_for_i2v( transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer, @@ -5151,6 +5154,7 @@ def load_lora_weights( state_dict, transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer, adapter_name=adapter_name, + metadata=metadata, _pipeline=self, low_cpu_mem_usage=low_cpu_mem_usage, hotswap=hotswap, From e07ace06537261bf517569292517f78130a2f901 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Thu, 22 May 2025 12:25:13 +0530 Subject: [PATCH 41/53] update --- src/diffusers/loaders/lora_pipeline.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/diffusers/loaders/lora_pipeline.py b/src/diffusers/loaders/lora_pipeline.py index d66f69b7606d..a2ae326de2f7 100644 --- a/src/diffusers/loaders/lora_pipeline.py +++ b/src/diffusers/loaders/lora_pipeline.py @@ -2798,6 +2798,7 @@ def load_lora_into_transformer( state_dict, network_alphas=network_alphas, adapter_name=adapter_name, + metadata=metadata, _pipeline=_pipeline, low_cpu_mem_usage=low_cpu_mem_usage, hotswap=hotswap, From c762b7ca2065a97961a524f07b7fe4cb89cd8255 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Thu, 22 May 2025 12:42:00 +0530 Subject: [PATCH 42/53] update --- src/diffusers/loaders/lora_pipeline.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/src/diffusers/loaders/lora_pipeline.py b/src/diffusers/loaders/lora_pipeline.py index a2ae326de2f7..e19ea492d13c 100644 --- a/src/diffusers/loaders/lora_pipeline.py +++ b/src/diffusers/loaders/lora_pipeline.py @@ -2024,12 +2024,15 @@ def lora_state_dict( f"The alpha key ({k}) seems to be incorrect. If you think this error is unexpected, please open as issue." ) - outputs = [state_dict] - if return_alphas: - outputs.append(network_alphas) - if return_lora_metadata: - outputs.append(metadata) - return tuple(outputs) + if return_alphas or return_lora_metadata: + outputs = [state_dict] + if return_alphas: + outputs.append(network_alphas) + if return_lora_metadata: + outputs.append(metadata) + return tuple(outputs) + else: + return state_dict def load_lora_weights( self, @@ -2082,7 +2085,6 @@ def load_lora_weights( state_dict, network_alphas, metadata = self.lora_state_dict( pretrained_model_name_or_path_or_dict, return_alphas=True, **kwargs ) - print(f"{metadata=}") has_lora_keys = any("lora" in key for key in state_dict.keys()) From c8c33d32e1227191f4e2a2a3df4676cb36429d97 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Thu, 22 May 2025 13:29:01 +0530 Subject: [PATCH 43/53] empty commit From d952267fc6a9420ea8ee5a38c1d4790244fe4bfb Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Thu, 5 Jun 2025 16:11:31 +0530 Subject: [PATCH 44/53] _pack_dict_with_prefix --- src/diffusers/loaders/lora_base.py | 4 ++-- src/diffusers/loaders/lora_pipeline.py | 10 +++++----- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/src/diffusers/loaders/lora_base.py b/src/diffusers/loaders/lora_base.py index 80e63ece4cdc..8b7fc34945a7 100644 --- a/src/diffusers/loaders/lora_base.py +++ b/src/diffusers/loaders/lora_base.py @@ -314,7 +314,7 @@ def _best_guess_weight_name( return weight_name -def _pack_sd_with_prefix(state_dict, prefix): +def _pack_dict_with_prefix(state_dict, prefix): sd_with_prefix = {f"{prefix}.{key}": value for key, value in state_dict.items()} return sd_with_prefix @@ -914,7 +914,7 @@ def set_lora_device(self, adapter_names: List[str], device: Union[torch.device, @staticmethod def pack_weights(layers, prefix): layers_weights = layers.state_dict() if isinstance(layers, torch.nn.Module) else layers - return _pack_sd_with_prefix(layers_weights, prefix) + return _pack_dict_with_prefix(layers_weights, prefix) @staticmethod def write_lora_layers( diff --git a/src/diffusers/loaders/lora_pipeline.py b/src/diffusers/loaders/lora_pipeline.py index 27f03ae54853..1eb2425da587 100644 --- a/src/diffusers/loaders/lora_pipeline.py +++ b/src/diffusers/loaders/lora_pipeline.py @@ -37,7 +37,7 @@ LoraBaseMixin, _fetch_state_dict, _load_lora_into_text_encoder, - _pack_sd_with_prefix, + _pack_dict_with_prefix, ) from .lora_conversion_utils import ( _convert_bfl_flux_control_lora_to_diffusers, @@ -511,11 +511,11 @@ def save_lora_weights( state_dict.update(cls.pack_weights(text_encoder_lora_layers, cls.text_encoder_name)) if unet_lora_adapter_metadata: - lora_adapter_metadata.update(_pack_sd_with_prefix(unet_lora_adapter_metadata, cls.unet_name)) + lora_adapter_metadata.update(_pack_dict_with_prefix(unet_lora_adapter_metadata, cls.unet_name)) if text_encoder_lora_adapter_metadata: lora_adapter_metadata.update( - _pack_sd_with_prefix(text_encoder_lora_adapter_metadata, cls.text_encoder_name) + _pack_dict_with_prefix(text_encoder_lora_adapter_metadata, cls.text_encoder_name) ) # Save the model @@ -2376,11 +2376,11 @@ def save_lora_weights( state_dict.update(cls.pack_weights(text_encoder_lora_layers, cls.text_encoder_name)) if transformer_lora_adapter_metadata: - lora_adapter_metadata.update(_pack_sd_with_prefix(transformer_lora_adapter_metadata, cls.transformer_name)) + lora_adapter_metadata.update(_pack_dict_with_prefix(transformer_lora_adapter_metadata, cls.transformer_name)) if text_encoder_lora_adapter_metadata: lora_adapter_metadata.update( - _pack_sd_with_prefix(text_encoder_lora_adapter_metadata, cls.text_encoder_name) + _pack_dict_with_prefix(text_encoder_lora_adapter_metadata, cls.text_encoder_name) ) # Save the model From 9bbc6dc4219df6725a37ef31fa1a70bce5214157 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Thu, 5 Jun 2025 16:13:06 +0530 Subject: [PATCH 45/53] update --- src/diffusers/loaders/lora_pipeline.py | 66 +++++++++++++++++++------- 1 file changed, 49 insertions(+), 17 deletions(-) diff --git a/src/diffusers/loaders/lora_pipeline.py b/src/diffusers/loaders/lora_pipeline.py index 1eb2425da587..4b08005189bc 100644 --- a/src/diffusers/loaders/lora_pipeline.py +++ b/src/diffusers/loaders/lora_pipeline.py @@ -1005,13 +1005,17 @@ def save_lora_weights( state_dict.update(cls.pack_weights(text_encoder_2_lora_layers, "text_encoder_2")) if unet_lora_adapter_metadata is not None: - lora_adapter_metadata.update(cls.pack_weights(unet_lora_adapter_metadata, cls.unet_name)) + lora_adapter_metadata.update(_pack_dict_with_prefix(unet_lora_adapter_metadata, cls.unet_name)) if text_encoder_lora_adapter_metadata: - lora_adapter_metadata.update(cls.pack_weights(text_encoder_lora_adapter_metadata, cls.text_encoder_name)) + lora_adapter_metadata.update( + _pack_dict_with_prefix(text_encoder_lora_adapter_metadata, cls.text_encoder_name) + ) if text_encoder_2_lora_adapter_metadata: - lora_adapter_metadata.update(cls.pack_weights(text_encoder_2_lora_adapter_metadata, "text_encoder_2")) + lora_adapter_metadata.update( + _pack_dict_with_prefix(text_encoder_2_lora_adapter_metadata, "text_encoder_2") + ) cls.write_lora_layers( state_dict=state_dict, @@ -1459,13 +1463,19 @@ def save_lora_weights( state_dict.update(cls.pack_weights(text_encoder_2_lora_layers, "text_encoder_2")) if transformer_lora_adapter_metadata is not None: - lora_adapter_metadata.update(cls.pack_weights(transformer_lora_adapter_metadata, cls.transformer_name)) + lora_adapter_metadata.update( + _pack_dict_with_prefix(transformer_lora_adapter_metadata, cls.transformer_name) + ) if text_encoder_lora_adapter_metadata: - lora_adapter_metadata.update(cls.pack_weights(text_encoder_lora_adapter_metadata, cls.text_encoder_name)) + lora_adapter_metadata.update( + _pack_dict_with_prefix(text_encoder_lora_adapter_metadata, cls.text_encoder_name) + ) if text_encoder_2_lora_adapter_metadata: - lora_adapter_metadata.update(cls.pack_weights(text_encoder_2_lora_adapter_metadata, "text_encoder_2")) + lora_adapter_metadata.update( + _pack_dict_with_prefix(text_encoder_2_lora_adapter_metadata, "text_encoder_2") + ) cls.write_lora_layers( state_dict=state_dict, @@ -1804,7 +1814,9 @@ def save_lora_weights( state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name)) if transformer_lora_adapter_metadata is not None: - lora_adapter_metadata.update(cls.pack_weights(transformer_lora_adapter_metadata, cls.transformer_name)) + lora_adapter_metadata.update( + _pack_dict_with_prefix(transformer_lora_adapter_metadata, cls.transformer_name) + ) # Save the model cls.write_lora_layers( @@ -2376,7 +2388,9 @@ def save_lora_weights( state_dict.update(cls.pack_weights(text_encoder_lora_layers, cls.text_encoder_name)) if transformer_lora_adapter_metadata: - lora_adapter_metadata.update(_pack_dict_with_prefix(transformer_lora_adapter_metadata, cls.transformer_name)) + lora_adapter_metadata.update( + _pack_dict_with_prefix(transformer_lora_adapter_metadata, cls.transformer_name) + ) if text_encoder_lora_adapter_metadata: lora_adapter_metadata.update( @@ -3173,7 +3187,9 @@ def save_lora_weights( state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name)) if transformer_lora_adapter_metadata is not None: - lora_adapter_metadata.update(cls.pack_weights(transformer_lora_adapter_metadata, cls.transformer_name)) + lora_adapter_metadata.update( + _pack_dict_with_prefix(transformer_lora_adapter_metadata, cls.transformer_name) + ) # Save the model cls.write_lora_layers( @@ -3508,7 +3524,9 @@ def save_lora_weights( state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name)) if transformer_lora_adapter_metadata is not None: - lora_adapter_metadata.update(cls.pack_weights(transformer_lora_adapter_metadata, cls.transformer_name)) + lora_adapter_metadata.update( + _pack_dict_with_prefix(transformer_lora_adapter_metadata, cls.transformer_name) + ) # Save the model cls.write_lora_layers( @@ -3847,7 +3865,9 @@ def save_lora_weights( state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name)) if transformer_lora_adapter_metadata is not None: - lora_adapter_metadata.update(cls.pack_weights(transformer_lora_adapter_metadata, cls.transformer_name)) + lora_adapter_metadata.update( + _pack_dict_with_prefix(transformer_lora_adapter_metadata, cls.transformer_name) + ) # Save the model cls.write_lora_layers( @@ -4184,7 +4204,9 @@ def save_lora_weights( state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name)) if transformer_lora_adapter_metadata is not None: - lora_adapter_metadata.update(cls.pack_weights(transformer_lora_adapter_metadata, cls.transformer_name)) + lora_adapter_metadata.update( + _pack_dict_with_prefix(transformer_lora_adapter_metadata, cls.transformer_name) + ) # Save the model cls.write_lora_layers( @@ -4523,7 +4545,9 @@ def save_lora_weights( state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name)) if transformer_lora_adapter_metadata is not None: - lora_adapter_metadata.update(cls.pack_weights(transformer_lora_adapter_metadata, cls.transformer_name)) + lora_adapter_metadata.update( + _pack_dict_with_prefix(transformer_lora_adapter_metadata, cls.transformer_name) + ) # Save the model cls.write_lora_layers( @@ -4863,7 +4887,9 @@ def save_lora_weights( state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name)) if transformer_lora_adapter_metadata is not None: - lora_adapter_metadata.update(cls.pack_weights(transformer_lora_adapter_metadata, cls.transformer_name)) + lora_adapter_metadata.update( + _pack_dict_with_prefix(transformer_lora_adapter_metadata, cls.transformer_name) + ) # Save the model cls.write_lora_layers( @@ -5253,7 +5279,9 @@ def save_lora_weights( state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name)) if transformer_lora_adapter_metadata is not None: - lora_adapter_metadata.update(cls.pack_weights(transformer_lora_adapter_metadata, cls.transformer_name)) + lora_adapter_metadata.update( + _pack_dict_with_prefix(transformer_lora_adapter_metadata, cls.transformer_name) + ) # Save the model cls.write_lora_layers( @@ -5590,7 +5618,9 @@ def save_lora_weights( state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name)) if transformer_lora_adapter_metadata is not None: - lora_adapter_metadata.update(cls.pack_weights(transformer_lora_adapter_metadata, cls.transformer_name)) + lora_adapter_metadata.update( + _pack_dict_with_prefix(transformer_lora_adapter_metadata, cls.transformer_name) + ) # Save the model cls.write_lora_layers( @@ -5929,7 +5959,9 @@ def save_lora_weights( state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name)) if transformer_lora_adapter_metadata is not None: - lora_adapter_metadata.update(cls.pack_weights(transformer_lora_adapter_metadata, cls.transformer_name)) + lora_adapter_metadata.update( + _pack_dict_with_prefix(transformer_lora_adapter_metadata, cls.transformer_name) + ) # Save the model cls.write_lora_layers( From eb524696a129ac3f72f26c7fe7e9cc4230fffc2c Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Thu, 5 Jun 2025 16:15:37 +0530 Subject: [PATCH 46/53] TODO 1. --- src/diffusers/loaders/lora_pipeline.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/diffusers/loaders/lora_pipeline.py b/src/diffusers/loaders/lora_pipeline.py index 4b08005189bc..34443f23ffec 100644 --- a/src/diffusers/loaders/lora_pipeline.py +++ b/src/diffusers/loaders/lora_pipeline.py @@ -281,7 +281,8 @@ def lora_state_dict( The subfolder location of a model file within a larger model repository on the Hub or locally. weight_name (`str`, *optional*, defaults to None): Name of the serialized state dict file. - return_lora_metadata: TODO + return_lora_metadata (`bool`, *optional*, defaults to False): + When enabled, additionally return the LoRA adapter metadata, typically found in the state dict. """ # Load the main state dict first which has the LoRA layers for either of # UNet and text encoder or both. @@ -761,7 +762,8 @@ def lora_state_dict( The subfolder location of a model file within a larger model repository on the Hub or locally. weight_name (`str`, *optional*, defaults to None): Name of the serialized state dict file. - return_lora_metadata: TODO + return_lora_metadata (`bool`, *optional*, defaults to False): + When enabled, additionally return the LoRA adapter metadata, typically found in the state dict. """ # Load the main state dict first which has the LoRA layers for either of # UNet and text encoder or both. From 461d2bd48e5fee0090e899899562e8b0c9ac834a Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Thu, 5 Jun 2025 16:17:38 +0530 Subject: [PATCH 47/53] todo: 2. --- src/diffusers/loaders/lora_pipeline.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/src/diffusers/loaders/lora_pipeline.py b/src/diffusers/loaders/lora_pipeline.py index 34443f23ffec..ec1339a79bf6 100644 --- a/src/diffusers/loaders/lora_pipeline.py +++ b/src/diffusers/loaders/lora_pipeline.py @@ -374,7 +374,9 @@ def load_lora_into_unet( adapter_name (`str`, *optional*): Adapter name to be used for referencing the loaded adapter model. If not specified, it will use `default_{i}` where i is the total number of adapters being loaded. - metadata: TODO + metadata (`dict`): + Optional LoRA adapter metadata. When supplied, the `LoraConfig` arguments of `peft` won't be derived + from the state dict. low_cpu_mem_usage (`bool`, *optional*): Speed up model loading only loading the pretrained LoRA weights and not initializing the random weights. @@ -856,7 +858,9 @@ def load_lora_into_unet( adapter_name (`str`, *optional*): Adapter name to be used for referencing the loaded adapter model. If not specified, it will use `default_{i}` where i is the total number of adapters being loaded. - metadata: TODO + metadata (`dict`): + Optional LoRA adapter metadata. When supplied, the `LoraConfig` arguments of `peft` won't be derived + from the state dict. low_cpu_mem_usage (`bool`, *optional*): Speed up model loading only loading the pretrained LoRA weights and not initializing the random weights. From f78c6f9bae1121d45222e20afa9ca3f201d52008 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Thu, 5 Jun 2025 16:18:30 +0530 Subject: [PATCH 48/53] todo: 3. --- src/diffusers/loaders/lora_pipeline.py | 20 +++++++++++++++----- 1 file changed, 15 insertions(+), 5 deletions(-) diff --git a/src/diffusers/loaders/lora_pipeline.py b/src/diffusers/loaders/lora_pipeline.py index ec1339a79bf6..f28a66f4fc02 100644 --- a/src/diffusers/loaders/lora_pipeline.py +++ b/src/diffusers/loaders/lora_pipeline.py @@ -441,7 +441,9 @@ def load_lora_into_text_encoder( adapter_name (`str`, *optional*): Adapter name to be used for referencing the loaded adapter model. If not specified, it will use `default_{i}` where i is the total number of adapters being loaded. - metadata: TODO + metadata (`dict`): + Optional LoRA adapter metadata. When supplied, the `LoraConfig` arguments of `peft` won't be derived + from the state dict. low_cpu_mem_usage (`bool`, *optional*): Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights. @@ -926,7 +928,9 @@ def load_lora_into_text_encoder( adapter_name (`str`, *optional*): Adapter name to be used for referencing the loaded adapter model. If not specified, it will use `default_{i}` where i is the total number of adapters being loaded. - metadata: TODO + metadata (`dict`): + Optional LoRA adapter metadata. When supplied, the `LoraConfig` arguments of `peft` won't be derived + from the state dict. low_cpu_mem_usage (`bool`, *optional*): Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights. @@ -1383,7 +1387,9 @@ def load_lora_into_text_encoder( adapter_name (`str`, *optional*): Adapter name to be used for referencing the loaded adapter model. If not specified, it will use `default_{i}` where i is the total number of adapters being loaded. - metadata: TODO + metadata (`dict`): + Optional LoRA adapter metadata. When supplied, the `LoraConfig` arguments of `peft` won't be derived + from the state dict. low_cpu_mem_usage (`bool`, *optional*): Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights. @@ -2320,7 +2326,9 @@ def load_lora_into_text_encoder( adapter_name (`str`, *optional*): Adapter name to be used for referencing the loaded adapter model. If not specified, it will use `default_{i}` where i is the total number of adapters being loaded. - metadata: TODO + metadata (`dict`): + Optional LoRA adapter metadata. When supplied, the `LoraConfig` arguments of `peft` won't be derived + from the state dict. low_cpu_mem_usage (`bool`, *optional*): Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights. @@ -2861,7 +2869,9 @@ def load_lora_into_text_encoder( adapter_name (`str`, *optional*): Adapter name to be used for referencing the loaded adapter model. If not specified, it will use `default_{i}` where i is the total number of adapters being loaded. - metadata: TODO + metadata (`dict`): + Optional LoRA adapter metadata. When supplied, the `LoraConfig` arguments of `peft` won't be derived + from the state dict. low_cpu_mem_usage (`bool`, *optional*): Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights. From 0eba7e79831e732a3d46c8fbd57ce04f5cbfe8fc Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Thu, 5 Jun 2025 16:19:29 +0530 Subject: [PATCH 49/53] update --- src/diffusers/loaders/lora_pipeline.py | 36 +++++++++++++++++--------- 1 file changed, 24 insertions(+), 12 deletions(-) diff --git a/src/diffusers/loaders/lora_pipeline.py b/src/diffusers/loaders/lora_pipeline.py index f28a66f4fc02..75c85085e999 100644 --- a/src/diffusers/loaders/lora_pipeline.py +++ b/src/diffusers/loaders/lora_pipeline.py @@ -1168,7 +1168,8 @@ def lora_state_dict( allowed by Git. subfolder (`str`, *optional*, defaults to `""`): The subfolder location of a model file within a larger model repository on the Hub or locally. - return_lora_metadata: TODO + return_lora_metadata (`bool`, *optional*, defaults to False): + When enabled, additionally return the LoRA adapter metadata, typically found in the state dict. """ # Load the main state dict first which has the LoRA layers for either of @@ -1628,7 +1629,8 @@ def lora_state_dict( allowed by Git. subfolder (`str`, *optional*, defaults to `""`): The subfolder location of a model file within a larger model repository on the Hub or locally. - return_lora_metadata: TODO + return_lora_metadata (`bool`, *optional*, defaults to False): + When enabled, additionally return the LoRA adapter metadata, typically found in the state dict. """ # Load the main state dict first which has the LoRA layers for either of @@ -1972,7 +1974,8 @@ def lora_state_dict( allowed by Git. subfolder (`str`, *optional*, defaults to `""`): The subfolder location of a model file within a larger model repository on the Hub or locally. - return_lora_metadata: TODO + return_lora_metadata (`bool`, *optional*, defaults to False): + When enabled, additionally return the LoRA adapter metadata, typically found in the state dict. """ # Load the main state dict first which has the LoRA layers for either of # transformer and text encoder or both. @@ -3006,7 +3009,8 @@ def lora_state_dict( allowed by Git. subfolder (`str`, *optional*, defaults to `""`): The subfolder location of a model file within a larger model repository on the Hub or locally. - return_lora_metadata: TODO + return_lora_metadata (`bool`, *optional*, defaults to False): + When enabled, additionally return the LoRA adapter metadata, typically found in the state dict. """ # Load the main state dict first which has the LoRA layers for either of @@ -3342,7 +3346,8 @@ def lora_state_dict( allowed by Git. subfolder (`str`, *optional*, defaults to `""`): The subfolder location of a model file within a larger model repository on the Hub or locally. - return_lora_metadata: TODO + return_lora_metadata (`bool`, *optional*, defaults to False): + When enabled, additionally return the LoRA adapter metadata, typically found in the state dict. """ # Load the main state dict first which has the LoRA layers for either of @@ -3680,7 +3685,8 @@ def lora_state_dict( allowed by Git. subfolder (`str`, *optional*, defaults to `""`): The subfolder location of a model file within a larger model repository on the Hub or locally. - return_lora_metadata: TODO + return_lora_metadata (`bool`, *optional*, defaults to False): + When enabled, additionally return the LoRA adapter metadata, typically found in the state dict. """ # Load the main state dict first which has the LoRA layers for either of # transformer and text encoder or both. @@ -4022,7 +4028,8 @@ def lora_state_dict( allowed by Git. subfolder (`str`, *optional*, defaults to `""`): The subfolder location of a model file within a larger model repository on the Hub or locally. - return_lora_metadata: TODO + return_lora_metadata (`bool`, *optional*, defaults to False): + When enabled, additionally return the LoRA adapter metadata, typically found in the state dict. """ # Load the main state dict first which has the LoRA layers for either of @@ -4360,7 +4367,8 @@ def lora_state_dict( allowed by Git. subfolder (`str`, *optional*, defaults to `""`): The subfolder location of a model file within a larger model repository on the Hub or locally. - return_lora_metadata: TODO + return_lora_metadata (`bool`, *optional*, defaults to False): + When enabled, additionally return the LoRA adapter metadata, typically found in the state dict. """ # Load the main state dict first which has the LoRA layers for either of # transformer and text encoder or both. @@ -4701,7 +4709,8 @@ def lora_state_dict( allowed by Git. subfolder (`str`, *optional*, defaults to `""`): The subfolder location of a model file within a larger model repository on the Hub or locally. - return_lora_metadata: TODO + return_lora_metadata (`bool`, *optional*, defaults to False): + When enabled, additionally return the LoRA adapter metadata, typically found in the state dict. """ # Load the main state dict first which has the LoRA layers for either of # transformer and text encoder or both. @@ -5043,7 +5052,8 @@ def lora_state_dict( allowed by Git. subfolder (`str`, *optional*, defaults to `""`): The subfolder location of a model file within a larger model repository on the Hub or locally. - return_lora_metadata: TODO + return_lora_metadata (`bool`, *optional*, defaults to False): + When enabled, additionally return the LoRA adapter metadata, typically found in the state dict. """ # Load the main state dict first which has the LoRA layers for either of # transformer and text encoder or both. @@ -5436,7 +5446,8 @@ def lora_state_dict( allowed by Git. subfolder (`str`, *optional*, defaults to `""`): The subfolder location of a model file within a larger model repository on the Hub or locally. - return_lora_metadata: TODO + return_lora_metadata (`bool`, *optional*, defaults to False): + When enabled, additionally return the LoRA adapter metadata, typically found in the state dict. """ # Load the main state dict first which has the LoRA layers for either of @@ -5774,7 +5785,8 @@ def lora_state_dict( allowed by Git. subfolder (`str`, *optional*, defaults to `""`): The subfolder location of a model file within a larger model repository on the Hub or locally. - return_lora_metadata: TODO + return_lora_metadata (`bool`, *optional*, defaults to False): + When enabled, additionally return the LoRA adapter metadata, typically found in the state dict. """ # Load the main state dict first which has the LoRA layers for either of # transformer and text encoder or both. From a4a15b5f18185d4f5878ced40095e5395e3299ba Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Thu, 5 Jun 2025 16:20:18 +0530 Subject: [PATCH 50/53] update --- src/diffusers/loaders/lora_pipeline.py | 52 +++++++++++++++++++------- 1 file changed, 39 insertions(+), 13 deletions(-) diff --git a/src/diffusers/loaders/lora_pipeline.py b/src/diffusers/loaders/lora_pipeline.py index 75c85085e999..bd3870a20dc0 100644 --- a/src/diffusers/loaders/lora_pipeline.py +++ b/src/diffusers/loaders/lora_pipeline.py @@ -1328,7 +1328,9 @@ def load_lora_into_transformer( adapter_name (`str`, *optional*): Adapter name to be used for referencing the loaded adapter model. If not specified, it will use `default_{i}` where i is the total number of adapters being loaded. - metadata: TODO + metadata (`dict`): + Optional LoRA adapter metadata. When supplied, the `LoraConfig` arguments of `peft` won't be derived + from the state dict. low_cpu_mem_usage (`bool`, *optional*): Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights. @@ -1762,7 +1764,9 @@ def load_lora_into_transformer( adapter_name (`str`, *optional*): Adapter name to be used for referencing the loaded adapter model. If not specified, it will use `default_{i}` where i is the total number of adapters being loaded. - metadata: TODO + metadata (`dict`): + Optional LoRA adapter metadata. When supplied, the `LoraConfig` arguments of `peft` won't be derived + from the state dict. low_cpu_mem_usage (`bool`, *optional*): Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights. @@ -2215,7 +2219,9 @@ def load_lora_into_transformer( adapter_name (`str`, *optional*): Adapter name to be used for referencing the loaded adapter model. If not specified, it will use `default_{i}` where i is the total number of adapters being loaded. - metadata: TODO + metadata (`dict`): + Optional LoRA adapter metadata. When supplied, the `LoraConfig` arguments of `peft` won't be derived + from the state dict. low_cpu_mem_usage (`bool`, *optional*): Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights. @@ -2812,7 +2818,9 @@ def load_lora_into_transformer( adapter_name (`str`, *optional*): Adapter name to be used for referencing the loaded adapter model. If not specified, it will use `default_{i}` where i is the total number of adapters being loaded. - metadata: TODO + metadata (`dict`): + Optional LoRA adapter metadata. When supplied, the `LoraConfig` arguments of `peft` won't be derived + from the state dict. low_cpu_mem_usage (`bool`, *optional*): Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights. @@ -3141,7 +3149,9 @@ def load_lora_into_transformer( adapter_name (`str`, *optional*): Adapter name to be used for referencing the loaded adapter model. If not specified, it will use `default_{i}` where i is the total number of adapters being loaded. - metadata: TODO + metadata (`dict`): + Optional LoRA adapter metadata. When supplied, the `LoraConfig` arguments of `peft` won't be derived + from the state dict. low_cpu_mem_usage (`bool`, *optional*): Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights. @@ -3479,7 +3489,9 @@ def load_lora_into_transformer( adapter_name (`str`, *optional*): Adapter name to be used for referencing the loaded adapter model. If not specified, it will use `default_{i}` where i is the total number of adapters being loaded. - metadata: TODO + metadata (`dict`): + Optional LoRA adapter metadata. When supplied, the `LoraConfig` arguments of `peft` won't be derived + from the state dict. low_cpu_mem_usage (`bool`, *optional*): Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights. @@ -3821,7 +3833,9 @@ def load_lora_into_transformer( adapter_name (`str`, *optional*): Adapter name to be used for referencing the loaded adapter model. If not specified, it will use `default_{i}` where i is the total number of adapters being loaded. - metadata: TODO + metadata (`dict`): + Optional LoRA adapter metadata. When supplied, the `LoraConfig` arguments of `peft` won't be derived + from the state dict. low_cpu_mem_usage (`bool`, *optional*): Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights. @@ -4161,7 +4175,9 @@ def load_lora_into_transformer( adapter_name (`str`, *optional*): Adapter name to be used for referencing the loaded adapter model. If not specified, it will use `default_{i}` where i is the total number of adapters being loaded. - metadata: TODO + metadata (`dict`): + Optional LoRA adapter metadata. When supplied, the `LoraConfig` arguments of `peft` won't be derived + from the state dict. low_cpu_mem_usage (`bool`, *optional*): Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights. @@ -4503,7 +4519,9 @@ def load_lora_into_transformer( adapter_name (`str`, *optional*): Adapter name to be used for referencing the loaded adapter model. If not specified, it will use `default_{i}` where i is the total number of adapters being loaded. - metadata: TODO + metadata (`dict`): + Optional LoRA adapter metadata. When supplied, the `LoraConfig` arguments of `peft` won't be derived + from the state dict. low_cpu_mem_usage (`bool`, *optional*): Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights. @@ -4846,7 +4864,9 @@ def load_lora_into_transformer( adapter_name (`str`, *optional*): Adapter name to be used for referencing the loaded adapter model. If not specified, it will use `default_{i}` where i is the total number of adapters being loaded. - metadata: TODO + metadata (`dict`): + Optional LoRA adapter metadata. When supplied, the `LoraConfig` arguments of `peft` won't be derived + from the state dict. low_cpu_mem_usage (`bool`, *optional*): Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights. @@ -5239,7 +5259,9 @@ def load_lora_into_transformer( adapter_name (`str`, *optional*): Adapter name to be used for referencing the loaded adapter model. If not specified, it will use `default_{i}` where i is the total number of adapters being loaded. - metadata: TODO + metadata (`dict`): + Optional LoRA adapter metadata. When supplied, the `LoraConfig` arguments of `peft` won't be derived + from the state dict. low_cpu_mem_usage (`bool`, *optional*): Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights. @@ -5579,7 +5601,9 @@ def load_lora_into_transformer( adapter_name (`str`, *optional*): Adapter name to be used for referencing the loaded adapter model. If not specified, it will use `default_{i}` where i is the total number of adapters being loaded. - metadata: TODO + metadata (`dict`): + Optional LoRA adapter metadata. When supplied, the `LoraConfig` arguments of `peft` won't be derived + from the state dict. low_cpu_mem_usage (`bool`, *optional*): Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights. @@ -5921,7 +5945,9 @@ def load_lora_into_transformer( adapter_name (`str`, *optional*): Adapter name to be used for referencing the loaded adapter model. If not specified, it will use `default_{i}` where i is the total number of adapters being loaded. - metadata: TODO + metadata (`dict`): + Optional LoRA adapter metadata. When supplied, the `LoraConfig` arguments of `peft` won't be derived + from the state dict. low_cpu_mem_usage (`bool`, *optional*): Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights. From 252fd219024951792c508d70732fd8a665f04627 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Fri, 6 Jun 2025 15:57:30 +0530 Subject: [PATCH 51/53] Apply suggestions from code review Co-authored-by: Benjamin Bossan --- src/diffusers/loaders/lora_base.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/diffusers/loaders/lora_base.py b/src/diffusers/loaders/lora_base.py index 8b7fc34945a7..b2cbd8760fb5 100644 --- a/src/diffusers/loaders/lora_base.py +++ b/src/diffusers/loaders/lora_base.py @@ -336,7 +336,7 @@ def _load_lora_into_text_encoder( raise ValueError("PEFT backend is required for this method.") if network_alphas and metadata: - raise ValueError("Both `network_alphas` and `metadata` cannot be specified.") + raise ValueError("`network_alphas` and `metadata` cannot be specified both at the same time.") peft_kwargs = {} if low_cpu_mem_usage: @@ -423,7 +423,7 @@ def _load_lora_into_text_encoder( try: lora_config = LoraConfig(**lora_config_kwargs) except TypeError as e: - raise TypeError(f"`LoraConfig` class could not be instantiated:\n{e}.") + raise TypeError("`LoraConfig` class could not be instantiated.") from e # adapter_name if adapter_name is None: @@ -933,7 +933,7 @@ def write_lora_layers( if lora_adapter_metadata and not safe_serialization: raise ValueError("`lora_adapter_metadata` cannot be specified when not using `safe_serialization`.") if lora_adapter_metadata and not isinstance(lora_adapter_metadata, dict): - raise ValueError("`lora_adapter_metadata` must be of type `dict`.") + raise TypeError("`lora_adapter_metadata` must be of type `dict`.") if save_function is None: if safe_serialization: From 29ff6f1e7b7426b18446320e9a7cffbd19e75c5c Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Fri, 6 Jun 2025 15:59:08 +0530 Subject: [PATCH 52/53] reraise. --- src/diffusers/loaders/peft.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/loaders/peft.py b/src/diffusers/loaders/peft.py index 4273f70702f0..5006ada4c030 100644 --- a/src/diffusers/loaders/peft.py +++ b/src/diffusers/loaders/peft.py @@ -302,7 +302,7 @@ def load_lora_adapter( try: lora_config = LoraConfig(**lora_config_kwargs) except TypeError as e: - raise TypeError(f"`LoraConfig` class could not be instantiated:\n{e}.") + raise TypeError("`LoraConfig` class could not be instantiated.") from e # adapter_name if adapter_name is None: From 37a225a6b35d679c590cf6050e0683c51fe58ae9 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Fri, 13 Jun 2025 07:11:46 +0530 Subject: [PATCH 53/53] move argument. --- src/diffusers/loaders/lora_base.py | 2 +- src/diffusers/loaders/lora_pipeline.py | 156 ++++++++++++------------- 2 files changed, 79 insertions(+), 79 deletions(-) diff --git a/src/diffusers/loaders/lora_base.py b/src/diffusers/loaders/lora_base.py index b2cbd8760fb5..b20b56340ea4 100644 --- a/src/diffusers/loaders/lora_base.py +++ b/src/diffusers/loaders/lora_base.py @@ -327,10 +327,10 @@ def _load_lora_into_text_encoder( lora_scale=1.0, text_encoder_name="text_encoder", adapter_name=None, - metadata=None, _pipeline=None, low_cpu_mem_usage=False, hotswap: bool = False, + metadata=None, ): if not USE_PEFT_BACKEND: raise ValueError("PEFT backend is required for this method.") diff --git a/src/diffusers/loaders/lora_pipeline.py b/src/diffusers/loaders/lora_pipeline.py index 42ab1b38ed44..27053623eeec 100644 --- a/src/diffusers/loaders/lora_pipeline.py +++ b/src/diffusers/loaders/lora_pipeline.py @@ -357,10 +357,10 @@ def load_lora_into_unet( network_alphas, unet, adapter_name=None, - metadata=None, _pipeline=None, low_cpu_mem_usage=False, hotswap: bool = False, + metadata=None, ): """ This will load the LoRA layers specified in `state_dict` into `unet`. @@ -379,14 +379,14 @@ def load_lora_into_unet( adapter_name (`str`, *optional*): Adapter name to be used for referencing the loaded adapter model. If not specified, it will use `default_{i}` where i is the total number of adapters being loaded. - metadata (`dict`): - Optional LoRA adapter metadata. When supplied, the `LoraConfig` arguments of `peft` won't be derived - from the state dict. low_cpu_mem_usage (`bool`, *optional*): Speed up model loading only loading the pretrained LoRA weights and not initializing the random weights. hotswap (`bool`, *optional*): See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`]. + metadata (`dict`): + Optional LoRA adapter metadata. When supplied, the `LoraConfig` arguments of `peft` won't be derived + from the state dict. """ if not USE_PEFT_BACKEND: raise ValueError("PEFT backend is required for this method.") @@ -420,10 +420,10 @@ def load_lora_into_text_encoder( prefix=None, lora_scale=1.0, adapter_name=None, - metadata=None, _pipeline=None, low_cpu_mem_usage=False, hotswap: bool = False, + metadata=None, ): """ This will load the LoRA layers specified in `state_dict` into `text_encoder` @@ -446,14 +446,14 @@ def load_lora_into_text_encoder( adapter_name (`str`, *optional*): Adapter name to be used for referencing the loaded adapter model. If not specified, it will use `default_{i}` where i is the total number of adapters being loaded. - metadata (`dict`): - Optional LoRA adapter metadata. When supplied, the `LoraConfig` arguments of `peft` won't be derived - from the state dict. low_cpu_mem_usage (`bool`, *optional*): Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights. hotswap (`bool`, *optional*): See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`]. + metadata (`dict`): + Optional LoRA adapter metadata. When supplied, the `LoraConfig` arguments of `peft` won't be derived + from the state dict. """ _load_lora_into_text_encoder( state_dict=state_dict, @@ -843,10 +843,10 @@ def load_lora_into_unet( network_alphas, unet, adapter_name=None, - metadata=None, _pipeline=None, low_cpu_mem_usage=False, hotswap: bool = False, + metadata=None, ): """ This will load the LoRA layers specified in `state_dict` into `unet`. @@ -865,14 +865,14 @@ def load_lora_into_unet( adapter_name (`str`, *optional*): Adapter name to be used for referencing the loaded adapter model. If not specified, it will use `default_{i}` where i is the total number of adapters being loaded. - metadata (`dict`): - Optional LoRA adapter metadata. When supplied, the `LoraConfig` arguments of `peft` won't be derived - from the state dict. low_cpu_mem_usage (`bool`, *optional*): Speed up model loading only loading the pretrained LoRA weights and not initializing the random weights. hotswap (`bool`, *optional*): See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`]. + metadata (`dict`): + Optional LoRA adapter metadata. When supplied, the `LoraConfig` arguments of `peft` won't be derived + from the state dict. """ if not USE_PEFT_BACKEND: raise ValueError("PEFT backend is required for this method.") @@ -907,10 +907,10 @@ def load_lora_into_text_encoder( prefix=None, lora_scale=1.0, adapter_name=None, - metadata=None, _pipeline=None, low_cpu_mem_usage=False, hotswap: bool = False, + metadata=None, ): """ This will load the LoRA layers specified in `state_dict` into `text_encoder` @@ -933,14 +933,14 @@ def load_lora_into_text_encoder( adapter_name (`str`, *optional*): Adapter name to be used for referencing the loaded adapter model. If not specified, it will use `default_{i}` where i is the total number of adapters being loaded. - metadata (`dict`): - Optional LoRA adapter metadata. When supplied, the `LoraConfig` arguments of `peft` won't be derived - from the state dict. low_cpu_mem_usage (`bool`, *optional*): Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights. hotswap (`bool`, *optional*): See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`]. + metadata (`dict`): + Optional LoRA adapter metadata. When supplied, the `LoraConfig` arguments of `peft` won't be derived + from the state dict. """ _load_lora_into_text_encoder( state_dict=state_dict, @@ -1315,10 +1315,10 @@ def load_lora_into_transformer( state_dict, transformer, adapter_name=None, - metadata=None, _pipeline=None, low_cpu_mem_usage=False, hotswap: bool = False, + metadata=None, ): """ This will load the LoRA layers specified in `state_dict` into `transformer`. @@ -1333,14 +1333,14 @@ def load_lora_into_transformer( adapter_name (`str`, *optional*): Adapter name to be used for referencing the loaded adapter model. If not specified, it will use `default_{i}` where i is the total number of adapters being loaded. - metadata (`dict`): - Optional LoRA adapter metadata. When supplied, the `LoraConfig` arguments of `peft` won't be derived - from the state dict. low_cpu_mem_usage (`bool`, *optional*): Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights. hotswap (`bool`, *optional*): See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`]. + metadata (`dict`): + Optional LoRA adapter metadata. When supplied, the `LoraConfig` arguments of `peft` won't be derived + from the state dict. """ if low_cpu_mem_usage and is_peft_version("<", "0.13.0"): raise ValueError( @@ -1369,10 +1369,10 @@ def load_lora_into_text_encoder( prefix=None, lora_scale=1.0, adapter_name=None, - metadata=None, _pipeline=None, low_cpu_mem_usage=False, hotswap: bool = False, + metadata=None, ): """ This will load the LoRA layers specified in `state_dict` into `text_encoder` @@ -1395,14 +1395,14 @@ def load_lora_into_text_encoder( adapter_name (`str`, *optional*): Adapter name to be used for referencing the loaded adapter model. If not specified, it will use `default_{i}` where i is the total number of adapters being loaded. - metadata (`dict`): - Optional LoRA adapter metadata. When supplied, the `LoraConfig` arguments of `peft` won't be derived - from the state dict. low_cpu_mem_usage (`bool`, *optional*): Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights. hotswap (`bool`, *optional*): See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`]. + metadata (`dict`): + Optional LoRA adapter metadata. When supplied, the `LoraConfig` arguments of `peft` won't be derived + from the state dict. """ _load_lora_into_text_encoder( state_dict=state_dict, @@ -1751,10 +1751,10 @@ def load_lora_into_transformer( state_dict, transformer, adapter_name=None, - metadata=None, _pipeline=None, low_cpu_mem_usage=False, hotswap: bool = False, + metadata=None, ): """ This will load the LoRA layers specified in `state_dict` into `transformer`. @@ -1769,14 +1769,14 @@ def load_lora_into_transformer( adapter_name (`str`, *optional*): Adapter name to be used for referencing the loaded adapter model. If not specified, it will use `default_{i}` where i is the total number of adapters being loaded. - metadata (`dict`): - Optional LoRA adapter metadata. When supplied, the `LoraConfig` arguments of `peft` won't be derived - from the state dict. low_cpu_mem_usage (`bool`, *optional*): Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights. hotswap (`bool`, *optional*): See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`]. + metadata (`dict`): + Optional LoRA adapter metadata. When supplied, the `LoraConfig` arguments of `peft` won't be derived + from the state dict. """ if low_cpu_mem_usage and is_peft_version("<", "0.13.0"): raise ValueError( @@ -2224,14 +2224,14 @@ def load_lora_into_transformer( adapter_name (`str`, *optional*): Adapter name to be used for referencing the loaded adapter model. If not specified, it will use `default_{i}` where i is the total number of adapters being loaded. - metadata (`dict`): - Optional LoRA adapter metadata. When supplied, the `LoraConfig` arguments of `peft` won't be derived - from the state dict. low_cpu_mem_usage (`bool`, *optional*): Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights. hotswap (`bool`, *optional*): See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`]. + metadata (`dict`): + Optional LoRA adapter metadata. When supplied, the `LoraConfig` arguments of `peft` won't be derived + from the state dict. """ if low_cpu_mem_usage and not is_peft_version(">=", "0.13.1"): raise ValueError( @@ -2314,10 +2314,10 @@ def load_lora_into_text_encoder( prefix=None, lora_scale=1.0, adapter_name=None, - metadata=None, _pipeline=None, low_cpu_mem_usage=False, hotswap: bool = False, + metadata=None, ): """ This will load the LoRA layers specified in `state_dict` into `text_encoder` @@ -2340,14 +2340,14 @@ def load_lora_into_text_encoder( adapter_name (`str`, *optional*): Adapter name to be used for referencing the loaded adapter model. If not specified, it will use `default_{i}` where i is the total number of adapters being loaded. - metadata (`dict`): - Optional LoRA adapter metadata. When supplied, the `LoraConfig` arguments of `peft` won't be derived - from the state dict. low_cpu_mem_usage (`bool`, *optional*): Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights. hotswap (`bool`, *optional*): See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`]. + metadata (`dict`): + Optional LoRA adapter metadata. When supplied, the `LoraConfig` arguments of `peft` won't be derived + from the state dict. """ _load_lora_into_text_encoder( state_dict=state_dict, @@ -2823,14 +2823,14 @@ def load_lora_into_transformer( adapter_name (`str`, *optional*): Adapter name to be used for referencing the loaded adapter model. If not specified, it will use `default_{i}` where i is the total number of adapters being loaded. - metadata (`dict`): - Optional LoRA adapter metadata. When supplied, the `LoraConfig` arguments of `peft` won't be derived - from the state dict. low_cpu_mem_usage (`bool`, *optional*): Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights. hotswap (`bool`, *optional*): See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`]. + metadata (`dict`): + Optional LoRA adapter metadata. When supplied, the `LoraConfig` arguments of `peft` won't be derived + from the state dict. """ if low_cpu_mem_usage and not is_peft_version(">=", "0.13.1"): raise ValueError( @@ -2859,10 +2859,10 @@ def load_lora_into_text_encoder( prefix=None, lora_scale=1.0, adapter_name=None, - metadata=None, _pipeline=None, low_cpu_mem_usage=False, hotswap: bool = False, + metadata=None, ): """ This will load the LoRA layers specified in `state_dict` into `text_encoder` @@ -2885,14 +2885,14 @@ def load_lora_into_text_encoder( adapter_name (`str`, *optional*): Adapter name to be used for referencing the loaded adapter model. If not specified, it will use `default_{i}` where i is the total number of adapters being loaded. - metadata (`dict`): - Optional LoRA adapter metadata. When supplied, the `LoraConfig` arguments of `peft` won't be derived - from the state dict. low_cpu_mem_usage (`bool`, *optional*): Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights. hotswap (`bool`, *optional*): See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`]. + metadata (`dict`): + Optional LoRA adapter metadata. When supplied, the `LoraConfig` arguments of `peft` won't be derived + from the state dict. """ _load_lora_into_text_encoder( state_dict=state_dict, @@ -3136,10 +3136,10 @@ def load_lora_into_transformer( state_dict, transformer, adapter_name=None, - metadata=None, _pipeline=None, low_cpu_mem_usage=False, hotswap: bool = False, + metadata=None, ): """ This will load the LoRA layers specified in `state_dict` into `transformer`. @@ -3154,14 +3154,14 @@ def load_lora_into_transformer( adapter_name (`str`, *optional*): Adapter name to be used for referencing the loaded adapter model. If not specified, it will use `default_{i}` where i is the total number of adapters being loaded. - metadata (`dict`): - Optional LoRA adapter metadata. When supplied, the `LoraConfig` arguments of `peft` won't be derived - from the state dict. low_cpu_mem_usage (`bool`, *optional*): Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights. hotswap (`bool`, *optional*): See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`]. + metadata (`dict`): + Optional LoRA adapter metadata. When supplied, the `LoraConfig` arguments of `peft` won't be derived + from the state dict. """ if low_cpu_mem_usage and is_peft_version("<", "0.13.0"): raise ValueError( @@ -3476,10 +3476,10 @@ def load_lora_into_transformer( state_dict, transformer, adapter_name=None, - metadata=None, _pipeline=None, low_cpu_mem_usage=False, hotswap: bool = False, + metadata=None, ): """ This will load the LoRA layers specified in `state_dict` into `transformer`. @@ -3494,14 +3494,14 @@ def load_lora_into_transformer( adapter_name (`str`, *optional*): Adapter name to be used for referencing the loaded adapter model. If not specified, it will use `default_{i}` where i is the total number of adapters being loaded. - metadata (`dict`): - Optional LoRA adapter metadata. When supplied, the `LoraConfig` arguments of `peft` won't be derived - from the state dict. low_cpu_mem_usage (`bool`, *optional*): Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights. hotswap (`bool`, *optional*): See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`]. + metadata (`dict`): + Optional LoRA adapter metadata. When supplied, the `LoraConfig` arguments of `peft` won't be derived + from the state dict. """ if low_cpu_mem_usage and is_peft_version("<", "0.13.0"): raise ValueError( @@ -3820,10 +3820,10 @@ def load_lora_into_transformer( state_dict, transformer, adapter_name=None, - metadata=None, _pipeline=None, low_cpu_mem_usage=False, hotswap: bool = False, + metadata=None, ): """ This will load the LoRA layers specified in `state_dict` into `transformer`. @@ -3838,14 +3838,14 @@ def load_lora_into_transformer( adapter_name (`str`, *optional*): Adapter name to be used for referencing the loaded adapter model. If not specified, it will use `default_{i}` where i is the total number of adapters being loaded. - metadata (`dict`): - Optional LoRA adapter metadata. When supplied, the `LoraConfig` arguments of `peft` won't be derived - from the state dict. low_cpu_mem_usage (`bool`, *optional*): Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights. hotswap (`bool`, *optional*): See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`]. + metadata (`dict`): + Optional LoRA adapter metadata. When supplied, the `LoraConfig` arguments of `peft` won't be derived + from the state dict. """ if low_cpu_mem_usage and is_peft_version("<", "0.13.0"): raise ValueError( @@ -4162,10 +4162,10 @@ def load_lora_into_transformer( state_dict, transformer, adapter_name=None, - metadata=None, _pipeline=None, low_cpu_mem_usage=False, hotswap: bool = False, + metadata=None, ): """ This will load the LoRA layers specified in `state_dict` into `transformer`. @@ -4180,14 +4180,14 @@ def load_lora_into_transformer( adapter_name (`str`, *optional*): Adapter name to be used for referencing the loaded adapter model. If not specified, it will use `default_{i}` where i is the total number of adapters being loaded. - metadata (`dict`): - Optional LoRA adapter metadata. When supplied, the `LoraConfig` arguments of `peft` won't be derived - from the state dict. low_cpu_mem_usage (`bool`, *optional*): Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights. hotswap (`bool`, *optional*): See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`]. + metadata (`dict`): + Optional LoRA adapter metadata. When supplied, the `LoraConfig` arguments of `peft` won't be derived + from the state dict. """ if low_cpu_mem_usage and is_peft_version("<", "0.13.0"): raise ValueError( @@ -4506,10 +4506,10 @@ def load_lora_into_transformer( state_dict, transformer, adapter_name=None, - metadata=None, _pipeline=None, low_cpu_mem_usage=False, hotswap: bool = False, + metadata=None, ): """ This will load the LoRA layers specified in `state_dict` into `transformer`. @@ -4524,14 +4524,14 @@ def load_lora_into_transformer( adapter_name (`str`, *optional*): Adapter name to be used for referencing the loaded adapter model. If not specified, it will use `default_{i}` where i is the total number of adapters being loaded. - metadata (`dict`): - Optional LoRA adapter metadata. When supplied, the `LoraConfig` arguments of `peft` won't be derived - from the state dict. low_cpu_mem_usage (`bool`, *optional*): Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights. hotswap (`bool`, *optional*): See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`]. + metadata (`dict`): + Optional LoRA adapter metadata. When supplied, the `LoraConfig` arguments of `peft` won't be derived + from the state dict. """ if low_cpu_mem_usage and is_peft_version("<", "0.13.0"): raise ValueError( @@ -4851,10 +4851,10 @@ def load_lora_into_transformer( state_dict, transformer, adapter_name=None, - metadata=None, _pipeline=None, low_cpu_mem_usage=False, hotswap: bool = False, + metadata=None, ): """ This will load the LoRA layers specified in `state_dict` into `transformer`. @@ -4869,14 +4869,14 @@ def load_lora_into_transformer( adapter_name (`str`, *optional*): Adapter name to be used for referencing the loaded adapter model. If not specified, it will use `default_{i}` where i is the total number of adapters being loaded. - metadata (`dict`): - Optional LoRA adapter metadata. When supplied, the `LoraConfig` arguments of `peft` won't be derived - from the state dict. low_cpu_mem_usage (`bool`, *optional*): Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights. hotswap (`bool`, *optional*): See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`]. + metadata (`dict`): + Optional LoRA adapter metadata. When supplied, the `LoraConfig` arguments of `peft` won't be derived + from the state dict. """ if low_cpu_mem_usage and is_peft_version("<", "0.13.0"): raise ValueError( @@ -5246,10 +5246,10 @@ def load_lora_into_transformer( state_dict, transformer, adapter_name=None, - metadata=None, _pipeline=None, low_cpu_mem_usage=False, hotswap: bool = False, + metadata=None, ): """ This will load the LoRA layers specified in `state_dict` into `transformer`. @@ -5264,14 +5264,14 @@ def load_lora_into_transformer( adapter_name (`str`, *optional*): Adapter name to be used for referencing the loaded adapter model. If not specified, it will use `default_{i}` where i is the total number of adapters being loaded. - metadata (`dict`): - Optional LoRA adapter metadata. When supplied, the `LoraConfig` arguments of `peft` won't be derived - from the state dict. low_cpu_mem_usage (`bool`, *optional*): Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights. hotswap (`bool`, *optional*): See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`]. + metadata (`dict`): + Optional LoRA adapter metadata. When supplied, the `LoraConfig` arguments of `peft` won't be derived + from the state dict. """ if low_cpu_mem_usage and is_peft_version("<", "0.13.0"): raise ValueError( @@ -5588,10 +5588,10 @@ def load_lora_into_transformer( state_dict, transformer, adapter_name=None, - metadata=None, _pipeline=None, low_cpu_mem_usage=False, hotswap: bool = False, + metadata=None, ): """ This will load the LoRA layers specified in `state_dict` into `transformer`. @@ -5606,14 +5606,14 @@ def load_lora_into_transformer( adapter_name (`str`, *optional*): Adapter name to be used for referencing the loaded adapter model. If not specified, it will use `default_{i}` where i is the total number of adapters being loaded. - metadata (`dict`): - Optional LoRA adapter metadata. When supplied, the `LoraConfig` arguments of `peft` won't be derived - from the state dict. low_cpu_mem_usage (`bool`, *optional*): Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights. hotswap (`bool`, *optional*): See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`]. + metadata (`dict`): + Optional LoRA adapter metadata. When supplied, the `LoraConfig` arguments of `peft` won't be derived + from the state dict. """ if low_cpu_mem_usage and is_peft_version("<", "0.13.0"): raise ValueError( @@ -5932,10 +5932,10 @@ def load_lora_into_transformer( state_dict, transformer, adapter_name=None, - metadata=None, _pipeline=None, low_cpu_mem_usage=False, hotswap: bool = False, + metadata=None, ): """ This will load the LoRA layers specified in `state_dict` into `transformer`. @@ -5950,14 +5950,14 @@ def load_lora_into_transformer( adapter_name (`str`, *optional*): Adapter name to be used for referencing the loaded adapter model. If not specified, it will use `default_{i}` where i is the total number of adapters being loaded. - metadata (`dict`): - Optional LoRA adapter metadata. When supplied, the `LoraConfig` arguments of `peft` won't be derived - from the state dict. low_cpu_mem_usage (`bool`, *optional*): Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights. hotswap (`bool`, *optional*): See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`]. + metadata (`dict`): + Optional LoRA adapter metadata. When supplied, the `LoraConfig` arguments of `peft` won't be derived + from the state dict. """ if low_cpu_mem_usage and is_peft_version("<", "0.13.0"): raise ValueError(