From a9f47624311c233a82b3494f861df3111e2aa30b Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Mon, 17 Feb 2025 10:01:53 +0530 Subject: [PATCH 1/5] restruct certain keys to be checked for peft config update. --- src/diffusers/loaders/peft.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/diffusers/loaders/peft.py b/src/diffusers/loaders/peft.py index 0d26738eec62..569d66966379 100644 --- a/src/diffusers/loaders/peft.py +++ b/src/diffusers/loaders/peft.py @@ -53,6 +53,7 @@ "LTXVideoTransformer3DModel": lambda model_cls, weights: weights, "SanaTransformer2DModel": lambda model_cls, weights: weights, } +_NO_CONFIG_UPDATE_KEYS = ["to_k", "to_q", "to_v"] def _maybe_adjust_config(config): @@ -67,6 +68,8 @@ def _maybe_adjust_config(config): original_r = config["r"] for key in list(rank_pattern.keys()): + if any(prefix in key for prefix in _NO_CONFIG_UPDATE_KEYS): + continue key_rank = rank_pattern[key] # try to detect ambiguity From 8c988f4eaa5057e2471c61b75580773572781a34 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Tue, 18 Feb 2025 15:28:06 +0530 Subject: [PATCH 2/5] updates --- src/diffusers/loaders/peft.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/diffusers/loaders/peft.py b/src/diffusers/loaders/peft.py index 569d66966379..367a5ff49709 100644 --- a/src/diffusers/loaders/peft.py +++ b/src/diffusers/loaders/peft.py @@ -54,6 +54,7 @@ "SanaTransformer2DModel": lambda model_cls, weights: weights, } _NO_CONFIG_UPDATE_KEYS = ["to_k", "to_q", "to_v"] +_FULL_NAME_PREFIX_FOR_PEFT = "FULL-NAME" def _maybe_adjust_config(config): @@ -188,6 +189,7 @@ def load_lora_adapter(self, pretrained_model_name_or_path_or_dict, prefix="trans """ from peft import LoraConfig, inject_adapter_in_model, set_peft_model_state_dict from peft.tuners.tuners_utils import BaseTunerLayer + from peft.utils.constants import FULLY_QUALIFIED_PATTERN_KEY_PREFIX cache_dir = kwargs.pop("cache_dir", None) force_download = kwargs.pop("force_download", False) @@ -253,14 +255,14 @@ def load_lora_adapter(self, pretrained_model_name_or_path_or_dict, prefix="trans # Cannot figure out rank from lora layers that don't have atleast 2 dimensions. # Bias layers in LoRA only have a single dimension if "lora_B" in key and val.ndim > 1: - rank[key] = val.shape[1] + rank[f"{FULLY_QUALIFIED_PATTERN_KEY_PREFIX}{key}"] = val.shape[1] if network_alphas is not None and len(network_alphas) >= 1: alpha_keys = [k for k in network_alphas.keys() if k.startswith(f"{prefix}.")] network_alphas = {k.replace(f"{prefix}.", ""): v for k, v in network_alphas.items() if k in alpha_keys} lora_config_kwargs = get_peft_kwargs(rank, network_alpha_dict=network_alphas, peft_state_dict=state_dict) - lora_config_kwargs = _maybe_adjust_config(lora_config_kwargs) + # lora_config_kwargs = _maybe_adjust_config(lora_config_kwargs) if "use_dora" in lora_config_kwargs: if lora_config_kwargs["use_dora"]: From bb0d4a1879f1707480006fa2041fa9e783adc8ce Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Sat, 22 Feb 2025 22:42:18 +0530 Subject: [PATCH 3/5] finish./ --- src/diffusers/loaders/peft.py | 19 +++++++++++++++---- 1 file changed, 15 insertions(+), 4 deletions(-) diff --git a/src/diffusers/loaders/peft.py b/src/diffusers/loaders/peft.py index 367a5ff49709..cbf658400135 100644 --- a/src/diffusers/loaders/peft.py +++ b/src/diffusers/loaders/peft.py @@ -54,7 +54,6 @@ "SanaTransformer2DModel": lambda model_cls, weights: weights, } _NO_CONFIG_UPDATE_KEYS = ["to_k", "to_q", "to_v"] -_FULL_NAME_PREFIX_FOR_PEFT = "FULL-NAME" def _maybe_adjust_config(config): @@ -189,7 +188,11 @@ def load_lora_adapter(self, pretrained_model_name_or_path_or_dict, prefix="trans """ from peft import LoraConfig, inject_adapter_in_model, set_peft_model_state_dict from peft.tuners.tuners_utils import BaseTunerLayer - from peft.utils.constants import FULLY_QUALIFIED_PATTERN_KEY_PREFIX + + try: + from peft.utils.constants import FULLY_QUALIFIED_PATTERN_KEY_PREFIX + except ImportError: + FULLY_QUALIFIED_PATTERN_KEY_PREFIX = None cache_dir = kwargs.pop("cache_dir", None) force_download = kwargs.pop("force_download", False) @@ -255,14 +258,22 @@ def load_lora_adapter(self, pretrained_model_name_or_path_or_dict, prefix="trans # Cannot figure out rank from lora layers that don't have atleast 2 dimensions. # Bias layers in LoRA only have a single dimension if "lora_B" in key and val.ndim > 1: - rank[f"{FULLY_QUALIFIED_PATTERN_KEY_PREFIX}{key}"] = val.shape[1] + # Support to handle cases where layer patterns are treated as full layer names + # was added later in PEFT. So, we handle it accordingly. + # TODO: when we fix the minimal PEFT version for Diffusers, + # we should remove `_maybe_adjust_config()`. + if FULLY_QUALIFIED_PATTERN_KEY_PREFIX: + rank[f"{FULLY_QUALIFIED_PATTERN_KEY_PREFIX}{key}"] = val.shape[1] + else: + rank[key] = val.shape[1] if network_alphas is not None and len(network_alphas) >= 1: alpha_keys = [k for k in network_alphas.keys() if k.startswith(f"{prefix}.")] network_alphas = {k.replace(f"{prefix}.", ""): v for k, v in network_alphas.items() if k in alpha_keys} lora_config_kwargs = get_peft_kwargs(rank, network_alpha_dict=network_alphas, peft_state_dict=state_dict) - # lora_config_kwargs = _maybe_adjust_config(lora_config_kwargs) + if not FULLY_QUALIFIED_PATTERN_KEY_PREFIX: + lora_config_kwargs = _maybe_adjust_config(lora_config_kwargs) if "use_dora" in lora_config_kwargs: if lora_config_kwargs["use_dora"]: From cbc4432a98063652fd775606977466b31a0dc33c Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Sat, 22 Feb 2025 23:08:02 +0530 Subject: [PATCH 4/5] finish 2. --- src/diffusers/loaders/peft.py | 25 +++++++++++-------------- 1 file changed, 11 insertions(+), 14 deletions(-) diff --git a/src/diffusers/loaders/peft.py b/src/diffusers/loaders/peft.py index 85f93e543aec..2eaaa20da1f1 100644 --- a/src/diffusers/loaders/peft.py +++ b/src/diffusers/loaders/peft.py @@ -54,7 +54,6 @@ "SanaTransformer2DModel": lambda model_cls, weights: weights, "Lumina2Transformer2DModel": lambda model_cls, weights: weights, } -_NO_CONFIG_UPDATE_KEYS = ["to_k", "to_q", "to_v"] def _maybe_adjust_config(config): @@ -64,40 +63,38 @@ def _maybe_adjust_config(config): method removes the ambiguity by following what is described here: https://github.com/huggingface/diffusers/pull/9985#issuecomment-2493840028. """ + # Track keys that have been explicitly removed to prevent re-adding them. + deleted_keys = set() + rank_pattern = config["rank_pattern"].copy() target_modules = config["target_modules"] original_r = config["r"] for key in list(rank_pattern.keys()): - if any(prefix in key for prefix in _NO_CONFIG_UPDATE_KEYS): - continue key_rank = rank_pattern[key] # try to detect ambiguity - # `target_modules` can also be a str, in which case this loop would loop - # over the chars of the str. The technically correct way to match LoRA keys - # in PEFT is to use LoraModel._check_target_module_exists (lora_config, key). - # But this cuts it for now. exact_matches = [mod for mod in target_modules if mod == key] substring_matches = [mod for mod in target_modules if key in mod and mod != key] ambiguous_key = key if exact_matches and substring_matches: - # if ambiguous we update the rank associated with the ambiguous key (`proj_out`, for example) + # if ambiguous, update the rank associated with the ambiguous key (`proj_out`, for example) config["r"] = key_rank - # remove the ambiguous key from `rank_pattern` and update its rank to `r`, instead + # remove the ambiguous key from `rank_pattern` and record it as deleted del config["rank_pattern"][key] + deleted_keys.add(key) + # For substring matches, add them with the original rank only if they haven't been assigned already for mod in substring_matches: - # avoid overwriting if the module already has a specific rank - if mod not in config["rank_pattern"]: + if mod not in config["rank_pattern"] and mod not in deleted_keys: config["rank_pattern"][mod] = original_r - # update the rest of the keys with the `original_r` + # Update the rest of the target modules with the original rank if not already set and not deleted for mod in target_modules: - if mod != ambiguous_key and mod not in config["rank_pattern"]: + if mod != ambiguous_key and mod not in config["rank_pattern"] and mod not in deleted_keys: config["rank_pattern"][mod] = original_r - # handle alphas to deal with cases like + # Handle alphas to deal with cases like: # https://github.com/huggingface/diffusers/pull/9999#issuecomment-2516180777 has_different_ranks = len(config["rank_pattern"]) > 1 and list(config["rank_pattern"])[0] != config["r"] if has_different_ranks: From a81351b19f2995ecf373141bc0ab3e5c925b0741 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Mon, 24 Feb 2025 16:40:39 +0530 Subject: [PATCH 5/5] updates --- src/diffusers/loaders/peft.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/diffusers/loaders/peft.py b/src/diffusers/loaders/peft.py index 2eaaa20da1f1..da038b9fdca5 100644 --- a/src/diffusers/loaders/peft.py +++ b/src/diffusers/loaders/peft.py @@ -74,6 +74,10 @@ def _maybe_adjust_config(config): key_rank = rank_pattern[key] # try to detect ambiguity + # `target_modules` can also be a str, in which case this loop would loop + # over the chars of the str. The technically correct way to match LoRA keys + # in PEFT is to use LoraModel._check_target_module_exists (lora_config, key). + # But this cuts it for now. exact_matches = [mod for mod in target_modules if mod == key] substring_matches = [mod for mod in target_modules if key in mod and mod != key] ambiguous_key = key