Skip to content

Commit c28d6f3

Browse files
committed
add tests
1 parent d3afa26 commit c28d6f3

File tree

2 files changed

+51
-0
lines changed

2 files changed

+51
-0
lines changed

src/diffusers/loaders/lora_base.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -349,6 +349,18 @@ def _optionally_disable_offloading(cls, _pipeline):
349349

350350
return (is_model_cpu_offload, is_sequential_cpu_offload)
351351

352+
@classmethod
353+
def _fetch_state_dict(cls, *args, **kwargs):
354+
deprecation_message = f"Using the `_fetch_state_dict()` method from {cls} has been deprecated and will be removed in a future version. Please use `from diffusers.loaders.lora_base import _fetch_state_dict`."
355+
deprecate("_fetch_state_dict", "0.35.0", deprecation_message)
356+
_fetch_state_dict(*args, **kwargs)
357+
358+
@classmethod
359+
def _best_guess_weight_name(cls, *args, **kwargs):
360+
deprecation_message = f"Using the `_best_guess_weight_name()` method from {cls} has been deprecated and will be removed in a future version. Please use `from diffusers.loaders.lora_base import _best_guess_weight_name`."
361+
deprecate("_best_guess_weight_name", "0.35.0", deprecation_message)
362+
_best_guess_weight_name(*args, **kwargs)
363+
352364
def unload_lora_weights(self):
353365
"""
354366
Unloads the LoRA parameters.
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
import os
2+
import tempfile
3+
import unittest
4+
5+
import torch
6+
7+
from diffusers.loaders.lora_base import LoraBaseMixin
8+
9+
10+
class UtilityMethodDeprecationTests(unittest.TestCase):
11+
def test_fetch_state_dict_cls_method_raises_warning(self):
12+
state_dict = torch.nn.Linear(3, 3).state_dict()
13+
with self.assertWarns(FutureWarning) as warning:
14+
_ = LoraBaseMixin._fetch_state_dict(
15+
state_dict,
16+
weight_name=None,
17+
use_safetensors=False,
18+
local_files_only=True,
19+
cache_dir=None,
20+
force_download=False,
21+
proxies=None,
22+
token=None,
23+
revision=None,
24+
subfolder=None,
25+
user_agent=None,
26+
allow_pickle=None,
27+
)
28+
warning_message = str(warning.warnings[0].message)
29+
assert "Using the `_fetch_state_dict()` method from" in warning_message
30+
31+
def test_best_guess_weight_name_cls_method_raises_warning(self):
32+
with tempfile.TemporaryDirectory() as tmpdir:
33+
state_dict = torch.nn.Linear(3, 3).state_dict()
34+
torch.save(state_dict, os.path.join(tmpdir, "pytorch_lora_weights.bin"))
35+
36+
with self.assertWarns(FutureWarning) as warning:
37+
_ = LoraBaseMixin._best_guess_weight_name(pretrained_model_name_or_path_or_dict=tmpdir)
38+
warning_message = str(warning.warnings[0].message)
39+
assert "Using the `_best_guess_weight_name()` method from" in warning_message

0 commit comments

Comments
 (0)