|
13 | 13 | # See the License for the specific language governing permissions and
|
14 | 14 | # limitations under the License.
|
15 | 15 | import inspect
|
| 16 | +import os |
16 | 17 | from functools import partial
|
| 18 | +from pathlib import Path |
17 | 19 | from typing import Dict, List, Optional, Union
|
18 | 20 |
|
| 21 | +import safetensors |
| 22 | +import torch |
19 | 23 | import torch.nn as nn
|
20 | 24 |
|
21 | 25 | from ..utils import (
|
@@ -189,40 +193,45 @@ def load_lora_adapter(self, pretrained_model_name_or_path_or_dict, prefix="trans
|
189 | 193 | user_agent=user_agent,
|
190 | 194 | allow_pickle=allow_pickle,
|
191 | 195 | )
|
| 196 | + if network_alphas is not None and prefix is None: |
| 197 | + raise ValueError("`network_alphas` cannot be None when `prefix` is None.") |
192 | 198 |
|
193 |
| - keys = list(state_dict.keys()) |
194 |
| - transformer_keys = [k for k in keys if k.startswith(prefix)] |
195 |
| - if len(transformer_keys) > 0: |
196 |
| - state_dict = {k.replace(f"{prefix}.", ""): v for k, v in state_dict.items() if k in transformer_keys} |
| 199 | + if prefix is not None: |
| 200 | + keys = list(state_dict.keys()) |
| 201 | + model_keys = [k for k in keys if k.startswith(f"{prefix}.")] |
| 202 | + if len(model_keys) > 0: |
| 203 | + state_dict = {k.replace(f"{prefix}.", ""): v for k, v in state_dict.items() if k in model_keys} |
| 204 | + |
| 205 | + if len(state_dict) > 0: |
| 206 | + if adapter_name in getattr(self, "peft_config", {}): |
| 207 | + raise ValueError( |
| 208 | + f"Adapter name {adapter_name} already in use in the model - please select a new adapter name." |
| 209 | + ) |
197 | 210 |
|
198 |
| - if len(state_dict.keys()) > 0: |
199 | 211 | # check with first key if is not in peft format
|
200 | 212 | first_key = next(iter(state_dict.keys()))
|
201 | 213 | if "lora_A" not in first_key:
|
202 | 214 | state_dict = convert_unet_state_dict_to_peft(state_dict)
|
203 | 215 |
|
204 |
| - if adapter_name in getattr(self, "peft_config", {}): |
205 |
| - raise ValueError( |
206 |
| - f"Adapter name {adapter_name} already in use in the transformer - please select a new adapter name." |
207 |
| - ) |
208 |
| - |
209 | 216 | rank = {}
|
210 | 217 | for key, val in state_dict.items():
|
211 | 218 | if "lora_B" in key:
|
212 | 219 | rank[key] = val.shape[1]
|
213 | 220 |
|
214 | 221 | if network_alphas is not None and len(network_alphas) >= 1:
|
215 |
| - alpha_keys = [k for k in network_alphas.keys() if k.startswith(prefix) and k.split(".")[0] == prefix] |
| 222 | + alpha_keys = [k for k in network_alphas.keys() if k.startswith(f"{prefix}.")] |
216 | 223 | network_alphas = {k.replace(f"{prefix}.", ""): v for k, v in network_alphas.items() if k in alpha_keys}
|
217 | 224 |
|
218 | 225 | lora_config_kwargs = get_peft_kwargs(rank, network_alpha_dict=network_alphas, peft_state_dict=state_dict)
|
219 | 226 | if "use_dora" in lora_config_kwargs:
|
220 |
| - if lora_config_kwargs["use_dora"] and is_peft_version("<", "0.9.0"): |
221 |
| - raise ValueError( |
222 |
| - "You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`." |
223 |
| - ) |
| 227 | + if lora_config_kwargs["use_dora"]: |
| 228 | + if is_peft_version("<", "0.9.0"): |
| 229 | + raise ValueError( |
| 230 | + "You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`." |
| 231 | + ) |
224 | 232 | else:
|
225 |
| - lora_config_kwargs.pop("use_dora") |
| 233 | + if is_peft_version("<", "0.9.0"): |
| 234 | + lora_config_kwargs.pop("use_dora") |
226 | 235 | lora_config = LoraConfig(**lora_config_kwargs)
|
227 | 236 |
|
228 | 237 | # adapter_name
|
@@ -276,6 +285,69 @@ def load_lora_adapter(self, pretrained_model_name_or_path_or_dict, prefix="trans
|
276 | 285 | _pipeline.enable_sequential_cpu_offload()
|
277 | 286 | # Unsafe code />
|
278 | 287 |
|
| 288 | + def save_lora_adapter( |
| 289 | + self, |
| 290 | + save_directory, |
| 291 | + adapter_name: str = "default", |
| 292 | + upcast_before_saving: bool = False, |
| 293 | + safe_serialization: bool = True, |
| 294 | + weight_name: Optional[str] = None, |
| 295 | + ): |
| 296 | + """ |
| 297 | + Save the LoRA parameters corresponding to the underlying model. |
| 298 | +
|
| 299 | + Arguments: |
| 300 | + save_directory (`str` or `os.PathLike`): |
| 301 | + Directory to save LoRA parameters to. Will be created if it doesn't exist. |
| 302 | + adapter_name: (`str`, defaults to "default"): The name of the adapter to serialize. Useful when the |
| 303 | + underlying model has multiple adapters loaded. |
| 304 | + upcast_before_saving (`bool`, defaults to `False`): |
| 305 | + Whether to cast the underlying model to `torch.float32` before serialization. |
| 306 | + save_function (`Callable`): |
| 307 | + The function to use to save the state dictionary. Useful during distributed training when you need to |
| 308 | + replace `torch.save` with another method. Can be configured with the environment variable |
| 309 | + `DIFFUSERS_SAVE_MODE`. |
| 310 | + safe_serialization (`bool`, *optional*, defaults to `True`): |
| 311 | + Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`. |
| 312 | + weight_name: (`str`, *optional*, defaults to `None`): Name of the file to serialize the state dict with. |
| 313 | + """ |
| 314 | + from peft.utils import get_peft_model_state_dict |
| 315 | + |
| 316 | + from .lora_base import LORA_WEIGHT_NAME, LORA_WEIGHT_NAME_SAFE |
| 317 | + |
| 318 | + if adapter_name is None: |
| 319 | + adapter_name = get_adapter_name(self) |
| 320 | + |
| 321 | + if adapter_name not in getattr(self, "peft_config", {}): |
| 322 | + raise ValueError(f"Adapter name {adapter_name} not found in the model.") |
| 323 | + |
| 324 | + lora_layers_to_save = get_peft_model_state_dict( |
| 325 | + self.to(dtype=torch.float32 if upcast_before_saving else None), adapter_name=adapter_name |
| 326 | + ) |
| 327 | + if os.path.isfile(save_directory): |
| 328 | + raise ValueError(f"Provided path ({save_directory}) should be a directory, not a file") |
| 329 | + |
| 330 | + if safe_serialization: |
| 331 | + |
| 332 | + def save_function(weights, filename): |
| 333 | + return safetensors.torch.save_file(weights, filename, metadata={"format": "pt"}) |
| 334 | + |
| 335 | + else: |
| 336 | + save_function = torch.save |
| 337 | + |
| 338 | + os.makedirs(save_directory, exist_ok=True) |
| 339 | + |
| 340 | + if weight_name is None: |
| 341 | + if safe_serialization: |
| 342 | + weight_name = LORA_WEIGHT_NAME_SAFE |
| 343 | + else: |
| 344 | + weight_name = LORA_WEIGHT_NAME |
| 345 | + |
| 346 | + # TODO: we could consider saving the `peft_config` as well. |
| 347 | + save_path = Path(save_directory, weight_name).as_posix() |
| 348 | + save_function(lora_layers_to_save, save_path) |
| 349 | + logger.info(f"Model weights saved in {save_path}") |
| 350 | + |
279 | 351 | def set_adapters(
|
280 | 352 | self,
|
281 | 353 | adapter_names: Union[List[str], str],
|
|
0 commit comments