Skip to content

Commit 10bee52

Browse files
authored
[LoRA] use removeprefix to preserve sanity. (#11493)
* use removeprefix to preserve sanity. * f-string.
1 parent d88ae1f commit 10bee52

File tree

3 files changed

+8
-6
lines changed

3 files changed

+8
-6
lines changed

src/diffusers/loaders/lora_base.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -348,7 +348,7 @@ def _load_lora_into_text_encoder(
348348

349349
# Load the layers corresponding to text encoder and make necessary adjustments.
350350
if prefix is not None:
351-
state_dict = {k[len(f"{prefix}.") :]: v for k, v in state_dict.items() if k.startswith(f"{prefix}.")}
351+
state_dict = {k.removeprefix(f"{prefix}."): v for k, v in state_dict.items() if k.startswith(f"{prefix}.")}
352352

353353
if len(state_dict) > 0:
354354
logger.info(f"Loading {prefix}.")
@@ -374,7 +374,7 @@ def _load_lora_into_text_encoder(
374374

375375
if network_alphas is not None:
376376
alpha_keys = [k for k in network_alphas.keys() if k.startswith(prefix) and k.split(".")[0] == prefix]
377-
network_alphas = {k.replace(f"{prefix}.", ""): v for k, v in network_alphas.items() if k in alpha_keys}
377+
network_alphas = {k.removeprefix(f"{prefix}."): v for k, v in network_alphas.items() if k in alpha_keys}
378378

379379
lora_config_kwargs = get_peft_kwargs(rank, network_alphas, state_dict, is_unet=False)
380380

src/diffusers/loaders/lora_pipeline.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2103,7 +2103,7 @@ def _load_norm_into_transformer(
21032103
prefix = prefix or cls.transformer_name
21042104
for key in list(state_dict.keys()):
21052105
if key.split(".")[0] == prefix:
2106-
state_dict[key[len(f"{prefix}.") :]] = state_dict.pop(key)
2106+
state_dict[key.removeprefix(f"{prefix}.")] = state_dict.pop(key)
21072107

21082108
# Find invalid keys
21092109
transformer_state_dict = transformer.state_dict()
@@ -2425,7 +2425,7 @@ def _maybe_expand_transformer_param_shape_or_error_(
24252425
prefix = prefix or cls.transformer_name
24262426
for key in list(state_dict.keys()):
24272427
if key.split(".")[0] == prefix:
2428-
state_dict[key[len(f"{prefix}.") :]] = state_dict.pop(key)
2428+
state_dict[key.removeprefix(f"{prefix}.")] = state_dict.pop(key)
24292429

24302430
# Expand transformer parameter shapes if they don't match lora
24312431
has_param_with_shape_update = False

src/diffusers/loaders/peft.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -230,7 +230,7 @@ def load_lora_adapter(
230230
raise ValueError("`network_alphas` cannot be None when `prefix` is None.")
231231

232232
if prefix is not None:
233-
state_dict = {k[len(f"{prefix}.") :]: v for k, v in state_dict.items() if k.startswith(f"{prefix}.")}
233+
state_dict = {k.removeprefix(f"{prefix}."): v for k, v in state_dict.items() if k.startswith(f"{prefix}.")}
234234

235235
if len(state_dict) > 0:
236236
if adapter_name in getattr(self, "peft_config", {}) and not hotswap:
@@ -261,7 +261,9 @@ def load_lora_adapter(
261261

262262
if network_alphas is not None and len(network_alphas) >= 1:
263263
alpha_keys = [k for k in network_alphas.keys() if k.startswith(f"{prefix}.")]
264-
network_alphas = {k.replace(f"{prefix}.", ""): v for k, v in network_alphas.items() if k in alpha_keys}
264+
network_alphas = {
265+
k.removeprefix(f"{prefix}."): v for k, v in network_alphas.items() if k in alpha_keys
266+
}
265267

266268
lora_config_kwargs = get_peft_kwargs(rank, network_alpha_dict=network_alphas, peft_state_dict=state_dict)
267269
_maybe_raise_error_for_ambiguity(lora_config_kwargs)

0 commit comments

Comments
 (0)