Skip to content

Commit 7d0b9c4

Browse files
authored
[LoRA] feat: save_lora_adapter() (#9862)
* feat: save_lora_adapter.
1 parent acf479b commit 7d0b9c4

File tree

6 files changed

+210
-55
lines changed

6 files changed

+210
-55
lines changed

src/diffusers/loaders/lora_pipeline.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -298,8 +298,9 @@ def load_lora_into_unet(
298298
if not only_text_encoder:
299299
# Load the layers corresponding to UNet.
300300
logger.info(f"Loading {cls.unet_name}.")
301-
unet.load_attn_procs(
301+
unet.load_lora_adapter(
302302
state_dict,
303+
prefix=cls.unet_name,
303304
network_alphas=network_alphas,
304305
adapter_name=adapter_name,
305306
_pipeline=_pipeline,
@@ -827,8 +828,9 @@ def load_lora_into_unet(
827828
if not only_text_encoder:
828829
# Load the layers corresponding to UNet.
829830
logger.info(f"Loading {cls.unet_name}.")
830-
unet.load_attn_procs(
831+
unet.load_lora_adapter(
831832
state_dict,
833+
prefix=cls.unet_name,
832834
network_alphas=network_alphas,
833835
adapter_name=adapter_name,
834836
_pipeline=_pipeline,

src/diffusers/loaders/peft.py

Lines changed: 88 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,13 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515
import inspect
16+
import os
1617
from functools import partial
18+
from pathlib import Path
1719
from typing import Dict, List, Optional, Union
1820

21+
import safetensors
22+
import torch
1923
import torch.nn as nn
2024

2125
from ..utils import (
@@ -189,40 +193,45 @@ def load_lora_adapter(self, pretrained_model_name_or_path_or_dict, prefix="trans
189193
user_agent=user_agent,
190194
allow_pickle=allow_pickle,
191195
)
196+
if network_alphas is not None and prefix is None:
197+
raise ValueError("`network_alphas` cannot be None when `prefix` is None.")
192198

193-
keys = list(state_dict.keys())
194-
transformer_keys = [k for k in keys if k.startswith(prefix)]
195-
if len(transformer_keys) > 0:
196-
state_dict = {k.replace(f"{prefix}.", ""): v for k, v in state_dict.items() if k in transformer_keys}
199+
if prefix is not None:
200+
keys = list(state_dict.keys())
201+
model_keys = [k for k in keys if k.startswith(f"{prefix}.")]
202+
if len(model_keys) > 0:
203+
state_dict = {k.replace(f"{prefix}.", ""): v for k, v in state_dict.items() if k in model_keys}
204+
205+
if len(state_dict) > 0:
206+
if adapter_name in getattr(self, "peft_config", {}):
207+
raise ValueError(
208+
f"Adapter name {adapter_name} already in use in the model - please select a new adapter name."
209+
)
197210

198-
if len(state_dict.keys()) > 0:
199211
# check with first key if is not in peft format
200212
first_key = next(iter(state_dict.keys()))
201213
if "lora_A" not in first_key:
202214
state_dict = convert_unet_state_dict_to_peft(state_dict)
203215

204-
if adapter_name in getattr(self, "peft_config", {}):
205-
raise ValueError(
206-
f"Adapter name {adapter_name} already in use in the transformer - please select a new adapter name."
207-
)
208-
209216
rank = {}
210217
for key, val in state_dict.items():
211218
if "lora_B" in key:
212219
rank[key] = val.shape[1]
213220

214221
if network_alphas is not None and len(network_alphas) >= 1:
215-
alpha_keys = [k for k in network_alphas.keys() if k.startswith(prefix) and k.split(".")[0] == prefix]
222+
alpha_keys = [k for k in network_alphas.keys() if k.startswith(f"{prefix}.")]
216223
network_alphas = {k.replace(f"{prefix}.", ""): v for k, v in network_alphas.items() if k in alpha_keys}
217224

218225
lora_config_kwargs = get_peft_kwargs(rank, network_alpha_dict=network_alphas, peft_state_dict=state_dict)
219226
if "use_dora" in lora_config_kwargs:
220-
if lora_config_kwargs["use_dora"] and is_peft_version("<", "0.9.0"):
221-
raise ValueError(
222-
"You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`."
223-
)
227+
if lora_config_kwargs["use_dora"]:
228+
if is_peft_version("<", "0.9.0"):
229+
raise ValueError(
230+
"You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`."
231+
)
224232
else:
225-
lora_config_kwargs.pop("use_dora")
233+
if is_peft_version("<", "0.9.0"):
234+
lora_config_kwargs.pop("use_dora")
226235
lora_config = LoraConfig(**lora_config_kwargs)
227236

228237
# adapter_name
@@ -276,6 +285,69 @@ def load_lora_adapter(self, pretrained_model_name_or_path_or_dict, prefix="trans
276285
_pipeline.enable_sequential_cpu_offload()
277286
# Unsafe code />
278287

288+
def save_lora_adapter(
289+
self,
290+
save_directory,
291+
adapter_name: str = "default",
292+
upcast_before_saving: bool = False,
293+
safe_serialization: bool = True,
294+
weight_name: Optional[str] = None,
295+
):
296+
"""
297+
Save the LoRA parameters corresponding to the underlying model.
298+
299+
Arguments:
300+
save_directory (`str` or `os.PathLike`):
301+
Directory to save LoRA parameters to. Will be created if it doesn't exist.
302+
adapter_name: (`str`, defaults to "default"): The name of the adapter to serialize. Useful when the
303+
underlying model has multiple adapters loaded.
304+
upcast_before_saving (`bool`, defaults to `False`):
305+
Whether to cast the underlying model to `torch.float32` before serialization.
306+
save_function (`Callable`):
307+
The function to use to save the state dictionary. Useful during distributed training when you need to
308+
replace `torch.save` with another method. Can be configured with the environment variable
309+
`DIFFUSERS_SAVE_MODE`.
310+
safe_serialization (`bool`, *optional*, defaults to `True`):
311+
Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
312+
weight_name: (`str`, *optional*, defaults to `None`): Name of the file to serialize the state dict with.
313+
"""
314+
from peft.utils import get_peft_model_state_dict
315+
316+
from .lora_base import LORA_WEIGHT_NAME, LORA_WEIGHT_NAME_SAFE
317+
318+
if adapter_name is None:
319+
adapter_name = get_adapter_name(self)
320+
321+
if adapter_name not in getattr(self, "peft_config", {}):
322+
raise ValueError(f"Adapter name {adapter_name} not found in the model.")
323+
324+
lora_layers_to_save = get_peft_model_state_dict(
325+
self.to(dtype=torch.float32 if upcast_before_saving else None), adapter_name=adapter_name
326+
)
327+
if os.path.isfile(save_directory):
328+
raise ValueError(f"Provided path ({save_directory}) should be a directory, not a file")
329+
330+
if safe_serialization:
331+
332+
def save_function(weights, filename):
333+
return safetensors.torch.save_file(weights, filename, metadata={"format": "pt"})
334+
335+
else:
336+
save_function = torch.save
337+
338+
os.makedirs(save_directory, exist_ok=True)
339+
340+
if weight_name is None:
341+
if safe_serialization:
342+
weight_name = LORA_WEIGHT_NAME_SAFE
343+
else:
344+
weight_name = LORA_WEIGHT_NAME
345+
346+
# TODO: we could consider saving the `peft_config` as well.
347+
save_path = Path(save_directory, weight_name).as_posix()
348+
save_function(lora_layers_to_save, save_path)
349+
logger.info(f"Model weights saved in {save_path}")
350+
279351
def set_adapters(
280352
self,
281353
adapter_names: Union[List[str], str],

src/diffusers/loaders/unet.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
USE_PEFT_BACKEND,
3737
_get_model_file,
3838
convert_unet_state_dict_to_peft,
39+
deprecate,
3940
get_adapter_name,
4041
get_peft_kwargs,
4142
is_accelerate_available,
@@ -209,6 +210,10 @@ def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict
209210
is_model_cpu_offload = False
210211
is_sequential_cpu_offload = False
211212

213+
if is_lora:
214+
deprecation_message = "Using the `load_attn_procs()` method has been deprecated and will be removed in a future version. Please use `load_lora_adapter()`."
215+
deprecate("load_attn_procs", "0.40.0", deprecation_message)
216+
212217
if is_custom_diffusion:
213218
attn_processors = self._process_custom_diffusion(state_dict=state_dict)
214219
elif is_lora:

tests/lora/utils.py

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1784,11 +1784,7 @@ def test_missing_keys_warning(self):
17841784
missing_key = [k for k in state_dict if "lora_A" in k][0]
17851785
del state_dict[missing_key]
17861786

1787-
logger = (
1788-
logging.get_logger("diffusers.loaders.unet")
1789-
if self.unet_kwargs is not None
1790-
else logging.get_logger("diffusers.loaders.peft")
1791-
)
1787+
logger = logging.get_logger("diffusers.loaders.peft")
17921788
logger.setLevel(30)
17931789
with CaptureLogger(logger) as cap_logger:
17941790
pipe.load_lora_weights(state_dict)
@@ -1823,11 +1819,7 @@ def test_unexpected_keys_warning(self):
18231819
unexpected_key = [k for k in state_dict if "lora_A" in k][0] + ".diffusers_cat"
18241820
state_dict[unexpected_key] = torch.tensor(1.0, device=torch_device)
18251821

1826-
logger = (
1827-
logging.get_logger("diffusers.loaders.unet")
1828-
if self.unet_kwargs is not None
1829-
else logging.get_logger("diffusers.loaders.peft")
1830-
)
1822+
logger = logging.get_logger("diffusers.loaders.peft")
18311823
logger.setLevel(30)
18321824
with CaptureLogger(logger) as cap_logger:
18331825
pipe.load_lora_weights(state_dict)

tests/models/test_modeling_common.py

Lines changed: 103 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@
4444
from diffusers.utils import (
4545
SAFE_WEIGHTS_INDEX_NAME,
4646
WEIGHTS_INDEX_NAME,
47+
is_peft_available,
4748
is_torch_npu_available,
4849
is_xformers_available,
4950
logging,
@@ -65,6 +66,10 @@
6566
from ..others.test_utils import TOKEN, USER, is_staging_test
6667

6768

69+
if is_peft_available():
70+
from peft.tuners.tuners_utils import BaseTunerLayer
71+
72+
6873
def caculate_expected_num_shards(index_map_path):
6974
with open(index_map_path) as f:
7075
weight_map_dict = json.load(f)["weight_map"]
@@ -74,6 +79,16 @@ def caculate_expected_num_shards(index_map_path):
7479
return expected_num_shards
7580

7681

82+
def check_if_lora_correctly_set(model) -> bool:
83+
"""
84+
Checks if the LoRA layers are correctly set with peft
85+
"""
86+
for module in model.modules():
87+
if isinstance(module, BaseTunerLayer):
88+
return True
89+
return False
90+
91+
7792
# Will be run via run_test_in_subprocess
7893
def _test_from_save_pretrained_dynamo(in_queue, out_queue, timeout):
7994
error = None
@@ -877,8 +892,6 @@ def _set_gradient_checkpointing_new(self, module, value=False):
877892
model = model_class_copy(**init_dict)
878893
model.enable_gradient_checkpointing()
879894

880-
print(f"{set(modules_with_gc_enabled.keys())=}, {expected_set=}")
881-
882895
assert set(modules_with_gc_enabled.keys()) == expected_set
883896
assert all(modules_with_gc_enabled.values()), "All modules should be enabled"
884897

@@ -902,6 +915,94 @@ def test_deprecated_kwargs(self):
902915
" from `_deprecated_kwargs = [<deprecated_argument>]`"
903916
)
904917

918+
@parameterized.expand([True, False])
919+
@torch.no_grad()
920+
@unittest.skipIf(not is_peft_available(), "Only with PEFT")
921+
def test_save_load_lora_adapter(self, use_dora=False):
922+
import safetensors
923+
from peft import LoraConfig
924+
from peft.utils import get_peft_model_state_dict
925+
926+
from diffusers.loaders.peft import PeftAdapterMixin
927+
928+
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
929+
model = self.model_class(**init_dict).to(torch_device)
930+
931+
if not issubclass(model.__class__, PeftAdapterMixin):
932+
return
933+
934+
torch.manual_seed(0)
935+
output_no_lora = model(**inputs_dict, return_dict=False)[0]
936+
937+
denoiser_lora_config = LoraConfig(
938+
r=4,
939+
lora_alpha=4,
940+
target_modules=["to_q", "to_k", "to_v", "to_out.0"],
941+
init_lora_weights=False,
942+
use_dora=use_dora,
943+
)
944+
model.add_adapter(denoiser_lora_config)
945+
self.assertTrue(check_if_lora_correctly_set(model), "LoRA layers not set correctly")
946+
947+
torch.manual_seed(0)
948+
outputs_with_lora = model(**inputs_dict, return_dict=False)[0]
949+
950+
self.assertFalse(torch.allclose(output_no_lora, outputs_with_lora, atol=1e-4, rtol=1e-4))
951+
952+
with tempfile.TemporaryDirectory() as tmpdir:
953+
model.save_lora_adapter(tmpdir)
954+
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")))
955+
956+
state_dict_loaded = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
957+
958+
model.unload_lora()
959+
self.assertFalse(check_if_lora_correctly_set(model), "LoRA layers not set correctly")
960+
961+
model.load_lora_adapter(tmpdir, prefix=None, use_safetensors=True)
962+
state_dict_retrieved = get_peft_model_state_dict(model, adapter_name="default_0")
963+
964+
for k in state_dict_loaded:
965+
loaded_v = state_dict_loaded[k]
966+
retrieved_v = state_dict_retrieved[k].to(loaded_v.device)
967+
self.assertTrue(torch.allclose(loaded_v, retrieved_v))
968+
969+
self.assertTrue(check_if_lora_correctly_set(model), "LoRA layers not set correctly")
970+
971+
torch.manual_seed(0)
972+
outputs_with_lora_2 = model(**inputs_dict, return_dict=False)[0]
973+
974+
self.assertFalse(torch.allclose(output_no_lora, outputs_with_lora_2, atol=1e-4, rtol=1e-4))
975+
self.assertTrue(torch.allclose(outputs_with_lora, outputs_with_lora_2, atol=1e-4, rtol=1e-4))
976+
977+
@unittest.skipIf(not is_peft_available(), "Only with PEFT")
978+
def test_wrong_adapter_name_raises_error(self):
979+
from peft import LoraConfig
980+
981+
from diffusers.loaders.peft import PeftAdapterMixin
982+
983+
init_dict, _ = self.prepare_init_args_and_inputs_for_common()
984+
model = self.model_class(**init_dict).to(torch_device)
985+
986+
if not issubclass(model.__class__, PeftAdapterMixin):
987+
return
988+
989+
denoiser_lora_config = LoraConfig(
990+
r=4,
991+
lora_alpha=4,
992+
target_modules=["to_q", "to_k", "to_v", "to_out.0"],
993+
init_lora_weights=False,
994+
use_dora=False,
995+
)
996+
model.add_adapter(denoiser_lora_config)
997+
self.assertTrue(check_if_lora_correctly_set(model), "LoRA layers not set correctly")
998+
999+
with tempfile.TemporaryDirectory() as tmpdir:
1000+
wrong_name = "foo"
1001+
with self.assertRaises(ValueError) as err_context:
1002+
model.save_lora_adapter(tmpdir, adapter_name=wrong_name)
1003+
1004+
self.assertTrue(f"Adapter name {wrong_name} not found in the model." in str(err_context.exception))
1005+
9051006
@require_torch_gpu
9061007
def test_cpu_offload(self):
9071008
config, inputs_dict = self.prepare_init_args_and_inputs_for_common()

0 commit comments

Comments
 (0)