Skip to content

Commit ba546bc

Browse files
committed
fix tests
1 parent d8a305e commit ba546bc

File tree

5 files changed

+8
-11
lines changed

5 files changed

+8
-11
lines changed

src/diffusers/loaders/lora_pipeline.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4889,7 +4889,7 @@ def load_lora_weights(
48894889
adapter_name=adapter_name,
48904890
_pipeline=self,
48914891
low_cpu_mem_usage=low_cpu_mem_usage,
4892-
load_with_metdata=load_with_metdata,
4892+
load_with_metadata=load_with_metdata,
48934893
)
48944894

48954895
@classmethod

src/diffusers/loaders/peft.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -236,15 +236,12 @@ def load_lora_adapter(
236236
if network_alphas is not None and prefix is None:
237237
raise ValueError("`network_alphas` cannot be None when `prefix` is None.")
238238

239-
if load_with_metadata is not None and not use_safetensors:
240-
raise ValueError("`load_with_metadata` cannot be specified when not using `use_safetensors`.")
241-
242239
if prefix is not None:
243-
metadata = state_dict.pop("_metadata", None)
240+
metadata = state_dict.pop("lora_metadata", None)
244241
state_dict = {k[len(f"{prefix}.") :]: v for k, v in state_dict.items() if k.startswith(f"{prefix}.")}
245242

246243
if metadata is not None:
247-
state_dict["_metadata"] = metadata
244+
state_dict["lora_metadata"] = metadata
248245

249246
if len(state_dict) > 0:
250247
if adapter_name in getattr(self, "peft_config", {}) and not hotswap:

src/diffusers/utils/peft_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -151,9 +151,9 @@ def get_peft_kwargs(
151151
rank_dict, network_alpha_dict, peft_state_dict, is_unet=True, prefix=None, load_with_metadata=False
152152
):
153153
if load_with_metadata:
154-
if "_metadata" not in peft_state_dict:
154+
if "lora_metadata" not in peft_state_dict:
155155
raise ValueError("Couldn't find '_metadata' key in the `peft_state_dict`.")
156-
metadata = peft_state_dict["_metadata"]
156+
metadata = peft_state_dict["lora_metadata"]
157157
if prefix is not None:
158158
metadata = {k.replace(f"{prefix}.", ""): v for k, v in metadata.items()}
159159
return metadata

src/diffusers/utils/state_dict_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -360,7 +360,7 @@ def _maybe_populate_state_dict_with_metadata(state_dict, model_file, metadata_ke
360360
metadata_keys = list(metadata.keys())
361361
if not (len(metadata_keys) == 1 and metadata_keys[0] == "format"):
362362
peft_metadata = {k: v for k, v in metadata.items() if k != "format"}
363-
state_dict["_metadata"] = json.loads(peft_metadata[metadata_key])
363+
state_dict["lora_metadata"] = json.loads(peft_metadata[metadata_key])
364364
else:
365365
raise ValueError("Metadata couldn't be parsed from the safetensors file.")
366366
return state_dict

tests/lora/test_lora_layers_wan.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -162,9 +162,9 @@ def test_adapter_metadata_is_loaded_correctly(self):
162162
pipe.unload_lora_weights()
163163
state_dict = pipe.lora_state_dict(tmpdir, load_with_metadata=True)
164164

165-
self.assertTrue("_metadata" in state_dict)
165+
self.assertTrue("lora_metadata" in state_dict)
166166

167-
parsed_metadata = state_dict["_metadata"]
167+
parsed_metadata = state_dict["lora_metadata"]
168168
parsed_metadata = {k[len("transformer.") :]: v for k, v in parsed_metadata.items()}
169169
check_if_dicts_are_equal(parsed_metadata, metadata)
170170

0 commit comments

Comments
 (0)