Skip to content

Commit 26149c0

Browse files
sayakpaulhlky
andauthored
[LoRA] Improve warning messages when LoRA loading becomes a no-op (#10187)
* updates * updates * updates * updates * notebooks revert * fix-copies. * seeing * fix * revert * fixes * fixes * fixes * remove print * fix * conflicts ii. * updates * fixes * better filtering of prefix. --------- Co-authored-by: hlky <hlky@hlky.ac>
1 parent 0703ce8 commit 26149c0

File tree

5 files changed

+244
-223
lines changed

5 files changed

+244
-223
lines changed

src/diffusers/loaders/lora_base.py

Lines changed: 77 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -339,93 +339,93 @@ def _load_lora_into_text_encoder(
339339
# If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918),
340340
# then the `state_dict` keys should have `unet_name` and/or `text_encoder_name` as
341341
# their prefixes.
342-
keys = list(state_dict.keys())
343342
prefix = text_encoder_name if prefix is None else prefix
344343

345-
# Safe prefix to check with.
346-
if any(text_encoder_name in key for key in keys):
347-
# Load the layers corresponding to text encoder and make necessary adjustments.
348-
text_encoder_keys = [k for k in keys if k.startswith(prefix) and k.split(".")[0] == prefix]
349-
text_encoder_lora_state_dict = {
350-
k.replace(f"{prefix}.", ""): v for k, v in state_dict.items() if k in text_encoder_keys
351-
}
344+
# Load the layers corresponding to text encoder and make necessary adjustments.
345+
if prefix is not None:
346+
state_dict = {k[len(f"{prefix}.") :]: v for k, v in state_dict.items() if k.startswith(f"{prefix}.")}
347+
348+
if len(state_dict) > 0:
349+
logger.info(f"Loading {prefix}.")
350+
rank = {}
351+
state_dict = convert_state_dict_to_diffusers(state_dict)
352+
353+
# convert state dict
354+
state_dict = convert_state_dict_to_peft(state_dict)
355+
356+
for name, _ in text_encoder_attn_modules(text_encoder):
357+
for module in ("out_proj", "q_proj", "k_proj", "v_proj"):
358+
rank_key = f"{name}.{module}.lora_B.weight"
359+
if rank_key not in state_dict:
360+
continue
361+
rank[rank_key] = state_dict[rank_key].shape[1]
362+
363+
for name, _ in text_encoder_mlp_modules(text_encoder):
364+
for module in ("fc1", "fc2"):
365+
rank_key = f"{name}.{module}.lora_B.weight"
366+
if rank_key not in state_dict:
367+
continue
368+
rank[rank_key] = state_dict[rank_key].shape[1]
369+
370+
if network_alphas is not None:
371+
alpha_keys = [k for k in network_alphas.keys() if k.startswith(prefix) and k.split(".")[0] == prefix]
372+
network_alphas = {k.replace(f"{prefix}.", ""): v for k, v in network_alphas.items() if k in alpha_keys}
373+
374+
lora_config_kwargs = get_peft_kwargs(rank, network_alphas, state_dict, is_unet=False)
375+
376+
if "use_dora" in lora_config_kwargs:
377+
if lora_config_kwargs["use_dora"]:
378+
if is_peft_version("<", "0.9.0"):
379+
raise ValueError(
380+
"You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`."
381+
)
382+
else:
383+
if is_peft_version("<", "0.9.0"):
384+
lora_config_kwargs.pop("use_dora")
385+
386+
if "lora_bias" in lora_config_kwargs:
387+
if lora_config_kwargs["lora_bias"]:
388+
if is_peft_version("<=", "0.13.2"):
389+
raise ValueError(
390+
"You need `peft` 0.14.0 at least to use `bias` in LoRAs. Please upgrade your installation of `peft`."
391+
)
392+
else:
393+
if is_peft_version("<=", "0.13.2"):
394+
lora_config_kwargs.pop("lora_bias")
352395

353-
if len(text_encoder_lora_state_dict) > 0:
354-
logger.info(f"Loading {prefix}.")
355-
rank = {}
356-
text_encoder_lora_state_dict = convert_state_dict_to_diffusers(text_encoder_lora_state_dict)
357-
358-
# convert state dict
359-
text_encoder_lora_state_dict = convert_state_dict_to_peft(text_encoder_lora_state_dict)
360-
361-
for name, _ in text_encoder_attn_modules(text_encoder):
362-
for module in ("out_proj", "q_proj", "k_proj", "v_proj"):
363-
rank_key = f"{name}.{module}.lora_B.weight"
364-
if rank_key not in text_encoder_lora_state_dict:
365-
continue
366-
rank[rank_key] = text_encoder_lora_state_dict[rank_key].shape[1]
367-
368-
for name, _ in text_encoder_mlp_modules(text_encoder):
369-
for module in ("fc1", "fc2"):
370-
rank_key = f"{name}.{module}.lora_B.weight"
371-
if rank_key not in text_encoder_lora_state_dict:
372-
continue
373-
rank[rank_key] = text_encoder_lora_state_dict[rank_key].shape[1]
374-
375-
if network_alphas is not None:
376-
alpha_keys = [k for k in network_alphas.keys() if k.startswith(prefix) and k.split(".")[0] == prefix]
377-
network_alphas = {k.replace(f"{prefix}.", ""): v for k, v in network_alphas.items() if k in alpha_keys}
378-
379-
lora_config_kwargs = get_peft_kwargs(rank, network_alphas, text_encoder_lora_state_dict, is_unet=False)
380-
381-
if "use_dora" in lora_config_kwargs:
382-
if lora_config_kwargs["use_dora"]:
383-
if is_peft_version("<", "0.9.0"):
384-
raise ValueError(
385-
"You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`."
386-
)
387-
else:
388-
if is_peft_version("<", "0.9.0"):
389-
lora_config_kwargs.pop("use_dora")
390-
391-
if "lora_bias" in lora_config_kwargs:
392-
if lora_config_kwargs["lora_bias"]:
393-
if is_peft_version("<=", "0.13.2"):
394-
raise ValueError(
395-
"You need `peft` 0.14.0 at least to use `bias` in LoRAs. Please upgrade your installation of `peft`."
396-
)
397-
else:
398-
if is_peft_version("<=", "0.13.2"):
399-
lora_config_kwargs.pop("lora_bias")
396+
lora_config = LoraConfig(**lora_config_kwargs)
400397

401-
lora_config = LoraConfig(**lora_config_kwargs)
398+
# adapter_name
399+
if adapter_name is None:
400+
adapter_name = get_adapter_name(text_encoder)
402401

403-
# adapter_name
404-
if adapter_name is None:
405-
adapter_name = get_adapter_name(text_encoder)
402+
is_model_cpu_offload, is_sequential_cpu_offload = _func_optionally_disable_offloading(_pipeline)
406403

407-
is_model_cpu_offload, is_sequential_cpu_offload = _func_optionally_disable_offloading(_pipeline)
404+
# inject LoRA layers and load the state dict
405+
# in transformers we automatically check whether the adapter name is already in use or not
406+
text_encoder.load_adapter(
407+
adapter_name=adapter_name,
408+
adapter_state_dict=state_dict,
409+
peft_config=lora_config,
410+
**peft_kwargs,
411+
)
408412

409-
# inject LoRA layers and load the state dict
410-
# in transformers we automatically check whether the adapter name is already in use or not
411-
text_encoder.load_adapter(
412-
adapter_name=adapter_name,
413-
adapter_state_dict=text_encoder_lora_state_dict,
414-
peft_config=lora_config,
415-
**peft_kwargs,
416-
)
413+
# scale LoRA layers with `lora_scale`
414+
scale_lora_layers(text_encoder, weight=lora_scale)
417415

418-
# scale LoRA layers with `lora_scale`
419-
scale_lora_layers(text_encoder, weight=lora_scale)
416+
text_encoder.to(device=text_encoder.device, dtype=text_encoder.dtype)
420417

421-
text_encoder.to(device=text_encoder.device, dtype=text_encoder.dtype)
418+
# Offload back.
419+
if is_model_cpu_offload:
420+
_pipeline.enable_model_cpu_offload()
421+
elif is_sequential_cpu_offload:
422+
_pipeline.enable_sequential_cpu_offload()
423+
# Unsafe code />
422424

423-
# Offload back.
424-
if is_model_cpu_offload:
425-
_pipeline.enable_model_cpu_offload()
426-
elif is_sequential_cpu_offload:
427-
_pipeline.enable_sequential_cpu_offload()
428-
# Unsafe code />
425+
if prefix is not None and not state_dict:
426+
logger.info(
427+
f"No LoRA keys associated to {text_encoder.__class__.__name__} found with the {prefix=}. This is safe to ignore if LoRA state dict didn't originally have any {text_encoder.__class__.__name__} related params. Open an issue if you think it's unexpected: https://github.com/huggingface/diffusers/issues/new"
428+
)
429429

430430

431431
def _func_optionally_disable_offloading(_pipeline):

0 commit comments

Comments
 (0)