From 07123e14d9efd5b9ddc9f31ad4dd14ea8f2e9869 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Fri, 18 Oct 2024 20:11:55 +0530 Subject: [PATCH 01/12] add first draft. --- src/diffusers/loaders/lora_base.py | 237 ++++++++++++------------- src/diffusers/loaders/lora_pipeline.py | 120 +++---------- src/diffusers/loaders/peft.py | 173 ++++++++++++++++++ tests/lora/utils.py | 4 +- 4 files changed, 318 insertions(+), 216 deletions(-) diff --git a/src/diffusers/loaders/lora_base.py b/src/diffusers/loaders/lora_base.py index e124b6eeacf3..b1809967dca0 100644 --- a/src/diffusers/loaders/lora_base.py +++ b/src/diffusers/loaders/lora_base.py @@ -181,6 +181,123 @@ def _remove_text_encoder_monkey_patch(text_encoder): text_encoder._hf_peft_config_loaded = None +def _fetch_state_dict( + pretrained_model_name_or_path_or_dict, + weight_name, + use_safetensors, + local_files_only, + cache_dir, + force_download, + proxies, + token, + revision, + subfolder, + user_agent, + allow_pickle, +): + from .lora_pipeline import LORA_WEIGHT_NAME, LORA_WEIGHT_NAME_SAFE + + model_file = None + if not isinstance(pretrained_model_name_or_path_or_dict, dict): + # Let's first try to load .safetensors weights + if (use_safetensors and weight_name is None) or ( + weight_name is not None and weight_name.endswith(".safetensors") + ): + try: + # Here we're relaxing the loading check to enable more Inference API + # friendliness where sometimes, it's not at all possible to automatically + # determine `weight_name`. + if weight_name is None: + weight_name = _best_guess_weight_name( + pretrained_model_name_or_path_or_dict, + file_extension=".safetensors", + local_files_only=local_files_only, + ) + model_file = _get_model_file( + pretrained_model_name_or_path_or_dict, + weights_name=weight_name or LORA_WEIGHT_NAME_SAFE, + cache_dir=cache_dir, + force_download=force_download, + proxies=proxies, + local_files_only=local_files_only, + token=token, + revision=revision, + subfolder=subfolder, + user_agent=user_agent, + ) + state_dict = safetensors.torch.load_file(model_file, device="cpu") + except (IOError, safetensors.SafetensorError) as e: + if not allow_pickle: + raise e + # try loading non-safetensors weights + model_file = None + pass + + if model_file is None: + if weight_name is None: + weight_name = _best_guess_weight_name( + pretrained_model_name_or_path_or_dict, file_extension=".bin", local_files_only=local_files_only + ) + model_file = _get_model_file( + pretrained_model_name_or_path_or_dict, + weights_name=weight_name or LORA_WEIGHT_NAME, + cache_dir=cache_dir, + force_download=force_download, + proxies=proxies, + local_files_only=local_files_only, + token=token, + revision=revision, + subfolder=subfolder, + user_agent=user_agent, + ) + state_dict = load_state_dict(model_file) + else: + state_dict = pretrained_model_name_or_path_or_dict + + return state_dict + + +def _best_guess_weight_name( + pretrained_model_name_or_path_or_dict, file_extension=".safetensors", local_files_only=False +): + from .lora_pipeline import LORA_WEIGHT_NAME, LORA_WEIGHT_NAME_SAFE + + if local_files_only or HF_HUB_OFFLINE: + raise ValueError("When using the offline mode, you must specify a `weight_name`.") + + targeted_files = [] + + if os.path.isfile(pretrained_model_name_or_path_or_dict): + return + elif os.path.isdir(pretrained_model_name_or_path_or_dict): + targeted_files = [f for f in os.listdir(pretrained_model_name_or_path_or_dict) if f.endswith(file_extension)] + else: + files_in_repo = model_info(pretrained_model_name_or_path_or_dict).siblings + targeted_files = [f.rfilename for f in files_in_repo if f.rfilename.endswith(file_extension)] + if len(targeted_files) == 0: + return + + # "scheduler" does not correspond to a LoRA checkpoint. + # "optimizer" does not correspond to a LoRA checkpoint + # only top-level checkpoints are considered and not the other ones, hence "checkpoint". + unallowed_substrings = {"scheduler", "optimizer", "checkpoint"} + targeted_files = list( + filter(lambda x: all(substring not in x for substring in unallowed_substrings), targeted_files) + ) + + if any(f.endswith(LORA_WEIGHT_NAME) for f in targeted_files): + targeted_files = list(filter(lambda x: x.endswith(LORA_WEIGHT_NAME), targeted_files)) + elif any(f.endswith(LORA_WEIGHT_NAME_SAFE) for f in targeted_files): + targeted_files = list(filter(lambda x: x.endswith(LORA_WEIGHT_NAME_SAFE), targeted_files)) + + if len(targeted_files) > 1: + raise ValueError( + f"Provided path contains more than one weights file in the {file_extension} format. Either specify `weight_name` in `load_lora_weights` or make sure there's only one `.safetensors` or `.bin` file in {pretrained_model_name_or_path_or_dict}." + ) + weight_name = targeted_files[0] + return weight_name + + class LoraBaseMixin: """Utility class for handling LoRAs.""" @@ -233,126 +350,6 @@ def _optionally_disable_offloading(cls, _pipeline): return (is_model_cpu_offload, is_sequential_cpu_offload) - @classmethod - def _fetch_state_dict( - cls, - pretrained_model_name_or_path_or_dict, - weight_name, - use_safetensors, - local_files_only, - cache_dir, - force_download, - proxies, - token, - revision, - subfolder, - user_agent, - allow_pickle, - ): - from .lora_pipeline import LORA_WEIGHT_NAME, LORA_WEIGHT_NAME_SAFE - - model_file = None - if not isinstance(pretrained_model_name_or_path_or_dict, dict): - # Let's first try to load .safetensors weights - if (use_safetensors and weight_name is None) or ( - weight_name is not None and weight_name.endswith(".safetensors") - ): - try: - # Here we're relaxing the loading check to enable more Inference API - # friendliness where sometimes, it's not at all possible to automatically - # determine `weight_name`. - if weight_name is None: - weight_name = cls._best_guess_weight_name( - pretrained_model_name_or_path_or_dict, - file_extension=".safetensors", - local_files_only=local_files_only, - ) - model_file = _get_model_file( - pretrained_model_name_or_path_or_dict, - weights_name=weight_name or LORA_WEIGHT_NAME_SAFE, - cache_dir=cache_dir, - force_download=force_download, - proxies=proxies, - local_files_only=local_files_only, - token=token, - revision=revision, - subfolder=subfolder, - user_agent=user_agent, - ) - state_dict = safetensors.torch.load_file(model_file, device="cpu") - except (IOError, safetensors.SafetensorError) as e: - if not allow_pickle: - raise e - # try loading non-safetensors weights - model_file = None - pass - - if model_file is None: - if weight_name is None: - weight_name = cls._best_guess_weight_name( - pretrained_model_name_or_path_or_dict, file_extension=".bin", local_files_only=local_files_only - ) - model_file = _get_model_file( - pretrained_model_name_or_path_or_dict, - weights_name=weight_name or LORA_WEIGHT_NAME, - cache_dir=cache_dir, - force_download=force_download, - proxies=proxies, - local_files_only=local_files_only, - token=token, - revision=revision, - subfolder=subfolder, - user_agent=user_agent, - ) - state_dict = load_state_dict(model_file) - else: - state_dict = pretrained_model_name_or_path_or_dict - - return state_dict - - @classmethod - def _best_guess_weight_name( - cls, pretrained_model_name_or_path_or_dict, file_extension=".safetensors", local_files_only=False - ): - from .lora_pipeline import LORA_WEIGHT_NAME, LORA_WEIGHT_NAME_SAFE - - if local_files_only or HF_HUB_OFFLINE: - raise ValueError("When using the offline mode, you must specify a `weight_name`.") - - targeted_files = [] - - if os.path.isfile(pretrained_model_name_or_path_or_dict): - return - elif os.path.isdir(pretrained_model_name_or_path_or_dict): - targeted_files = [ - f for f in os.listdir(pretrained_model_name_or_path_or_dict) if f.endswith(file_extension) - ] - else: - files_in_repo = model_info(pretrained_model_name_or_path_or_dict).siblings - targeted_files = [f.rfilename for f in files_in_repo if f.rfilename.endswith(file_extension)] - if len(targeted_files) == 0: - return - - # "scheduler" does not correspond to a LoRA checkpoint. - # "optimizer" does not correspond to a LoRA checkpoint - # only top-level checkpoints are considered and not the other ones, hence "checkpoint". - unallowed_substrings = {"scheduler", "optimizer", "checkpoint"} - targeted_files = list( - filter(lambda x: all(substring not in x for substring in unallowed_substrings), targeted_files) - ) - - if any(f.endswith(LORA_WEIGHT_NAME) for f in targeted_files): - targeted_files = list(filter(lambda x: x.endswith(LORA_WEIGHT_NAME), targeted_files)) - elif any(f.endswith(LORA_WEIGHT_NAME_SAFE) for f in targeted_files): - targeted_files = list(filter(lambda x: x.endswith(LORA_WEIGHT_NAME_SAFE), targeted_files)) - - if len(targeted_files) > 1: - raise ValueError( - f"Provided path contains more than one weights file in the {file_extension} format. Either specify `weight_name` in `load_lora_weights` or make sure there's only one `.safetensors` or `.bin` file in {pretrained_model_name_or_path_or_dict}." - ) - weight_name = targeted_files[0] - return weight_name - def unload_lora_weights(self): """ Unloads the LoRA parameters. diff --git a/src/diffusers/loaders/lora_pipeline.py b/src/diffusers/loaders/lora_pipeline.py index 5e01ec567f9a..61c5503e5c3a 100644 --- a/src/diffusers/loaders/lora_pipeline.py +++ b/src/diffusers/loaders/lora_pipeline.py @@ -33,7 +33,7 @@ logging, scale_lora_layers, ) -from .lora_base import LoraBaseMixin +from .lora_base import LoraBaseMixin, _fetch_state_dict from .lora_conversion_utils import ( _convert_kohya_flux_lora_to_diffusers, _convert_non_diffusers_lora_to_diffusers, @@ -222,7 +222,7 @@ def lora_state_dict( "framework": "pytorch", } - state_dict = cls._fetch_state_dict( + state_dict = _fetch_state_dict( pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict, weight_name=weight_name, use_safetensors=use_safetensors, @@ -744,7 +744,7 @@ def lora_state_dict( "framework": "pytorch", } - state_dict = cls._fetch_state_dict( + state_dict = _fetch_state_dict( pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict, weight_name=weight_name, use_safetensors=use_safetensors, @@ -1182,7 +1182,7 @@ def lora_state_dict( "framework": "pytorch", } - state_dict = cls._fetch_state_dict( + state_dict = _fetch_state_dict( pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict, weight_name=weight_name, use_safetensors=use_safetensors, @@ -1250,13 +1250,17 @@ def load_lora_weights( if not is_correct_format: raise ValueError("Invalid LoRA checkpoint.") - self.load_lora_into_transformer( - state_dict, - transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer, - adapter_name=adapter_name, - _pipeline=self, - low_cpu_mem_usage=low_cpu_mem_usage, - ) + transformer_state_dict = {k: v for k, v in state_dict.items() if "transformer." in k} + if len(transformer_state_dict) > 0: + self.load_lora_into_transformer( + state_dict, + transformer=getattr(self, self.transformer_name) + if not hasattr(self, "transformer") + else self.transformer, + adapter_name=adapter_name, + _pipeline=self, + low_cpu_mem_usage=low_cpu_mem_usage, + ) text_encoder_state_dict = {k: v for k, v in state_dict.items() if "text_encoder." in k} if len(text_encoder_state_dict) > 0: @@ -1308,87 +1312,15 @@ def load_lora_into_transformer( "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`." ) - from peft import LoraConfig, inject_adapter_in_model, set_peft_model_state_dict - - keys = list(state_dict.keys()) - - transformer_keys = [k for k in keys if k.startswith(cls.transformer_name)] - state_dict = { - k.replace(f"{cls.transformer_name}.", ""): v for k, v in state_dict.items() if k in transformer_keys - } - - if len(state_dict.keys()) > 0: - # check with first key if is not in peft format - first_key = next(iter(state_dict.keys())) - if "lora_A" not in first_key: - state_dict = convert_unet_state_dict_to_peft(state_dict) - - if adapter_name in getattr(transformer, "peft_config", {}): - raise ValueError( - f"Adapter name {adapter_name} already in use in the transformer - please select a new adapter name." - ) - - rank = {} - for key, val in state_dict.items(): - if "lora_B" in key: - rank[key] = val.shape[1] - - lora_config_kwargs = get_peft_kwargs(rank, network_alpha_dict=None, peft_state_dict=state_dict) - if "use_dora" in lora_config_kwargs: - if lora_config_kwargs["use_dora"] and is_peft_version("<", "0.9.0"): - raise ValueError( - "You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`." - ) - else: - lora_config_kwargs.pop("use_dora") - lora_config = LoraConfig(**lora_config_kwargs) - - # adapter_name - if adapter_name is None: - adapter_name = get_adapter_name(transformer) - - # In case the pipeline has been already offloaded to CPU - temporarily remove the hooks - # otherwise loading LoRA weights will lead to an error - is_model_cpu_offload, is_sequential_cpu_offload = cls._optionally_disable_offloading(_pipeline) - - peft_kwargs = {} - if is_peft_version(">=", "0.13.1"): - peft_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage - - inject_adapter_in_model(lora_config, transformer, adapter_name=adapter_name, **peft_kwargs) - incompatible_keys = set_peft_model_state_dict(transformer, state_dict, adapter_name, **peft_kwargs) - - warn_msg = "" - if incompatible_keys is not None: - # Check only for unexpected keys. - unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None) - if unexpected_keys: - lora_unexpected_keys = [k for k in unexpected_keys if "lora_" in k and adapter_name in k] - if lora_unexpected_keys: - warn_msg = ( - f"Loading adapter weights from state_dict led to unexpected keys found in the model:" - f" {', '.join(lora_unexpected_keys)}. " - ) - - # Filter missing keys specific to the current adapter. - missing_keys = getattr(incompatible_keys, "missing_keys", None) - if missing_keys: - lora_missing_keys = [k for k in missing_keys if "lora_" in k and adapter_name in k] - if lora_missing_keys: - warn_msg += ( - f"Loading adapter weights from state_dict led to missing keys in the model:" - f" {', '.join(lora_missing_keys)}." - ) - - if warn_msg: - logger.warning(warn_msg) - - # Offload back. - if is_model_cpu_offload: - _pipeline.enable_model_cpu_offload() - elif is_sequential_cpu_offload: - _pipeline.enable_sequential_cpu_offload() - # Unsafe code /> + # Load the layers corresponding to transformer. + logger.info(f"Loading {cls.transformer_name}.") + transformer.load_lora_adapter( + state_dict, + network_alphas=None, + adapter_name=adapter_name, + _pipeline=_pipeline, + low_cpu_mem_usage=low_cpu_mem_usage, + ) @classmethod # Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.load_lora_into_text_encoder @@ -1742,7 +1674,7 @@ def lora_state_dict( "framework": "pytorch", } - state_dict = cls._fetch_state_dict( + state_dict = _fetch_state_dict( pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict, weight_name=weight_name, use_safetensors=use_safetensors, @@ -2619,7 +2551,7 @@ def lora_state_dict( "framework": "pytorch", } - state_dict = cls._fetch_state_dict( + state_dict = _fetch_state_dict( pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict, weight_name=weight_name, use_safetensors=use_safetensors, diff --git a/src/diffusers/loaders/peft.py b/src/diffusers/loaders/peft.py index d1c6721512fa..ac6828571553 100644 --- a/src/diffusers/loaders/peft.py +++ b/src/diffusers/loaders/peft.py @@ -16,18 +16,32 @@ from functools import partial from typing import Dict, List, Optional, Union +import torch.nn as nn + from ..utils import ( MIN_PEFT_VERSION, USE_PEFT_BACKEND, check_peft_version, + convert_unet_state_dict_to_peft, delete_adapter_layers, + get_adapter_name, + get_peft_kwargs, + is_accelerate_available, is_peft_available, + is_peft_version, + logging, set_adapter_layers, set_weights_and_activate_adapters, ) +from .lora_base import _fetch_state_dict from .unet_loader_utils import _maybe_expand_lora_scales +if is_accelerate_available(): + from accelerate.hooks import AlignDevicesHook, CpuOffload, remove_hook_from_module + +logger = logging.get_logger(__name__) + _SET_ADAPTER_SCALE_FN_MAPPING = { "UNet2DConditionModel": _maybe_expand_lora_scales, "UNetMotionModel": _maybe_expand_lora_scales, @@ -53,6 +67,165 @@ class PeftAdapterMixin: _hf_peft_config_loaded = False + @classmethod + # Copied from diffusers.loaders.lora_base.LoraBaseMixin._optionally_disable_offloading + def _optionally_disable_offloading(cls, _pipeline): + """ + Optionally removes offloading in case the pipeline has been already sequentially offloaded to CPU. + + Args: + _pipeline (`DiffusionPipeline`): + The pipeline to disable offloading for. + + Returns: + tuple: + A tuple indicating if `is_model_cpu_offload` or `is_sequential_cpu_offload` is True. + """ + is_model_cpu_offload = False + is_sequential_cpu_offload = False + + if _pipeline is not None and _pipeline.hf_device_map is None: + for _, component in _pipeline.components.items(): + if isinstance(component, nn.Module) and hasattr(component, "_hf_hook"): + if not is_model_cpu_offload: + is_model_cpu_offload = isinstance(component._hf_hook, CpuOffload) + if not is_sequential_cpu_offload: + is_sequential_cpu_offload = ( + isinstance(component._hf_hook, AlignDevicesHook) + or hasattr(component._hf_hook, "hooks") + and isinstance(component._hf_hook.hooks[0], AlignDevicesHook) + ) + + logger.info( + "Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again." + ) + remove_hook_from_module(component, recurse=is_sequential_cpu_offload) + + return (is_model_cpu_offload, is_sequential_cpu_offload) + + def load_lora_adapter(self, pretrained_model_name_or_path_or_dict, prefix="transformer", **kwargs): + from peft import LoraConfig, inject_adapter_in_model, set_peft_model_state_dict + + cache_dir = kwargs.pop("cache_dir", None) + force_download = kwargs.pop("force_download", False) + proxies = kwargs.pop("proxies", None) + local_files_only = kwargs.pop("local_files_only", None) + token = kwargs.pop("token", None) + revision = kwargs.pop("revision", None) + subfolder = kwargs.pop("subfolder", None) + weight_name = kwargs.pop("weight_name", None) + use_safetensors = kwargs.pop("use_safetensors", None) + adapter_name = kwargs.pop("adapter_name", None) + _pipeline = kwargs.pop("_pipeline", None) + low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", False) + allow_pickle = False + + if low_cpu_mem_usage and is_peft_version("<=", "0.13.0"): + raise ValueError( + "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`." + ) + + user_agent = { + "file_type": "attn_procs_weights", + "framework": "pytorch", + } + + state_dict = _fetch_state_dict( + pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict, + weight_name=weight_name, + use_safetensors=use_safetensors, + local_files_only=local_files_only, + cache_dir=cache_dir, + force_download=force_download, + proxies=proxies, + token=token, + revision=revision, + subfolder=subfolder, + user_agent=user_agent, + allow_pickle=allow_pickle, + ) + + keys = list(state_dict.keys()) + transformer_keys = [k for k in keys if k.startswith(prefix)] + if len(transformer_keys) > 0: + state_dict = {k.replace(f"{prefix}.", ""): v for k, v in state_dict.items() if k in transformer_keys} + + if len(state_dict.keys()) > 0: + # check with first key if is not in peft format + first_key = next(iter(state_dict.keys())) + if "lora_A" not in first_key: + state_dict = convert_unet_state_dict_to_peft(state_dict) + + if adapter_name in getattr(self, "peft_config", {}): + raise ValueError( + f"Adapter name {adapter_name} already in use in the transformer - please select a new adapter name." + ) + + rank = {} + for key, val in state_dict.items(): + if "lora_B" in key: + rank[key] = val.shape[1] + + lora_config_kwargs = get_peft_kwargs(rank, network_alpha_dict=None, peft_state_dict=state_dict) + if "use_dora" in lora_config_kwargs: + if lora_config_kwargs["use_dora"] and is_peft_version("<", "0.9.0"): + raise ValueError( + "You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`." + ) + else: + lora_config_kwargs.pop("use_dora") + lora_config = LoraConfig(**lora_config_kwargs) + + # adapter_name + if adapter_name is None: + adapter_name = get_adapter_name(self) + + # =", "0.13.1"): + peft_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage + + inject_adapter_in_model(lora_config, self, adapter_name=adapter_name, **peft_kwargs) + incompatible_keys = set_peft_model_state_dict(self, state_dict, adapter_name, **peft_kwargs) + + warn_msg = "" + if incompatible_keys is not None: + # Check only for unexpected keys. + unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None) + if unexpected_keys: + lora_unexpected_keys = [k for k in unexpected_keys if "lora_" in k and adapter_name in k] + if lora_unexpected_keys: + warn_msg = ( + f"Loading adapter weights from state_dict led to unexpected keys found in the model:" + f" {', '.join(lora_unexpected_keys)}. " + ) + + # Filter missing keys specific to the current adapter. + missing_keys = getattr(incompatible_keys, "missing_keys", None) + if missing_keys: + lora_missing_keys = [k for k in missing_keys if "lora_" in k and adapter_name in k] + if lora_missing_keys: + warn_msg += ( + f"Loading adapter weights from state_dict led to missing keys in the model:" + f" {', '.join(lora_missing_keys)}." + ) + + if warn_msg: + logger.warning(warn_msg) + + # Offload back. + if is_model_cpu_offload: + _pipeline.enable_model_cpu_offload() + elif is_sequential_cpu_offload: + _pipeline.enable_sequential_cpu_offload() + # Unsafe code /> + def set_adapters( self, adapter_names: Union[List[str], str], diff --git a/tests/lora/utils.py b/tests/lora/utils.py index e7fc840fcaa5..b711c8c9791e 100644 --- a/tests/lora/utils.py +++ b/tests/lora/utils.py @@ -1787,7 +1787,7 @@ def test_missing_keys_warning(self): logger = ( logging.get_logger("diffusers.loaders.unet") if self.unet_kwargs is not None - else logging.get_logger("diffusers.loaders.lora_pipeline") + else logging.get_logger("diffusers.loaders.peft") ) logger.setLevel(30) with CaptureLogger(logger) as cap_logger: @@ -1826,7 +1826,7 @@ def test_unexpected_keys_warning(self): logger = ( logging.get_logger("diffusers.loaders.unet") if self.unet_kwargs is not None - else logging.get_logger("diffusers.loaders.lora_pipeline") + else logging.get_logger("diffusers.loaders.peft") ) logger.setLevel(30) with CaptureLogger(logger) as cap_logger: From d099b8464864fe80ba5bcadbf2dcba6a45f31fd3 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Fri, 18 Oct 2024 20:19:09 +0530 Subject: [PATCH 02/12] fix --- src/diffusers/loaders/lora_pipeline.py | 146 +++++++++---------------- 1 file changed, 51 insertions(+), 95 deletions(-) diff --git a/src/diffusers/loaders/lora_pipeline.py b/src/diffusers/loaders/lora_pipeline.py index 61c5503e5c3a..31a38f2b7bb3 100644 --- a/src/diffusers/loaders/lora_pipeline.py +++ b/src/diffusers/loaders/lora_pipeline.py @@ -282,7 +282,9 @@ def load_lora_into_unet( adapter_name (`str`, *optional*): Adapter name to be used for referencing the loaded adapter model. If not specified, it will use `default_{i}` where i is the total number of adapters being loaded. - Speed up model loading only loading the pretrained LoRA weights and not initializing the random weights. + low_cpu_mem_usage (`bool`, *optional*): + Speed up model loading only loading the pretrained LoRA weights and not initializing the random + weights. """ if not USE_PEFT_BACKEND: raise ValueError("PEFT backend is required for this method.") @@ -341,7 +343,9 @@ def load_lora_into_text_encoder( adapter_name (`str`, *optional*): Adapter name to be used for referencing the loaded adapter model. If not specified, it will use `default_{i}` where i is the total number of adapters being loaded. - Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights.: + low_cpu_mem_usage (`bool`, *optional*): + Speed up model loading by only loading the pretrained LoRA weights and not initializing the random + weights. """ if not USE_PEFT_BACKEND: raise ValueError("PEFT backend is required for this method.") @@ -601,7 +605,9 @@ def load_lora_weights( adapter_name (`str`, *optional*): Adapter name to be used for referencing the loaded adapter model. If not specified, it will use `default_{i}` where i is the total number of adapters being loaded. - Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights.: + low_cpu_mem_usage (`bool`, *optional*): + Speed up model loading by only loading the pretrained LoRA weights and not initializing the random + weights. kwargs (`dict`, *optional*): See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`]. """ @@ -805,7 +811,9 @@ def load_lora_into_unet( adapter_name (`str`, *optional*): Adapter name to be used for referencing the loaded adapter model. If not specified, it will use `default_{i}` where i is the total number of adapters being loaded. - Speed up model loading only loading the pretrained LoRA weights and not initializing the random weights. + low_cpu_mem_usage (`bool`, *optional*): + Speed up model loading only loading the pretrained LoRA weights and not initializing the random + weights. """ if not USE_PEFT_BACKEND: raise ValueError("PEFT backend is required for this method.") @@ -865,7 +873,9 @@ def load_lora_into_text_encoder( adapter_name (`str`, *optional*): Adapter name to be used for referencing the loaded adapter model. If not specified, it will use `default_{i}` where i is the total number of adapters being loaded. - Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights.: + low_cpu_mem_usage (`bool`, *optional*): + Speed up model loading by only loading the pretrained LoRA weights and not initializing the random + weights. """ if not USE_PEFT_BACKEND: raise ValueError("PEFT backend is required for this method.") @@ -1226,7 +1236,9 @@ def load_lora_weights( adapter_name (`str`, *optional*): Adapter name to be used for referencing the loaded adapter model. If not specified, it will use `default_{i}` where i is the total number of adapters being loaded. - Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights.: + low_cpu_mem_usage (`bool`, *optional*): + Speed up model loading by only loading the pretrained LoRA weights and not initializing the random + weights. kwargs (`dict`, *optional*): See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`]. """ @@ -1305,7 +1317,9 @@ def load_lora_into_transformer( adapter_name (`str`, *optional*): Adapter name to be used for referencing the loaded adapter model. If not specified, it will use `default_{i}` where i is the total number of adapters being loaded. - Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights.: + low_cpu_mem_usage (`bool`, *optional*): + Speed up model loading by only loading the pretrained LoRA weights and not initializing the random + weights. """ if low_cpu_mem_usage and is_peft_version("<", "0.13.0"): raise ValueError( @@ -1356,7 +1370,9 @@ def load_lora_into_text_encoder( adapter_name (`str`, *optional*): Adapter name to be used for referencing the loaded adapter model. If not specified, it will use `default_{i}` where i is the total number of adapters being loaded. - Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights.: + low_cpu_mem_usage (`bool`, *optional*): + Speed up model loading by only loading the pretrained LoRA weights and not initializing the random + weights. """ if not USE_PEFT_BACKEND: raise ValueError("PEFT backend is required for this method.") @@ -1751,7 +1767,9 @@ def load_lora_weights( adapter_name (`str`, *optional*): Adapter name to be used for referencing the loaded adapter model. If not specified, it will use `default_{i}` where i is the total number of adapters being loaded. - Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights.: + low_cpu_mem_usage (`bool`, *optional*): + `Speed up model loading by only loading the pretrained LoRA weights and not initializing the random + weights. """ if not USE_PEFT_BACKEND: raise ValueError("PEFT backend is required for this method.") @@ -1818,7 +1836,9 @@ def load_lora_into_transformer( adapter_name (`str`, *optional*): Adapter name to be used for referencing the loaded adapter model. If not specified, it will use `default_{i}` where i is the total number of adapters being loaded. - Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights.: + low_cpu_mem_usage (`bool`, *optional*): + Speed up model loading by only loading the pretrained LoRA weights and not initializing the random + weights. """ if low_cpu_mem_usage and not is_peft_version(">=", "0.13.1"): raise ValueError( @@ -1946,7 +1966,9 @@ def load_lora_into_text_encoder( adapter_name (`str`, *optional*): Adapter name to be used for referencing the loaded adapter model. If not specified, it will use `default_{i}` where i is the total number of adapters being loaded. - Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights.: + low_cpu_mem_usage (`bool`, *optional*): + Speed up model loading by only loading the pretrained LoRA weights and not initializing the random + weights. """ if not USE_PEFT_BACKEND: raise ValueError("PEFT backend is required for this method.") @@ -2309,7 +2331,9 @@ def load_lora_into_text_encoder( adapter_name (`str`, *optional*): Adapter name to be used for referencing the loaded adapter model. If not specified, it will use `default_{i}` where i is the total number of adapters being loaded. - Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights.: + low_cpu_mem_usage (`bool`, *optional*): + Speed up model loading by only loading the pretrained LoRA weights and not initializing the random + weights. """ if not USE_PEFT_BACKEND: raise ValueError("PEFT backend is required for this method.") @@ -2590,7 +2614,9 @@ def load_lora_weights( adapter_name (`str`, *optional*): Adapter name to be used for referencing the loaded adapter model. If not specified, it will use `default_{i}` where i is the total number of adapters being loaded. - Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights.: + low_cpu_mem_usage (`bool`, *optional*): + Speed up model loading by only loading the pretrained LoRA weights and not initializing the random + weights. kwargs (`dict`, *optional*): See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`]. """ @@ -2640,94 +2666,24 @@ def load_lora_into_transformer( adapter_name (`str`, *optional*): Adapter name to be used for referencing the loaded adapter model. If not specified, it will use `default_{i}` where i is the total number of adapters being loaded. - Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights.: + low_cpu_mem_usage (`bool`, *optional*): + Speed up model loading by only loading the pretrained LoRA weights and not initializing the random + weights. """ if low_cpu_mem_usage and is_peft_version("<", "0.13.0"): raise ValueError( "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`." ) - from peft import LoraConfig, inject_adapter_in_model, set_peft_model_state_dict - - keys = list(state_dict.keys()) - - transformer_keys = [k for k in keys if k.startswith(cls.transformer_name)] - state_dict = { - k.replace(f"{cls.transformer_name}.", ""): v for k, v in state_dict.items() if k in transformer_keys - } - - if len(state_dict.keys()) > 0: - # check with first key if is not in peft format - first_key = next(iter(state_dict.keys())) - if "lora_A" not in first_key: - state_dict = convert_unet_state_dict_to_peft(state_dict) - - if adapter_name in getattr(transformer, "peft_config", {}): - raise ValueError( - f"Adapter name {adapter_name} already in use in the transformer - please select a new adapter name." - ) - - rank = {} - for key, val in state_dict.items(): - if "lora_B" in key: - rank[key] = val.shape[1] - - lora_config_kwargs = get_peft_kwargs(rank, network_alpha_dict=None, peft_state_dict=state_dict) - if "use_dora" in lora_config_kwargs: - if lora_config_kwargs["use_dora"] and is_peft_version("<", "0.9.0"): - raise ValueError( - "You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`." - ) - else: - lora_config_kwargs.pop("use_dora") - lora_config = LoraConfig(**lora_config_kwargs) - - # adapter_name - if adapter_name is None: - adapter_name = get_adapter_name(transformer) - - # In case the pipeline has been already offloaded to CPU - temporarily remove the hooks - # otherwise loading LoRA weights will lead to an error - is_model_cpu_offload, is_sequential_cpu_offload = cls._optionally_disable_offloading(_pipeline) - - peft_kwargs = {} - if is_peft_version(">=", "0.13.1"): - peft_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage - - inject_adapter_in_model(lora_config, transformer, adapter_name=adapter_name, **peft_kwargs) - incompatible_keys = set_peft_model_state_dict(transformer, state_dict, adapter_name, **peft_kwargs) - - warn_msg = "" - if incompatible_keys is not None: - # Check only for unexpected keys. - unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None) - if unexpected_keys: - lora_unexpected_keys = [k for k in unexpected_keys if "lora_" in k and adapter_name in k] - if lora_unexpected_keys: - warn_msg = ( - f"Loading adapter weights from state_dict led to unexpected keys found in the model:" - f" {', '.join(lora_unexpected_keys)}. " - ) - - # Filter missing keys specific to the current adapter. - missing_keys = getattr(incompatible_keys, "missing_keys", None) - if missing_keys: - lora_missing_keys = [k for k in missing_keys if "lora_" in k and adapter_name in k] - if lora_missing_keys: - warn_msg += ( - f"Loading adapter weights from state_dict led to missing keys in the model:" - f" {', '.join(lora_missing_keys)}." - ) - - if warn_msg: - logger.warning(warn_msg) - - # Offload back. - if is_model_cpu_offload: - _pipeline.enable_model_cpu_offload() - elif is_sequential_cpu_offload: - _pipeline.enable_sequential_cpu_offload() - # Unsafe code /> + # Load the layers corresponding to transformer. + logger.info(f"Loading {cls.transformer_name}.") + transformer.load_lora_adapter( + state_dict, + network_alphas=None, + adapter_name=adapter_name, + _pipeline=_pipeline, + low_cpu_mem_usage=low_cpu_mem_usage, + ) @classmethod # Adapted from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.save_lora_weights without support for text encoder From 4d307cc58cf81d6c024664c9bc07cfe370c445ec Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Fri, 18 Oct 2024 20:22:13 +0530 Subject: [PATCH 03/12] updates. --- src/diffusers/loaders/lora_pipeline.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/diffusers/loaders/lora_pipeline.py b/src/diffusers/loaders/lora_pipeline.py index 31a38f2b7bb3..27ff4a7791d0 100644 --- a/src/diffusers/loaders/lora_pipeline.py +++ b/src/diffusers/loaders/lora_pipeline.py @@ -2649,7 +2649,7 @@ def load_lora_weights( ) @classmethod - # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer + # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->CogVideoXTransformer3DModel def load_lora_into_transformer( cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False ): @@ -2661,7 +2661,7 @@ def load_lora_into_transformer( A standard state dict containing the lora layer parameters. The keys can either be indexed directly into the unet or prefixed with an additional `unet` which can be used to distinguish between text encoder lora layers. - transformer (`SD3Transformer2DModel`): + transformer (`CogVideoXTransformer3DModel`): The Transformer model to load the LoRA layers into. adapter_name (`str`, *optional*): Adapter name to be used for referencing the loaded adapter model. If not specified, it will use From 984b8c9490a82c1157a345dcccc0d17caae2560a Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Fri, 18 Oct 2024 20:41:07 +0530 Subject: [PATCH 04/12] updates. --- src/diffusers/loaders/lora_pipeline.py | 96 +++----------------------- src/diffusers/loaders/peft.py | 8 ++- 2 files changed, 16 insertions(+), 88 deletions(-) diff --git a/src/diffusers/loaders/lora_pipeline.py b/src/diffusers/loaders/lora_pipeline.py index 27ff4a7791d0..9d31ecdff2ef 100644 --- a/src/diffusers/loaders/lora_pipeline.py +++ b/src/diffusers/loaders/lora_pipeline.py @@ -21,7 +21,6 @@ USE_PEFT_BACKEND, convert_state_dict_to_diffusers, convert_state_dict_to_peft, - convert_unet_state_dict_to_peft, deprecate, get_adapter_name, get_peft_kwargs, @@ -1845,92 +1844,15 @@ def load_lora_into_transformer( "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`." ) - from peft import LoraConfig, inject_adapter_in_model, set_peft_model_state_dict - - keys = list(state_dict.keys()) - - transformer_keys = [k for k in keys if k.startswith(cls.transformer_name)] - state_dict = { - k.replace(f"{cls.transformer_name}.", ""): v for k, v in state_dict.items() if k in transformer_keys - } - - if len(state_dict.keys()) > 0: - # check with first key if is not in peft format - first_key = next(iter(state_dict.keys())) - if "lora_A" not in first_key: - state_dict = convert_unet_state_dict_to_peft(state_dict) - - if adapter_name in getattr(transformer, "peft_config", {}): - raise ValueError( - f"Adapter name {adapter_name} already in use in the transformer - please select a new adapter name." - ) - - rank = {} - for key, val in state_dict.items(): - if "lora_B" in key: - rank[key] = val.shape[1] - - if network_alphas is not None and len(network_alphas) >= 1: - prefix = cls.transformer_name - alpha_keys = [k for k in network_alphas.keys() if k.startswith(prefix) and k.split(".")[0] == prefix] - network_alphas = {k.replace(f"{prefix}.", ""): v for k, v in network_alphas.items() if k in alpha_keys} - - lora_config_kwargs = get_peft_kwargs(rank, network_alpha_dict=network_alphas, peft_state_dict=state_dict) - if "use_dora" in lora_config_kwargs: - if lora_config_kwargs["use_dora"] and is_peft_version("<", "0.9.0"): - raise ValueError( - "You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`." - ) - else: - lora_config_kwargs.pop("use_dora") - lora_config = LoraConfig(**lora_config_kwargs) - - # adapter_name - if adapter_name is None: - adapter_name = get_adapter_name(transformer) - - # In case the pipeline has been already offloaded to CPU - temporarily remove the hooks - # otherwise loading LoRA weights will lead to an error - is_model_cpu_offload, is_sequential_cpu_offload = cls._optionally_disable_offloading(_pipeline) - - peft_kwargs = {} - if is_peft_version(">=", "0.13.1"): - peft_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage - - inject_adapter_in_model(lora_config, transformer, adapter_name=adapter_name, **peft_kwargs) - incompatible_keys = set_peft_model_state_dict(transformer, state_dict, adapter_name, **peft_kwargs) - - warn_msg = "" - if incompatible_keys is not None: - # Check only for unexpected keys. - unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None) - if unexpected_keys: - lora_unexpected_keys = [k for k in unexpected_keys if "lora_" in k and adapter_name in k] - if lora_unexpected_keys: - warn_msg = ( - f"Loading adapter weights from state_dict led to unexpected keys found in the model:" - f" {', '.join(lora_unexpected_keys)}. " - ) - - # Filter missing keys specific to the current adapter. - missing_keys = getattr(incompatible_keys, "missing_keys", None) - if missing_keys: - lora_missing_keys = [k for k in missing_keys if "lora_" in k and adapter_name in k] - if lora_missing_keys: - warn_msg += ( - f"Loading adapter weights from state_dict led to missing keys in the model:" - f" {', '.join(lora_missing_keys)}." - ) - - if warn_msg: - logger.warning(warn_msg) - - # Offload back. - if is_model_cpu_offload: - _pipeline.enable_model_cpu_offload() - elif is_sequential_cpu_offload: - _pipeline.enable_sequential_cpu_offload() - # Unsafe code /> + # Load the layers corresponding to transformer. + logger.info(f"Loading {cls.transformer_name}.") + transformer.load_lora_adapter( + state_dict, + network_alphas=network_alphas, + adapter_name=adapter_name, + _pipeline=_pipeline, + low_cpu_mem_usage=low_cpu_mem_usage, + ) @classmethod # Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.load_lora_into_text_encoder diff --git a/src/diffusers/loaders/peft.py b/src/diffusers/loaders/peft.py index ac6828571553..779bd5316393 100644 --- a/src/diffusers/loaders/peft.py +++ b/src/diffusers/loaders/peft.py @@ -116,6 +116,7 @@ def load_lora_adapter(self, pretrained_model_name_or_path_or_dict, prefix="trans weight_name = kwargs.pop("weight_name", None) use_safetensors = kwargs.pop("use_safetensors", None) adapter_name = kwargs.pop("adapter_name", None) + network_alphas = kwargs.pop("network_alphas", None) _pipeline = kwargs.pop("_pipeline", None) low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", False) allow_pickle = False @@ -166,7 +167,11 @@ def load_lora_adapter(self, pretrained_model_name_or_path_or_dict, prefix="trans if "lora_B" in key: rank[key] = val.shape[1] - lora_config_kwargs = get_peft_kwargs(rank, network_alpha_dict=None, peft_state_dict=state_dict) + if network_alphas is not None and len(network_alphas) >= 1: + alpha_keys = [k for k in network_alphas.keys() if k.startswith(prefix) and k.split(".")[0] == prefix] + network_alphas = {k.replace(f"{prefix}.", ""): v for k, v in network_alphas.items() if k in alpha_keys} + + lora_config_kwargs = get_peft_kwargs(rank, network_alpha_dict=network_alphas, peft_state_dict=state_dict) if "use_dora" in lora_config_kwargs: if lora_config_kwargs["use_dora"] and is_peft_version("<", "0.9.0"): raise ValueError( @@ -187,6 +192,7 @@ def load_lora_adapter(self, pretrained_model_name_or_path_or_dict, prefix="trans # In case the pipeline has been already offloaded to CPU - temporarily remove the hooks # otherwise loading LoRA weights will lead to an error is_model_cpu_offload, is_sequential_cpu_offload = self._optionally_disable_offloading(_pipeline) + peft_kwargs = {} if is_peft_version(">=", "0.13.1"): peft_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage From 2e70a9345cd59cf3ae89839d51822687cbd290c5 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Fri, 18 Oct 2024 20:45:20 +0530 Subject: [PATCH 05/12] updates --- src/diffusers/loaders/lora_pipeline.py | 107 ++++++------------------- 1 file changed, 23 insertions(+), 84 deletions(-) diff --git a/src/diffusers/loaders/lora_pipeline.py b/src/diffusers/loaders/lora_pipeline.py index 9d31ecdff2ef..0750a4a33fb5 100644 --- a/src/diffusers/loaders/lora_pipeline.py +++ b/src/diffusers/loaders/lora_pipeline.py @@ -1830,7 +1830,7 @@ def load_lora_into_transformer( The value of the network alpha used for stable learning and preventing underflow. This value has the same meaning as the `--network_alpha` option in the kohya-ss trainer script. Refer to [this link](https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning). - transformer (`SD3Transformer2DModel`): + transformer (`FluxTransformer2DModel`): The Transformer model to load the LoRA layers into. adapter_name (`str`, *optional*): Adapter name to be used for referencing the loaded adapter model. If not specified, it will use @@ -2118,7 +2118,10 @@ class AmusedLoraLoaderMixin(StableDiffusionLoraLoaderMixin): text_encoder_name = TEXT_ENCODER_NAME @classmethod - def load_lora_into_transformer(cls, state_dict, network_alphas, transformer, adapter_name=None, _pipeline=None): + # Copied from diffusers.loaders.lora_pipeline.FluxLoraLoaderMixin.load_lora_into_transformer with FluxTransformer2DModel->UVit2DModel + def load_lora_into_transformer( + cls, state_dict, network_alphas, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False + ): """ This will load the LoRA layers specified in `state_dict` into `transformer`. @@ -2131,93 +2134,29 @@ def load_lora_into_transformer(cls, state_dict, network_alphas, transformer, ada The value of the network alpha used for stable learning and preventing underflow. This value has the same meaning as the `--network_alpha` option in the kohya-ss trainer script. Refer to [this link](https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning). - unet (`UNet2DConditionModel`): - The UNet model to load the LoRA layers into. + transformer (`UVit2DModel`): + The Transformer model to load the LoRA layers into. adapter_name (`str`, *optional*): Adapter name to be used for referencing the loaded adapter model. If not specified, it will use `default_{i}` where i is the total number of adapters being loaded. + low_cpu_mem_usage (`bool`, *optional*): + Speed up model loading by only loading the pretrained LoRA weights and not initializing the random + weights. """ - if not USE_PEFT_BACKEND: - raise ValueError("PEFT backend is required for this method.") - - from peft import LoraConfig, inject_adapter_in_model, set_peft_model_state_dict - - keys = list(state_dict.keys()) - - transformer_keys = [k for k in keys if k.startswith(cls.transformer_name)] - state_dict = { - k.replace(f"{cls.transformer_name}.", ""): v for k, v in state_dict.items() if k in transformer_keys - } - - if network_alphas is not None: - alpha_keys = [k for k in network_alphas.keys() if k.startswith(cls.transformer_name)] - network_alphas = { - k.replace(f"{cls.transformer_name}.", ""): v for k, v in network_alphas.items() if k in alpha_keys - } - - if len(state_dict.keys()) > 0: - if adapter_name in getattr(transformer, "peft_config", {}): - raise ValueError( - f"Adapter name {adapter_name} already in use in the transformer - please select a new adapter name." - ) - - rank = {} - for key, val in state_dict.items(): - if "lora_B" in key: - rank[key] = val.shape[1] + if low_cpu_mem_usage and not is_peft_version(">=", "0.13.1"): + raise ValueError( + "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`." + ) - lora_config_kwargs = get_peft_kwargs(rank, network_alphas, state_dict) - if "use_dora" in lora_config_kwargs: - if lora_config_kwargs["use_dora"] and is_peft_version("<", "0.9.0"): - raise ValueError( - "You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`." - ) - else: - lora_config_kwargs.pop("use_dora") - lora_config = LoraConfig(**lora_config_kwargs) - - # adapter_name - if adapter_name is None: - adapter_name = get_adapter_name(transformer) - - # In case the pipeline has been already offloaded to CPU - temporarily remove the hooks - # otherwise loading LoRA weights will lead to an error - is_model_cpu_offload, is_sequential_cpu_offload = cls._optionally_disable_offloading(_pipeline) - - inject_adapter_in_model(lora_config, transformer, adapter_name=adapter_name) - incompatible_keys = set_peft_model_state_dict(transformer, state_dict, adapter_name) - - warn_msg = "" - if incompatible_keys is not None: - # Check only for unexpected keys. - unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None) - if unexpected_keys: - lora_unexpected_keys = [k for k in unexpected_keys if "lora_" in k and adapter_name in k] - if lora_unexpected_keys: - warn_msg = ( - f"Loading adapter weights from state_dict led to unexpected keys found in the model:" - f" {', '.join(lora_unexpected_keys)}. " - ) - - # Filter missing keys specific to the current adapter. - missing_keys = getattr(incompatible_keys, "missing_keys", None) - if missing_keys: - lora_missing_keys = [k for k in missing_keys if "lora_" in k and adapter_name in k] - if lora_missing_keys: - warn_msg += ( - f"Loading adapter weights from state_dict led to missing keys in the model:" - f" {', '.join(lora_missing_keys)}." - ) - - if warn_msg: - logger.warning(warn_msg) - - # Offload back. - if is_model_cpu_offload: - _pipeline.enable_model_cpu_offload() - elif is_sequential_cpu_offload: - _pipeline.enable_sequential_cpu_offload() - # Unsafe code /> + # Load the layers corresponding to transformer. + logger.info(f"Loading {cls.transformer_name}.") + transformer.load_lora_adapter( + state_dict, + network_alphas=network_alphas, + adapter_name=adapter_name, + _pipeline=_pipeline, + low_cpu_mem_usage=low_cpu_mem_usage, + ) @classmethod # Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.load_lora_into_text_encoder From c0f45856073ccf5293488b3a8f68cbef6ec4272e Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Fri, 18 Oct 2024 21:08:44 +0530 Subject: [PATCH 06/12] updates --- src/diffusers/loaders/lora_pipeline.py | 78 +++++++++++++++----------- src/diffusers/loaders/peft.py | 3 + 2 files changed, 48 insertions(+), 33 deletions(-) diff --git a/src/diffusers/loaders/lora_pipeline.py b/src/diffusers/loaders/lora_pipeline.py index 0750a4a33fb5..4c4b1af4b4ed 100644 --- a/src/diffusers/loaders/lora_pipeline.py +++ b/src/diffusers/loaders/lora_pipeline.py @@ -1326,14 +1326,17 @@ def load_lora_into_transformer( ) # Load the layers corresponding to transformer. - logger.info(f"Loading {cls.transformer_name}.") - transformer.load_lora_adapter( - state_dict, - network_alphas=None, - adapter_name=adapter_name, - _pipeline=_pipeline, - low_cpu_mem_usage=low_cpu_mem_usage, - ) + keys = list(state_dict.keys()) + only_text_encoder = all(key.startswith(cls.text_encoder_name) for key in keys) + if not only_text_encoder: + logger.info(f"Loading {cls.transformer_name}.") + transformer.load_lora_adapter( + state_dict, + network_alphas=None, + adapter_name=adapter_name, + _pipeline=_pipeline, + low_cpu_mem_usage=low_cpu_mem_usage, + ) @classmethod # Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.load_lora_into_text_encoder @@ -1845,14 +1848,17 @@ def load_lora_into_transformer( ) # Load the layers corresponding to transformer. - logger.info(f"Loading {cls.transformer_name}.") - transformer.load_lora_adapter( - state_dict, - network_alphas=network_alphas, - adapter_name=adapter_name, - _pipeline=_pipeline, - low_cpu_mem_usage=low_cpu_mem_usage, - ) + keys = list(state_dict.keys()) + only_text_encoder = all(key.startswith(cls.text_encoder_name) for key in keys) + if not only_text_encoder: + logger.info(f"Loading {cls.transformer_name}.") + transformer.load_lora_adapter( + state_dict, + network_alphas=network_alphas, + adapter_name=adapter_name, + _pipeline=_pipeline, + low_cpu_mem_usage=low_cpu_mem_usage, + ) @classmethod # Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.load_lora_into_text_encoder @@ -2118,7 +2124,7 @@ class AmusedLoraLoaderMixin(StableDiffusionLoraLoaderMixin): text_encoder_name = TEXT_ENCODER_NAME @classmethod - # Copied from diffusers.loaders.lora_pipeline.FluxLoraLoaderMixin.load_lora_into_transformer with FluxTransformer2DModel->UVit2DModel + # Copied from diffusers.loaders.lora_pipeline.FluxLoraLoaderMixin.load_lora_into_transformer with FluxTransformer2DModel->UVit2DModel def load_lora_into_transformer( cls, state_dict, network_alphas, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False ): @@ -2149,14 +2155,17 @@ def load_lora_into_transformer( ) # Load the layers corresponding to transformer. - logger.info(f"Loading {cls.transformer_name}.") - transformer.load_lora_adapter( - state_dict, - network_alphas=network_alphas, - adapter_name=adapter_name, - _pipeline=_pipeline, - low_cpu_mem_usage=low_cpu_mem_usage, - ) + keys = list(state_dict.keys()) + only_text_encoder = all(key.startswith(cls.text_encoder_name) for key in keys) + if not only_text_encoder: + logger.info(f"Loading {cls.transformer_name}.") + transformer.load_lora_adapter( + state_dict, + network_alphas=network_alphas, + adapter_name=adapter_name, + _pipeline=_pipeline, + low_cpu_mem_usage=low_cpu_mem_usage, + ) @classmethod # Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.load_lora_into_text_encoder @@ -2537,14 +2546,17 @@ def load_lora_into_transformer( ) # Load the layers corresponding to transformer. - logger.info(f"Loading {cls.transformer_name}.") - transformer.load_lora_adapter( - state_dict, - network_alphas=None, - adapter_name=adapter_name, - _pipeline=_pipeline, - low_cpu_mem_usage=low_cpu_mem_usage, - ) + keys = list(state_dict.keys()) + only_text_encoder = all(key.startswith(cls.text_encoder_name) for key in keys) + if not only_text_encoder: + logger.info(f"Loading {cls.transformer_name}.") + transformer.load_lora_adapter( + state_dict, + network_alphas=None, + adapter_name=adapter_name, + _pipeline=_pipeline, + low_cpu_mem_usage=low_cpu_mem_usage, + ) @classmethod # Adapted from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.save_lora_weights without support for text encoder diff --git a/src/diffusers/loaders/peft.py b/src/diffusers/loaders/peft.py index 779bd5316393..d3ea9022104a 100644 --- a/src/diffusers/loaders/peft.py +++ b/src/diffusers/loaders/peft.py @@ -104,6 +104,9 @@ def _optionally_disable_offloading(cls, _pipeline): return (is_model_cpu_offload, is_sequential_cpu_offload) def load_lora_adapter(self, pretrained_model_name_or_path_or_dict, prefix="transformer", **kwargs): + """ + TODO + """ from peft import LoraConfig, inject_adapter_in_model, set_peft_model_state_dict cache_dir = kwargs.pop("cache_dir", None) From 0f6ce88262c53be198fa6ea5983b9f078530d24a Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Fri, 18 Oct 2024 21:30:21 +0530 Subject: [PATCH 07/12] updates. --- src/diffusers/loaders/lora_pipeline.py | 66 +++++++++++++------------- 1 file changed, 32 insertions(+), 34 deletions(-) diff --git a/src/diffusers/loaders/lora_pipeline.py b/src/diffusers/loaders/lora_pipeline.py index 4c4b1af4b4ed..864a447e045a 100644 --- a/src/diffusers/loaders/lora_pipeline.py +++ b/src/diffusers/loaders/lora_pipeline.py @@ -1327,8 +1327,8 @@ def load_lora_into_transformer( # Load the layers corresponding to transformer. keys = list(state_dict.keys()) - only_text_encoder = all(key.startswith(cls.text_encoder_name) for key in keys) - if not only_text_encoder: + transformer_present = any(key.startswith(cls.transformer_name) for key in keys) + if transformer_present: logger.info(f"Loading {cls.transformer_name}.") transformer.load_lora_adapter( state_dict, @@ -1795,14 +1795,18 @@ def load_lora_weights( if not is_correct_format: raise ValueError("Invalid LoRA checkpoint.") - self.load_lora_into_transformer( - state_dict, - network_alphas=network_alphas, - transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer, - adapter_name=adapter_name, - _pipeline=self, - low_cpu_mem_usage=low_cpu_mem_usage, - ) + transformer_state_dict = {k: v for k, v in state_dict.items() if "transformer." in k} + if len(transformer_state_dict) > 0: + self.load_lora_into_transformer( + state_dict, + network_alphas=network_alphas, + transformer=getattr(self, self.transformer_name) + if not hasattr(self, "transformer") + else self.transformer, + adapter_name=adapter_name, + _pipeline=self, + low_cpu_mem_usage=low_cpu_mem_usage, + ) text_encoder_state_dict = {k: v for k, v in state_dict.items() if "text_encoder." in k} if len(text_encoder_state_dict) > 0: @@ -1849,8 +1853,8 @@ def load_lora_into_transformer( # Load the layers corresponding to transformer. keys = list(state_dict.keys()) - only_text_encoder = all(key.startswith(cls.text_encoder_name) for key in keys) - if not only_text_encoder: + transformer_present = any(key.startswith(cls.transformer_name) for key in keys) + if transformer_present: logger.info(f"Loading {cls.transformer_name}.") transformer.load_lora_adapter( state_dict, @@ -2155,17 +2159,14 @@ def load_lora_into_transformer( ) # Load the layers corresponding to transformer. - keys = list(state_dict.keys()) - only_text_encoder = all(key.startswith(cls.text_encoder_name) for key in keys) - if not only_text_encoder: - logger.info(f"Loading {cls.transformer_name}.") - transformer.load_lora_adapter( - state_dict, - network_alphas=network_alphas, - adapter_name=adapter_name, - _pipeline=_pipeline, - low_cpu_mem_usage=low_cpu_mem_usage, - ) + logger.info(f"Loading {cls.transformer_name}.") + transformer.load_lora_adapter( + state_dict, + network_alphas=network_alphas, + adapter_name=adapter_name, + _pipeline=_pipeline, + low_cpu_mem_usage=low_cpu_mem_usage, + ) @classmethod # Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.load_lora_into_text_encoder @@ -2546,17 +2547,14 @@ def load_lora_into_transformer( ) # Load the layers corresponding to transformer. - keys = list(state_dict.keys()) - only_text_encoder = all(key.startswith(cls.text_encoder_name) for key in keys) - if not only_text_encoder: - logger.info(f"Loading {cls.transformer_name}.") - transformer.load_lora_adapter( - state_dict, - network_alphas=None, - adapter_name=adapter_name, - _pipeline=_pipeline, - low_cpu_mem_usage=low_cpu_mem_usage, - ) + logger.info(f"Loading {cls.transformer_name}.") + transformer.load_lora_adapter( + state_dict, + network_alphas=None, + adapter_name=adapter_name, + _pipeline=_pipeline, + low_cpu_mem_usage=low_cpu_mem_usage, + ) @classmethod # Adapted from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.save_lora_weights without support for text encoder From 35414953f2759cb4ed371fa7c03ef0678af37005 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Fri, 18 Oct 2024 21:34:24 +0530 Subject: [PATCH 08/12] fix-copies --- src/diffusers/loaders/lora_pipeline.py | 38 +++++++++++++------------- 1 file changed, 19 insertions(+), 19 deletions(-) diff --git a/src/diffusers/loaders/lora_pipeline.py b/src/diffusers/loaders/lora_pipeline.py index 864a447e045a..31aca552d67d 100644 --- a/src/diffusers/loaders/lora_pipeline.py +++ b/src/diffusers/loaders/lora_pipeline.py @@ -1326,17 +1326,14 @@ def load_lora_into_transformer( ) # Load the layers corresponding to transformer. - keys = list(state_dict.keys()) - transformer_present = any(key.startswith(cls.transformer_name) for key in keys) - if transformer_present: - logger.info(f"Loading {cls.transformer_name}.") - transformer.load_lora_adapter( - state_dict, - network_alphas=None, - adapter_name=adapter_name, - _pipeline=_pipeline, - low_cpu_mem_usage=low_cpu_mem_usage, - ) + logger.info(f"Loading {cls.transformer_name}.") + transformer.load_lora_adapter( + state_dict, + network_alphas=None, + adapter_name=adapter_name, + _pipeline=_pipeline, + low_cpu_mem_usage=low_cpu_mem_usage, + ) @classmethod # Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.load_lora_into_text_encoder @@ -2159,14 +2156,17 @@ def load_lora_into_transformer( ) # Load the layers corresponding to transformer. - logger.info(f"Loading {cls.transformer_name}.") - transformer.load_lora_adapter( - state_dict, - network_alphas=network_alphas, - adapter_name=adapter_name, - _pipeline=_pipeline, - low_cpu_mem_usage=low_cpu_mem_usage, - ) + keys = list(state_dict.keys()) + transformer_present = any(key.startswith(cls.transformer_name) for key in keys) + if transformer_present: + logger.info(f"Loading {cls.transformer_name}.") + transformer.load_lora_adapter( + state_dict, + network_alphas=network_alphas, + adapter_name=adapter_name, + _pipeline=_pipeline, + low_cpu_mem_usage=low_cpu_mem_usage, + ) @classmethod # Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.load_lora_into_text_encoder From d3afa2687e8684952bdaf0051775dada2fe18966 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Fri, 25 Oct 2024 00:00:07 +0900 Subject: [PATCH 09/12] lora constants. --- src/diffusers/loaders/lora_base.py | 9 +++------ src/diffusers/loaders/lora_pipeline.py | 5 +---- 2 files changed, 4 insertions(+), 10 deletions(-) diff --git a/src/diffusers/loaders/lora_base.py b/src/diffusers/loaders/lora_base.py index b1809967dca0..783bf5791033 100644 --- a/src/diffusers/loaders/lora_base.py +++ b/src/diffusers/loaders/lora_base.py @@ -51,6 +51,9 @@ logger = logging.get_logger(__name__) +LORA_WEIGHT_NAME = "pytorch_lora_weights.bin" +LORA_WEIGHT_NAME_SAFE = "pytorch_lora_weights.safetensors" + def fuse_text_encoder_lora(text_encoder, lora_scale=1.0, safe_fusing=False, adapter_names=None): """ @@ -195,8 +198,6 @@ def _fetch_state_dict( user_agent, allow_pickle, ): - from .lora_pipeline import LORA_WEIGHT_NAME, LORA_WEIGHT_NAME_SAFE - model_file = None if not isinstance(pretrained_model_name_or_path_or_dict, dict): # Let's first try to load .safetensors weights @@ -260,8 +261,6 @@ def _fetch_state_dict( def _best_guess_weight_name( pretrained_model_name_or_path_or_dict, file_extension=".safetensors", local_files_only=False ): - from .lora_pipeline import LORA_WEIGHT_NAME, LORA_WEIGHT_NAME_SAFE - if local_files_only or HF_HUB_OFFLINE: raise ValueError("When using the offline mode, you must specify a `weight_name`.") @@ -722,8 +721,6 @@ def write_lora_layers( save_function: Callable, safe_serialization: bool, ): - from .lora_pipeline import LORA_WEIGHT_NAME, LORA_WEIGHT_NAME_SAFE - if os.path.isfile(save_directory): logger.error(f"Provided path ({save_directory}) should be a directory, not a file") return diff --git a/src/diffusers/loaders/lora_pipeline.py b/src/diffusers/loaders/lora_pipeline.py index 31aca552d67d..154aa2d8f9bb 100644 --- a/src/diffusers/loaders/lora_pipeline.py +++ b/src/diffusers/loaders/lora_pipeline.py @@ -32,7 +32,7 @@ logging, scale_lora_layers, ) -from .lora_base import LoraBaseMixin, _fetch_state_dict +from .lora_base import LORA_WEIGHT_NAME, LORA_WEIGHT_NAME_SAFE, LoraBaseMixin, _fetch_state_dict # noqa from .lora_conversion_utils import ( _convert_kohya_flux_lora_to_diffusers, _convert_non_diffusers_lora_to_diffusers, @@ -61,9 +61,6 @@ UNET_NAME = "unet" TRANSFORMER_NAME = "transformer" -LORA_WEIGHT_NAME = "pytorch_lora_weights.bin" -LORA_WEIGHT_NAME_SAFE = "pytorch_lora_weights.safetensors" - class StableDiffusionLoraLoaderMixin(LoraBaseMixin): r""" From c28d6f3fb39c4666057a8b273327fb14150045db Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Fri, 25 Oct 2024 00:22:19 +0900 Subject: [PATCH 10/12] add tests --- src/diffusers/loaders/lora_base.py | 12 ++++++++ tests/lora/test_deprecated_utilities.py | 39 +++++++++++++++++++++++++ 2 files changed, 51 insertions(+) create mode 100644 tests/lora/test_deprecated_utilities.py diff --git a/src/diffusers/loaders/lora_base.py b/src/diffusers/loaders/lora_base.py index 783bf5791033..131d1f299b62 100644 --- a/src/diffusers/loaders/lora_base.py +++ b/src/diffusers/loaders/lora_base.py @@ -349,6 +349,18 @@ def _optionally_disable_offloading(cls, _pipeline): return (is_model_cpu_offload, is_sequential_cpu_offload) + @classmethod + def _fetch_state_dict(cls, *args, **kwargs): + deprecation_message = f"Using the `_fetch_state_dict()` method from {cls} has been deprecated and will be removed in a future version. Please use `from diffusers.loaders.lora_base import _fetch_state_dict`." + deprecate("_fetch_state_dict", "0.35.0", deprecation_message) + _fetch_state_dict(*args, **kwargs) + + @classmethod + def _best_guess_weight_name(cls, *args, **kwargs): + deprecation_message = f"Using the `_best_guess_weight_name()` method from {cls} has been deprecated and will be removed in a future version. Please use `from diffusers.loaders.lora_base import _best_guess_weight_name`." + deprecate("_best_guess_weight_name", "0.35.0", deprecation_message) + _best_guess_weight_name(*args, **kwargs) + def unload_lora_weights(self): """ Unloads the LoRA parameters. diff --git a/tests/lora/test_deprecated_utilities.py b/tests/lora/test_deprecated_utilities.py new file mode 100644 index 000000000000..4275ef8089a3 --- /dev/null +++ b/tests/lora/test_deprecated_utilities.py @@ -0,0 +1,39 @@ +import os +import tempfile +import unittest + +import torch + +from diffusers.loaders.lora_base import LoraBaseMixin + + +class UtilityMethodDeprecationTests(unittest.TestCase): + def test_fetch_state_dict_cls_method_raises_warning(self): + state_dict = torch.nn.Linear(3, 3).state_dict() + with self.assertWarns(FutureWarning) as warning: + _ = LoraBaseMixin._fetch_state_dict( + state_dict, + weight_name=None, + use_safetensors=False, + local_files_only=True, + cache_dir=None, + force_download=False, + proxies=None, + token=None, + revision=None, + subfolder=None, + user_agent=None, + allow_pickle=None, + ) + warning_message = str(warning.warnings[0].message) + assert "Using the `_fetch_state_dict()` method from" in warning_message + + def test_best_guess_weight_name_cls_method_raises_warning(self): + with tempfile.TemporaryDirectory() as tmpdir: + state_dict = torch.nn.Linear(3, 3).state_dict() + torch.save(state_dict, os.path.join(tmpdir, "pytorch_lora_weights.bin")) + + with self.assertWarns(FutureWarning) as warning: + _ = LoraBaseMixin._best_guess_weight_name(pretrained_model_name_or_path_or_dict=tmpdir) + warning_message = str(warning.warnings[0].message) + assert "Using the `_best_guess_weight_name()` method from" in warning_message From e187b706a546a54c95abffb951f0ec0292400627 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Fri, 1 Nov 2024 10:45:24 +0530 Subject: [PATCH 11/12] Apply suggestions from code review Co-authored-by: Benjamin Bossan --- src/diffusers/loaders/lora_base.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/diffusers/loaders/lora_base.py b/src/diffusers/loaders/lora_base.py index 131d1f299b62..286d0a12bc71 100644 --- a/src/diffusers/loaders/lora_base.py +++ b/src/diffusers/loaders/lora_base.py @@ -353,13 +353,13 @@ def _optionally_disable_offloading(cls, _pipeline): def _fetch_state_dict(cls, *args, **kwargs): deprecation_message = f"Using the `_fetch_state_dict()` method from {cls} has been deprecated and will be removed in a future version. Please use `from diffusers.loaders.lora_base import _fetch_state_dict`." deprecate("_fetch_state_dict", "0.35.0", deprecation_message) - _fetch_state_dict(*args, **kwargs) + return _fetch_state_dict(*args, **kwargs) @classmethod def _best_guess_weight_name(cls, *args, **kwargs): deprecation_message = f"Using the `_best_guess_weight_name()` method from {cls} has been deprecated and will be removed in a future version. Please use `from diffusers.loaders.lora_base import _best_guess_weight_name`." deprecate("_best_guess_weight_name", "0.35.0", deprecation_message) - _best_guess_weight_name(*args, **kwargs) + return _best_guess_weight_name(*args, **kwargs) def unload_lora_weights(self): """ From 6ca6c659d62bcd6831527c191d71093140f56a44 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Sat, 2 Nov 2024 09:25:17 +0530 Subject: [PATCH 12/12] docstrings. --- src/diffusers/loaders/peft.py | 45 +++++++++++++++++++++++++++++++++-- 1 file changed, 43 insertions(+), 2 deletions(-) diff --git a/src/diffusers/loaders/peft.py b/src/diffusers/loaders/peft.py index d3ea9022104a..cf361e88a670 100644 --- a/src/diffusers/loaders/peft.py +++ b/src/diffusers/loaders/peft.py @@ -104,8 +104,49 @@ def _optionally_disable_offloading(cls, _pipeline): return (is_model_cpu_offload, is_sequential_cpu_offload) def load_lora_adapter(self, pretrained_model_name_or_path_or_dict, prefix="transformer", **kwargs): - """ - TODO + r""" + Loads a LoRA adapter into the underlying model. + + Parameters: + pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`): + Can be either: + + - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on + the Hub. + - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved + with [`ModelMixin.save_pretrained`]. + - A [torch state + dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict). + + prefix (`str`, *optional*): Prefix to filter the state dict. + + cache_dir (`Union[str, os.PathLike]`, *optional*): + Path to a directory where a downloaded pretrained model configuration is cached if the standard cache + is not used. + force_download (`bool`, *optional*, defaults to `False`): + Whether or not to force the (re-)download of the model weights and configuration files, overriding the + cached versions if they exist. + proxies (`Dict[str, str]`, *optional*): + A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128', + 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. + local_files_only (`bool`, *optional*, defaults to `False`): + Whether to only load local model weights and configuration files or not. If set to `True`, the model + won't be downloaded from the Hub. + token (`str` or *bool*, *optional*): + The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from + `diffusers-cli login` (stored in `~/.huggingface`) is used. + revision (`str`, *optional*, defaults to `"main"`): + The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier + allowed by Git. + subfolder (`str`, *optional*, defaults to `""`): + The subfolder location of a model file within a larger model repository on the Hub or locally. + network_alphas (`Dict[str, float]`): + The value of the network alpha used for stable learning and preventing underflow. This value has the + same meaning as the `--network_alpha` option in the kohya-ss trainer script. Refer to [this + link](https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning). + low_cpu_mem_usage (`bool`, *optional*): + Speed up model loading by only loading the pretrained LoRA weights and not initializing the random + weights. """ from peft import LoraConfig, inject_adapter_in_model, set_peft_model_state_dict