Skip to content

Commit 63805f8

Browse files
authored
Support convert LoRA safetensors into diffusers format (#2403)
* add lora convertor * Update convert_lora_safetensor_to_diffusers.py * Update README.md * Update convert_lora_safetensor_to_diffusers.py
1 parent 9920c33 commit 63805f8

File tree

2 files changed

+134
-4
lines changed

2 files changed

+134
-4
lines changed

README.md

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -467,12 +467,12 @@ image.save("ddpm_generated_image.png")
467467
- [Unconditional Diffusion with continuous scheduler](https://huggingface.co/google/ncsnpp-ffhq-1024)
468468

469469
**Other Image Notebooks**:
470-
* [image-to-image generation with Stable Diffusion](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/image_2_image_using_diffusers.ipynb) ![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg),
471-
* [tweak images via repeated Stable Diffusion seeds](https://colab.research.google.com/github/pcuenca/diffusers-examples/blob/main/notebooks/stable-diffusion-seeds.ipynb) ![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg),
470+
* [image-to-image generation with Stable Diffusion](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/image_2_image_using_diffusers.ipynb) ![Open In Colab](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/image_2_image_using_diffusers.ipynb),
471+
* [tweak images via repeated Stable Diffusion seeds](https://colab.research.google.com/github/pcuenca/diffusers-examples/blob/main/notebooks/stable-diffusion-seeds.ipynb) ![Open In Colab](https://colab.research.google.com/github/pcuenca/diffusers-examples/blob/main/notebooks/stable-diffusion-seeds.ipynb),
472472

473473
**Diffusers for Other Modalities**:
474-
* [Molecule conformation generation](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/geodiff_molecule_conformation.ipynb) ![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg),
475-
* [Model-based reinforcement learning](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/reinforcement_learning_with_diffusers.ipynb) ![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg),
474+
* [Molecule conformation generation](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/geodiff_molecule_conformation.ipynb) ![Open In Colab](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/geodiff_molecule_conformation.ipynb),
475+
* [Model-based reinforcement learning](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/reinforcement_learning_with_diffusers.ipynb) ![Open In Colab](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/reinforcement_learning_with_diffusers.ipynb),
476476

477477
### Web Demos
478478
If you just want to play around with some web demos, you can try out the following 🚀 Spaces:
Lines changed: 130 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,130 @@
1+
# coding=utf-8
2+
# Copyright 2023, Haofan Wang, Qixun Wang, All rights reserved.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
""" Conversion script for the LoRA's safetensors checkpoints. """
17+
18+
import argparse
19+
20+
import torch
21+
from safetensors.torch import load_file
22+
23+
from diffusers import StableDiffusionPipeline
24+
25+
26+
def convert(base_model_path, checkpoint_path, LORA_PREFIX_UNET, LORA_PREFIX_TEXT_ENCODER, alpha):
27+
28+
# load base model
29+
pipeline = StableDiffusionPipeline.from_pretrained(base_model_path, torch_dtype=torch.float32)
30+
31+
# load LoRA weight from .safetensors
32+
state_dict = load_file(checkpoint_path)
33+
34+
visited = []
35+
36+
# directly update weight in diffusers model
37+
for key in state_dict:
38+
39+
# it is suggested to print out the key, it usually will be something like below
40+
# "lora_te_text_model_encoder_layers_0_self_attn_k_proj.lora_down.weight"
41+
42+
# as we have set the alpha beforehand, so just skip
43+
if ".alpha" in key or key in visited:
44+
continue
45+
46+
if "text" in key:
47+
layer_infos = key.split(".")[0].split(LORA_PREFIX_TEXT_ENCODER + "_")[-1].split("_")
48+
curr_layer = pipeline.text_encoder
49+
else:
50+
layer_infos = key.split(".")[0].split(LORA_PREFIX_UNET + "_")[-1].split("_")
51+
curr_layer = pipeline.unet
52+
53+
# find the target layer
54+
temp_name = layer_infos.pop(0)
55+
while len(layer_infos) > -1:
56+
try:
57+
curr_layer = curr_layer.__getattr__(temp_name)
58+
if len(layer_infos) > 0:
59+
temp_name = layer_infos.pop(0)
60+
elif len(layer_infos) == 0:
61+
break
62+
except Exception:
63+
if len(temp_name) > 0:
64+
temp_name += "_" + layer_infos.pop(0)
65+
else:
66+
temp_name = layer_infos.pop(0)
67+
68+
pair_keys = []
69+
if "lora_down" in key:
70+
pair_keys.append(key.replace("lora_down", "lora_up"))
71+
pair_keys.append(key)
72+
else:
73+
pair_keys.append(key)
74+
pair_keys.append(key.replace("lora_up", "lora_down"))
75+
76+
# update weight
77+
if len(state_dict[pair_keys[0]].shape) == 4:
78+
weight_up = state_dict[pair_keys[0]].squeeze(3).squeeze(2).to(torch.float32)
79+
weight_down = state_dict[pair_keys[1]].squeeze(3).squeeze(2).to(torch.float32)
80+
curr_layer.weight.data += alpha * torch.mm(weight_up, weight_down).unsqueeze(2).unsqueeze(3)
81+
else:
82+
weight_up = state_dict[pair_keys[0]].to(torch.float32)
83+
weight_down = state_dict[pair_keys[1]].to(torch.float32)
84+
curr_layer.weight.data += alpha * torch.mm(weight_up, weight_down)
85+
86+
# update visited list
87+
for item in pair_keys:
88+
visited.append(item)
89+
90+
return pipeline
91+
92+
93+
if __name__ == "__main__":
94+
parser = argparse.ArgumentParser()
95+
96+
parser.add_argument(
97+
"--base_model_path", default=None, type=str, required=True, help="Path to the base model in diffusers format."
98+
)
99+
parser.add_argument(
100+
"--checkpoint_path", default=None, type=str, required=True, help="Path to the checkpoint to convert."
101+
)
102+
parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output model.")
103+
parser.add_argument(
104+
"--lora_prefix_unet", default="lora_unet", type=str, help="The prefix of UNet weight in safetensors"
105+
)
106+
parser.add_argument(
107+
"--lora_prefix_text_encoder",
108+
default="lora_te",
109+
type=str,
110+
help="The prefix of text encoder weight in safetensors",
111+
)
112+
parser.add_argument("--alpha", default=0.75, type=float, help="The merging ratio in W = W0 + alpha * deltaW")
113+
parser.add_argument(
114+
"--to_safetensors", action="store_true", help="Whether to store pipeline in safetensors format or not."
115+
)
116+
parser.add_argument("--device", type=str, help="Device to use (e.g. cpu, cuda:0, cuda:1, etc.)")
117+
118+
args = parser.parse_args()
119+
120+
base_model_path = args.base_model_path
121+
checkpoint_path = args.checkpoint_path
122+
dump_path = args.dump_path
123+
lora_prefix_unet = args.lora_prefix_unet
124+
lora_prefix_text_encoder = args.lora_prefix_text_encoder
125+
alpha = args.alpha
126+
127+
pipe = convert(base_model_path, checkpoint_path, lora_prefix_unet, lora_prefix_text_encoder, alpha)
128+
129+
pipe = pipe.to(args.device)
130+
pipe.save_pretrained(args.dump_path, safe_serialization=args.to_safetensors)

0 commit comments

Comments
 (0)