Skip to content

Commit 812b4e1

Browse files
committed
support more comyui loras.
1 parent f103993 commit 812b4e1

File tree

3 files changed

+151
-8
lines changed

3 files changed

+151
-8
lines changed

src/diffusers/loaders/lora_base.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -358,6 +358,20 @@ def _load_lora_into_text_encoder(
358358
# convert state dict
359359
text_encoder_lora_state_dict = convert_state_dict_to_peft(text_encoder_lora_state_dict)
360360

361+
if any("position_embedding" in k for k in text_encoder_lora_state_dict):
362+
# TODO: this copying is a big shot in the dark.
363+
# https://huggingface.co/sayakpaul/different-lora-from-civitai/tree/main?show_file_info=RM_Artistify_v1.0M.safetensors
364+
# only has LoRA keys for the position embedding but not the LoRA embedding keys.
365+
text_encoder_lora_state_dict[
366+
"text_model.embeddings.position_embedding.lora_embedding_A.weight"
367+
] = text_encoder_lora_state_dict["text_model.embeddings.position_embedding.lora_A.weight"].clone()
368+
text_encoder_lora_state_dict[
369+
"text_model.embeddings.position_embedding.lora_embedding_B.weight"
370+
] = text_encoder_lora_state_dict["text_model.embeddings.position_embedding.lora_B.weight"].clone()
371+
rank["text_model.embeddings.position_embedding.lora_B.weight"] = text_encoder_lora_state_dict[
372+
"text_model.embeddings.position_embedding.lora_B.weight"
373+
].shape[1]
374+
361375
for name, _ in text_encoder_attn_modules(text_encoder):
362376
for module in ("out_proj", "q_proj", "k_proj", "v_proj"):
363377
rank_key = f"{name}.{module}.lora_B.weight"

src/diffusers/loaders/lora_conversion_utils.py

Lines changed: 135 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# limitations under the License.
1414

1515
import re
16+
from typing import List
1617

1718
import torch
1819

@@ -22,6 +23,12 @@
2223
logger = logging.get_logger(__name__)
2324

2425

26+
def swap_scale_shift(weight):
27+
shift, scale = weight.chunk(2, dim=0)
28+
new_weight = torch.cat([scale, shift], dim=0)
29+
return new_weight
30+
31+
2532
def _maybe_map_sgm_blocks_to_diffusers(state_dict, unet_config, delimiter="_", block_slice_pos=5):
2633
# 1. get all state_dict_keys
2734
all_keys = list(state_dict.keys())
@@ -299,7 +306,9 @@ def _convert_text_encoder_lora_key(key, lora_name):
299306
key_to_replace = "lora_te2_"
300307

301308
diffusers_name = key.replace(key_to_replace, "").replace("_", ".")
309+
302310
diffusers_name = diffusers_name.replace("text.model", "text_model")
311+
diffusers_name = diffusers_name.replace("position.embedding", "position_embedding")
303312
diffusers_name = diffusers_name.replace("self.attn", "self_attn")
304313
diffusers_name = diffusers_name.replace("q.proj.lora", "to_q_lora")
305314
diffusers_name = diffusers_name.replace("k.proj.lora", "to_k_lora")
@@ -313,6 +322,7 @@ def _convert_text_encoder_lora_key(key, lora_name):
313322
# Be aware that this is the new diffusers convention and the rest of the code might
314323
# not utilize it yet.
315324
diffusers_name = diffusers_name.replace(".lora.", ".lora_linear_layer.")
325+
316326
return diffusers_name
317327

318328

@@ -341,7 +351,8 @@ def _convert_to_ai_toolkit(sds_sd, ait_sd, sds_key, ait_key):
341351

342352
# scale weight by alpha and dim
343353
rank = down_weight.shape[0]
344-
alpha = sds_sd.pop(sds_key + ".alpha").item() # alpha is scalar
354+
default_alpha = torch.tensor(rank, dtype=down_weight.dtype, device=down_weight.device, requires_grad=False)
355+
alpha = sds_sd.pop(sds_key + ".alpha", default_alpha).item() # alpha is scalar
345356
scale = alpha / rank # LoRA is scaled by 'alpha / rank' in forward pass, so we need to scale it back here
346357

347358
# calculate scale_down and scale_up to keep the same value. if scale is 4, scale_down is 2 and scale_up is 2
@@ -362,7 +373,10 @@ def _convert_to_ai_toolkit_cat(sds_sd, ait_sd, sds_key, ait_keys, dims=None):
362373
sd_lora_rank = down_weight.shape[0]
363374

364375
# scale weight by alpha and dim
365-
alpha = sds_sd.pop(sds_key + ".alpha")
376+
default_alpha = torch.tensor(
377+
sd_lora_rank, dtype=down_weight.dtype, device=down_weight.device, requires_grad=False
378+
)
379+
alpha = sds_sd.pop(sds_key + ".alpha", default_alpha)
366380
scale = alpha / sd_lora_rank
367381

368382
# calculate scale_down and scale_up
@@ -516,10 +530,62 @@ def _convert_sd_scripts_to_ai_toolkit(sds_sd):
516530
f"transformer.single_transformer_blocks.{i}.norm.linear",
517531
)
518532

533+
# TODO: alphas.
534+
if any("final_layer" in k for k in sds_sd):
535+
for lora_key in ["lora_A", "lora_B"]:
536+
orig_lora_key = "lora_down" if lora_key == "lora_A" else "lora_up"
537+
# Notice the swap.
538+
ait_sd[f"norm_out.linear.{lora_key}.weight"] = swap_scale_shift(
539+
sds_sd.pop(f"lora_unet_final_layer_adaLN_modulation_1.{orig_lora_key}.weight")
540+
)
541+
ait_sd[f"proj_out.{lora_key}.weight"] = sds_sd.pop(
542+
f"lora_unet_final_layer_linear.{orig_lora_key}.weight"
543+
)
544+
545+
if any("guidance_in" in k for k in sds_sd):
546+
for lora_key in ["lora_A", "lora_B"]:
547+
orig_lora_key = "lora_down" if lora_key == "lora_A" else "lora_up"
548+
ait_sd[f"time_text_embed.guidance_embedder.linear_1.{lora_key}.weight"] = sds_sd.pop(
549+
f"lora_unet_guidance_in_in_layer.{orig_lora_key}.weight"
550+
)
551+
ait_sd[f"time_text_embed.guidance_embedder.linear_2.{lora_key}.weight"] = sds_sd.pop(
552+
f"lora_unet_guidance_in_out_layer.{orig_lora_key}.weight"
553+
)
554+
555+
if any("img_in" in k for k in sds_sd):
556+
for lora_key in ["lora_A", "lora_B"]:
557+
orig_lora_key = "lora_down" if lora_key == "lora_A" else "lora_up"
558+
ait_sd[f"x_embedder.{lora_key}.weight"] = sds_sd.pop(f"lora_unet_img_in.{orig_lora_key}.weight")
559+
560+
if any("txt_in" in k for k in sds_sd):
561+
for lora_key in ["lora_A", "lora_B"]:
562+
orig_lora_key = "lora_down" if lora_key == "lora_A" else "lora_up"
563+
ait_sd[f"context_embedder.{lora_key}.weight"] = sds_sd.pop(f"lora_unet_txt_in.{orig_lora_key}.weight")
564+
565+
if any("time_in" in k for k in state_dict):
566+
for lora_key in ["lora_A", "lora_B"]:
567+
orig_lora_key = "lora_down" if lora_key == "lora_A" else "lora_up"
568+
ait_sd[f"time_text_embed.timestep_embedder.linear_1.{lora_key}.weight"] = sds_sd.pop(
569+
f"lora_unet_time_in_in_layer.{orig_lora_key}.weight"
570+
)
571+
ait_sd[f"time_text_embed.timestep_embedder.linear_2.{lora_key}.weight"] = sds_sd.pop(
572+
f"lora_unet_time_in_out_layer.{orig_lora_key}.weight"
573+
)
574+
575+
if any("vector_in" in k for k in sds_sd):
576+
for lora_key in ["lora_A", "lora_B"]:
577+
orig_lora_key = "lora_down" if lora_key == "lora_A" else "lora_up"
578+
ait_sd[f"time_text_embed.text_embedder.linear_1.{lora_key}.weight"] = sds_sd.pop(
579+
f"lora_unet_vector_in_in_layer.{orig_lora_key}.weight"
580+
)
581+
ait_sd[f"time_text_embed.text_embedder.linear_2.{lora_key}.weight"] = sds_sd.pop(
582+
f"lora_unet_vector_in_out_layer.{orig_lora_key}.weight"
583+
)
584+
519585
remaining_keys = list(sds_sd.keys())
520586
te_state_dict = {}
521587
if remaining_keys:
522-
if not all(k.startswith("lora_te") for k in remaining_keys):
588+
if not all(k.startswith(("lora_te", "lora_te1")) for k in remaining_keys):
523589
raise ValueError(f"Incompatible keys detected: \n\n {', '.join(remaining_keys)}")
524590
for key in remaining_keys:
525591
if not key.endswith("lora_down.weight"):
@@ -680,10 +746,59 @@ def _convert(original_key, diffusers_key, state_dict, new_state_dict):
680746
if has_peft_state_dict:
681747
state_dict = {k: v for k, v in state_dict.items() if k.startswith("transformer.")}
682748
return state_dict
749+
683750
# Another weird one.
684751
has_mixture = any(
685752
k.startswith("lora_transformer_") and ("lora_down" in k or "lora_up" in k or "alpha" in k) for k in state_dict
686753
)
754+
755+
# ComfyUI.
756+
state_dict = {k.replace("diffusion_model.", "lora_unet."): v for k, v in state_dict.items()}
757+
state_dict = {k.replace("text_encoders.clip_l.transformer.", "lora_te."): v for k, v in state_dict.items()}
758+
has_t5xxl = any(k.startswith("text_encoders.t5xxl.transformer.") for k in state_dict)
759+
if has_t5xxl:
760+
logger.info(
761+
"T5-xxl keys found in the state dict, which are currently unsupported. We will filter them out."
762+
"Open an issue if this is a problem - https://github.com/huggingface/diffusers/issues/new."
763+
)
764+
state_dict = {k: v for k, v in state_dict.items() if not k.startswith("text_encoders.t5xxl.transformer.")}
765+
766+
any_diffb_keys = any("diff_b" in k and k.startswith(("lora_unet.", "lora_te.")) for k in state_dict)
767+
if any_diffb_keys:
768+
logger.info(
769+
"`diff_b` keys found in the state dict which are currently unsupported. "
770+
"So, we will filter out those keys. Open an issue if this is a problem - "
771+
"https://github.com/huggingface/diffusers/issues/new."
772+
)
773+
state_dict = {k: v for k, v in state_dict.items() if "diff_b" not in k}
774+
775+
any_norm_diff_keys = any("norm" in k and "diff" in k for k in state_dict)
776+
if any_norm_diff_keys:
777+
logger.info(
778+
"Normalization diff keys found in the state dict which are currently unsupported. "
779+
"So, we will filter out those keys. Open an issue if this is a problem - "
780+
"https://github.com/huggingface/diffusers/issues/new."
781+
)
782+
state_dict = {k: v for k, v in state_dict.items() if "norm" not in k and "diff" not in k}
783+
784+
limit_substrings = ["lora_down", "lora_up"]
785+
if any("alpha" in k for k in state_dict):
786+
limit_substrings.append("alpha")
787+
788+
state_dict = {
789+
_custom_replace(k, limit_substrings): v
790+
for k, v in state_dict.items()
791+
if k.startswith(("lora_unet.", "lora_te."))
792+
}
793+
794+
if any("text_projection" in k for k in state_dict):
795+
logger.info(
796+
"`text_projection` keys found in the state_dict which are unexpected. "
797+
"So, we will filter out those keys. Open an issue if this is a problem - "
798+
"https://github.com/huggingface/diffusers/issues/new."
799+
)
800+
state_dict = {k: v for k, v in state_dict.items() if "text_projection" not in k}
801+
687802
if has_mixture:
688803
return _convert_mixture_state_dict_to_diffusers(state_dict)
689804

@@ -798,6 +913,23 @@ def handle_qkv(sds_sd, ait_sd, sds_key, ait_keys, dims=None):
798913
return new_state_dict
799914

800915

916+
def _custom_replace(key: str, substrings: List[str]) -> str:
917+
pattern = "(" + "|".join(re.escape(sub) for sub in substrings) + ")"
918+
919+
match = re.search(pattern, key)
920+
if match:
921+
start_sub = match.start()
922+
if start_sub > 0 and key[start_sub - 1] == ".":
923+
boundary = start_sub - 1
924+
else:
925+
boundary = start_sub
926+
left = key[:boundary].replace(".", "_")
927+
right = key[boundary:]
928+
return left + right
929+
else:
930+
return key.replace(".", "_")
931+
932+
801933
def _convert_bfl_flux_control_lora_to_diffusers(original_state_dict):
802934
converted_state_dict = {}
803935
original_state_dict_keys = list(original_state_dict.keys())
@@ -806,11 +938,6 @@ def _convert_bfl_flux_control_lora_to_diffusers(original_state_dict):
806938
inner_dim = 3072
807939
mlp_ratio = 4.0
808940

809-
def swap_scale_shift(weight):
810-
shift, scale = weight.chunk(2, dim=0)
811-
new_weight = torch.cat([scale, shift], dim=0)
812-
return new_weight
813-
814941
for lora_key in ["lora_A", "lora_B"]:
815942
## time_text_embed.timestep_embedder <- time_in
816943
converted_state_dict[

src/diffusers/utils/state_dict_utils.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,8 @@ class StateDictType(enum.Enum):
6464
".lora_linear_layer.down": ".lora_A",
6565
"text_projection.lora.down.weight": "text_projection.lora_A.weight",
6666
"text_projection.lora.up.weight": "text_projection.lora_B.weight",
67+
"position_embedding.lora.down.weight": "position_embedding.lora_A.weight",
68+
"position_embedding.lora.up.weight": "position_embedding.lora_B.weight",
6769
}
6870

6971
DIFFUSERS_OLD_TO_PEFT = {

0 commit comments

Comments
 (0)