Skip to content

Commit 5c4976b

Browse files
committed
fixes
1 parent 367153d commit 5c4976b

File tree

4 files changed

+74
-35
lines changed

4 files changed

+74
-35
lines changed

src/diffusers/loaders/lora_base.py

Lines changed: 4 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -344,6 +344,10 @@ def _load_lora_into_text_encoder(
344344

345345
# Safe prefix to check with.
346346
if any(text_encoder_name in key for key in keys):
347+
# adapter_name
348+
if adapter_name is None:
349+
adapter_name = get_adapter_name(text_encoder)
350+
347351
# Load the layers corresponding to text encoder and make necessary adjustments.
348352
text_encoder_keys = [k for k in keys if k.startswith(prefix) and k.split(".")[0] == prefix]
349353
text_encoder_lora_state_dict = {
@@ -358,20 +362,6 @@ def _load_lora_into_text_encoder(
358362
# convert state dict
359363
text_encoder_lora_state_dict = convert_state_dict_to_peft(text_encoder_lora_state_dict)
360364

361-
if any("position_embedding" in k for k in text_encoder_lora_state_dict):
362-
# TODO: this copying is a big shot in the dark.
363-
# https://huggingface.co/sayakpaul/different-lora-from-civitai/tree/main?show_file_info=RM_Artistify_v1.0M.safetensors
364-
# only has LoRA keys for the position embedding but not the LoRA embedding keys.
365-
text_encoder_lora_state_dict[
366-
"text_model.embeddings.position_embedding.lora_embedding_A.weight"
367-
] = text_encoder_lora_state_dict["text_model.embeddings.position_embedding.lora_A.weight"].clone()
368-
text_encoder_lora_state_dict[
369-
"text_model.embeddings.position_embedding.lora_embedding_B.weight"
370-
] = text_encoder_lora_state_dict["text_model.embeddings.position_embedding.lora_B.weight"].clone()
371-
rank["text_model.embeddings.position_embedding.lora_B.weight"] = text_encoder_lora_state_dict[
372-
"text_model.embeddings.position_embedding.lora_B.weight"
373-
].shape[1]
374-
375365
for name, _ in text_encoder_attn_modules(text_encoder):
376366
for module in ("out_proj", "q_proj", "k_proj", "v_proj"):
377367
rank_key = f"{name}.{module}.lora_B.weight"
@@ -414,10 +404,6 @@ def _load_lora_into_text_encoder(
414404

415405
lora_config = LoraConfig(**lora_config_kwargs)
416406

417-
# adapter_name
418-
if adapter_name is None:
419-
adapter_name = get_adapter_name(text_encoder)
420-
421407
is_model_cpu_offload, is_sequential_cpu_offload = _func_optionally_disable_offloading(_pipeline)
422408

423409
# inject LoRA layers and load the state dict

src/diffusers/loaders/lora_conversion_utils.py

Lines changed: 53 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
import torch
1919

20-
from ..utils import is_peft_version, logging
20+
from ..utils import is_peft_version, logging, state_dict_all_zero
2121

2222

2323
logger = logging.get_logger(__name__)
@@ -755,29 +755,67 @@ def _convert(original_key, diffusers_key, state_dict, new_state_dict):
755755
state_dict = {k.replace("diffusion_model.", "lora_unet_"): v for k, v in state_dict.items()}
756756
state_dict = {k.replace("text_encoders.clip_l.transformer.", "lora_te_"): v for k, v in state_dict.items()}
757757
has_t5xxl = any(k.startswith("text_encoders.t5xxl.transformer.") for k in state_dict)
758+
759+
if any("position_embedding" in k for k in state_dict):
760+
zero_status_pe = state_dict_all_zero(state_dict, "position_embedding")
761+
if zero_status_pe:
762+
logger.info(
763+
"The `position_embedding` LoRA params are all zeros which make them ineffective. "
764+
"So, we will purge them out of the curret state dict to make loading possible."
765+
)
766+
current_pe_lora_keys = [k for k in state_dict if "position_embedding" in k]
767+
for k in current_pe_lora_keys:
768+
state_dict.pop(k)
769+
else:
770+
raise NotImplementedError(
771+
"The state_dict has position_embedding LoRA params and we currently do not support them. "
772+
"Open an issue if you need this supported - https://github.com/huggingface/diffusers/issues/new."
773+
)
774+
758775
if has_t5xxl:
759-
logger.info(
760-
"T5-xxl keys found in the state dict, which are currently unsupported. We will filter them out."
761-
"Open an issue if this is a problem - https://github.com/huggingface/diffusers/issues/new."
762-
)
776+
zero_status_t5 = state_dict_all_zero(state_dict, "text_encoders.t5xxl")
777+
if zero_status_t5:
778+
logger.info(
779+
"The `t5xxl` LoRA params are all zeros which make them ineffective. "
780+
"So, we will purge them out of the curret state dict to make loading possible."
781+
)
782+
else:
783+
logger.info(
784+
"T5-xxl keys found in the state dict, which are currently unsupported. We will filter them out."
785+
"Open an issue if this is a problem - https://github.com/huggingface/diffusers/issues/new."
786+
)
763787
state_dict = {k: v for k, v in state_dict.items() if not k.startswith("text_encoders.t5xxl.transformer.")}
764788

765789
any_diffb_keys = any("diff_b" in k and k.startswith(("lora_unet_", "lora_te_")) for k in state_dict)
766790
if any_diffb_keys:
767-
logger.info(
768-
"`diff_b` keys found in the state dict which are currently unsupported. "
769-
"So, we will filter out those keys. Open an issue if this is a problem - "
770-
"https://github.com/huggingface/diffusers/issues/new."
771-
)
791+
zero_status_diff_b = state_dict_all_zero(state_dict, "diff_b")
792+
if zero_status_diff_b:
793+
logger.info(
794+
"The `diff_b` LoRA params are all zeros which make them ineffective. "
795+
"So, we will purge them out of the curret state dict to make loading possible."
796+
)
797+
else:
798+
logger.info(
799+
"`diff_b` keys found in the state dict which are currently unsupported. "
800+
"So, we will filter out those keys. Open an issue if this is a problem - "
801+
"https://github.com/huggingface/diffusers/issues/new."
802+
)
772803
state_dict = {k: v for k, v in state_dict.items() if "diff_b" not in k}
773804

774805
any_norm_diff_keys = any("norm" in k and "diff" in k for k in state_dict)
775806
if any_norm_diff_keys:
776-
logger.info(
777-
"Normalization diff keys found in the state dict which are currently unsupported. "
778-
"So, we will filter out those keys. Open an issue if this is a problem - "
779-
"https://github.com/huggingface/diffusers/issues/new."
780-
)
807+
zero_status_diff = state_dict_all_zero(state_dict, "diff")
808+
if zero_status_diff:
809+
logger.info(
810+
"The `diff` LoRA params are all zeros which make them ineffective. "
811+
"So, we will purge them out of the curret state dict to make loading possible."
812+
)
813+
else:
814+
logger.info(
815+
"Normalization diff keys found in the state dict which are currently unsupported. "
816+
"So, we will filter out those keys. Open an issue if this is a problem - "
817+
"https://github.com/huggingface/diffusers/issues/new."
818+
)
781819
state_dict = {k: v for k, v in state_dict.items() if "norm" not in k and "diff" not in k}
782820

783821
limit_substrings = ["lora_down", "lora_up"]

src/diffusers/utils/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,7 @@
123123
convert_state_dict_to_kohya,
124124
convert_state_dict_to_peft,
125125
convert_unet_state_dict_to_peft,
126+
state_dict_all_zero,
126127
)
127128
from .typing_utils import _get_detailed_type, _is_valid_type
128129

src/diffusers/utils/state_dict_utils.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,14 @@
1717

1818
import enum
1919

20+
from .import_utils import is_torch_available
2021
from .logging import get_logger
2122

2223

24+
if is_torch_available():
25+
import torch
26+
27+
2328
logger = get_logger(__name__)
2429

2530

@@ -64,8 +69,8 @@ class StateDictType(enum.Enum):
6469
".lora_linear_layer.down": ".lora_A",
6570
"text_projection.lora.down.weight": "text_projection.lora_A.weight",
6671
"text_projection.lora.up.weight": "text_projection.lora_B.weight",
67-
"position_embedding.lora.down.weight": "position_embedding.lora_A.weight",
68-
"position_embedding.lora.up.weight": "position_embedding.lora_B.weight",
72+
"position_embedding.lora.down.weight": "position_embedding.lora_embedding_A",
73+
"position_embedding.lora.up.weight": "position_embedding.lora_embedding_B",
6974
}
7075

7176
DIFFUSERS_OLD_TO_PEFT = {
@@ -335,3 +340,12 @@ def convert_state_dict_to_kohya(state_dict, original_type=None, **kwargs):
335340
kohya_ss_state_dict[alpha_key] = torch.tensor(len(weight))
336341

337342
return kohya_ss_state_dict
343+
344+
345+
def state_dict_all_zero(state_dict, filter_str=None):
346+
if filter_str is not None:
347+
if isinstance(filter_str, str):
348+
filter_str = [filter_str]
349+
state_dict = {k: v for k, v in state_dict.items() if any(f in k for f in filter_str)}
350+
351+
return all(torch.all(param == 0).item() for param in state_dict.values())

0 commit comments

Comments
 (0)