Skip to content

Commit 367153d

Browse files
committed
fix
1 parent 812b4e1 commit 367153d

File tree

1 file changed

+6
-7
lines changed

1 file changed

+6
-7
lines changed

src/diffusers/loaders/lora_conversion_utils.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -341,8 +341,7 @@ def _get_alpha_name(lora_name_alpha, diffusers_name, alpha):
341341

342342

343343
# The utilities under `_convert_kohya_flux_lora_to_diffusers()`
344-
# are taken from https://github.com/kohya-ss/sd-scripts/blob/a61cf73a5cb5209c3f4d1a3688dd276a4dfd1ecb/networks/convert_flux_lora.py
345-
# All credits go to `kohya-ss`.
344+
# are adapted from https://github.com/kohya-ss/sd-scripts/blob/a61cf73a5cb5209c3f4d1a3688dd276a4dfd1ecb/networks/convert_flux_lora.py
346345
def _convert_kohya_flux_lora_to_diffusers(state_dict):
347346
def _convert_to_ai_toolkit(sds_sd, ait_sd, sds_key, ait_key):
348347
if sds_key + ".lora_down.weight" not in sds_sd:
@@ -753,8 +752,8 @@ def _convert(original_key, diffusers_key, state_dict, new_state_dict):
753752
)
754753

755754
# ComfyUI.
756-
state_dict = {k.replace("diffusion_model.", "lora_unet."): v for k, v in state_dict.items()}
757-
state_dict = {k.replace("text_encoders.clip_l.transformer.", "lora_te."): v for k, v in state_dict.items()}
755+
state_dict = {k.replace("diffusion_model.", "lora_unet_"): v for k, v in state_dict.items()}
756+
state_dict = {k.replace("text_encoders.clip_l.transformer.", "lora_te_"): v for k, v in state_dict.items()}
758757
has_t5xxl = any(k.startswith("text_encoders.t5xxl.transformer.") for k in state_dict)
759758
if has_t5xxl:
760759
logger.info(
@@ -763,7 +762,7 @@ def _convert(original_key, diffusers_key, state_dict, new_state_dict):
763762
)
764763
state_dict = {k: v for k, v in state_dict.items() if not k.startswith("text_encoders.t5xxl.transformer.")}
765764

766-
any_diffb_keys = any("diff_b" in k and k.startswith(("lora_unet.", "lora_te.")) for k in state_dict)
765+
any_diffb_keys = any("diff_b" in k and k.startswith(("lora_unet_", "lora_te_")) for k in state_dict)
767766
if any_diffb_keys:
768767
logger.info(
769768
"`diff_b` keys found in the state dict which are currently unsupported. "
@@ -788,12 +787,12 @@ def _convert(original_key, diffusers_key, state_dict, new_state_dict):
788787
state_dict = {
789788
_custom_replace(k, limit_substrings): v
790789
for k, v in state_dict.items()
791-
if k.startswith(("lora_unet.", "lora_te."))
790+
if k.startswith(("lora_unet_", "lora_te_"))
792791
}
793792

794793
if any("text_projection" in k for k in state_dict):
795794
logger.info(
796-
"`text_projection` keys found in the state_dict which are unexpected. "
795+
"`text_projection` keys found in the `state_dict` which are unexpected. "
797796
"So, we will filter out those keys. Open an issue if this is a problem - "
798797
"https://github.com/huggingface/diffusers/issues/new."
799798
)

0 commit comments

Comments
 (0)