Skip to content

Commit 5897137

Browse files
authored
[chore] add a script to extract loras from full fine-tuned models (#10631)
* feat: add a lora extraction script. * updates
1 parent a451c0e commit 5897137

File tree

1 file changed

+151
-0
lines changed

1 file changed

+151
-0
lines changed

scripts/extract_lora_from_model.py

Lines changed: 151 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,151 @@
1+
"""
2+
This script demonstrates how to extract a LoRA checkpoint from a fully finetuned model with the CogVideoX model.
3+
4+
To make it work for other models:
5+
6+
* Change the model class. Here we use `CogVideoXTransformer3DModel`. For Flux, it would be `FluxTransformer2DModel`,
7+
for example. (TODO: more reason to add `AutoModel`).
8+
* Spply path to the base checkpoint via `base_ckpt_path`.
9+
* Supply path to the fully fine-tuned checkpoint via `--finetune_ckpt_path`.
10+
* Change the `--rank` as needed.
11+
12+
Example usage:
13+
14+
```bash
15+
python extract_lora_from_model.py \
16+
--base_ckpt_path=THUDM/CogVideoX-5b \
17+
--finetune_ckpt_path=finetrainers/cakeify-v0 \
18+
--lora_out_path=cakeify_lora.safetensors
19+
```
20+
21+
Script is adapted from
22+
https://github.com/Stability-AI/stability-ComfyUI-nodes/blob/001154622564b17223ce0191803c5fff7b87146c/control_lora_create.py
23+
"""
24+
25+
import argparse
26+
27+
import torch
28+
from safetensors.torch import save_file
29+
from tqdm.auto import tqdm
30+
31+
from diffusers import CogVideoXTransformer3DModel
32+
33+
34+
RANK = 64
35+
CLAMP_QUANTILE = 0.99
36+
37+
38+
# Comes from
39+
# https://github.com/Stability-AI/stability-ComfyUI-nodes/blob/001154622564b17223ce0191803c5fff7b87146c/control_lora_create.py#L9
40+
def extract_lora(diff, rank):
41+
# Important to use CUDA otherwise, very slow!
42+
if torch.cuda.is_available():
43+
diff = diff.to("cuda")
44+
45+
is_conv2d = len(diff.shape) == 4
46+
kernel_size = None if not is_conv2d else diff.size()[2:4]
47+
is_conv2d_3x3 = is_conv2d and kernel_size != (1, 1)
48+
out_dim, in_dim = diff.size()[0:2]
49+
rank = min(rank, in_dim, out_dim)
50+
51+
if is_conv2d:
52+
if is_conv2d_3x3:
53+
diff = diff.flatten(start_dim=1)
54+
else:
55+
diff = diff.squeeze()
56+
57+
U, S, Vh = torch.linalg.svd(diff.float())
58+
U = U[:, :rank]
59+
S = S[:rank]
60+
U = U @ torch.diag(S)
61+
Vh = Vh[:rank, :]
62+
63+
dist = torch.cat([U.flatten(), Vh.flatten()])
64+
hi_val = torch.quantile(dist, CLAMP_QUANTILE)
65+
low_val = -hi_val
66+
67+
U = U.clamp(low_val, hi_val)
68+
Vh = Vh.clamp(low_val, hi_val)
69+
if is_conv2d:
70+
U = U.reshape(out_dim, rank, 1, 1)
71+
Vh = Vh.reshape(rank, in_dim, kernel_size[0], kernel_size[1])
72+
return (U.cpu(), Vh.cpu())
73+
74+
75+
def parse_args():
76+
parser = argparse.ArgumentParser()
77+
parser.add_argument(
78+
"--base_ckpt_path",
79+
default=None,
80+
type=str,
81+
required=True,
82+
help="Base checkpoint path from which the model was finetuned. Can be a model ID on the Hub.",
83+
)
84+
parser.add_argument(
85+
"--base_subfolder",
86+
default="transformer",
87+
type=str,
88+
help="subfolder to load the base checkpoint from if any.",
89+
)
90+
parser.add_argument(
91+
"--finetune_ckpt_path",
92+
default=None,
93+
type=str,
94+
required=True,
95+
help="Fully fine-tuned checkpoint path. Can be a model ID on the Hub.",
96+
)
97+
parser.add_argument(
98+
"--finetune_subfolder",
99+
default=None,
100+
type=str,
101+
help="subfolder to load the fulle finetuned checkpoint from if any.",
102+
)
103+
parser.add_argument("--rank", default=64, type=int)
104+
parser.add_argument("--lora_out_path", default=None, type=str, required=True)
105+
args = parser.parse_args()
106+
107+
if not args.lora_out_path.endswith(".safetensors"):
108+
raise ValueError("`lora_out_path` must end with `.safetensors`.")
109+
110+
return args
111+
112+
113+
@torch.no_grad()
114+
def main(args):
115+
model_finetuned = CogVideoXTransformer3DModel.from_pretrained(
116+
args.finetune_ckpt_path, subfolder=args.finetune_subfolder, torch_dtype=torch.bfloat16
117+
)
118+
state_dict_ft = model_finetuned.state_dict()
119+
120+
# Change the `subfolder` as needed.
121+
base_model = CogVideoXTransformer3DModel.from_pretrained(
122+
args.base_ckpt_path, subfolder=args.base_subfolder, torch_dtype=torch.bfloat16
123+
)
124+
state_dict = base_model.state_dict()
125+
output_dict = {}
126+
127+
for k in tqdm(state_dict, desc="Extracting LoRA..."):
128+
original_param = state_dict[k]
129+
finetuned_param = state_dict_ft[k]
130+
if len(original_param.shape) >= 2:
131+
diff = finetuned_param.float() - original_param.float()
132+
out = extract_lora(diff, RANK)
133+
name = k
134+
135+
if name.endswith(".weight"):
136+
name = name[: -len(".weight")]
137+
down_key = "{}.lora_A.weight".format(name)
138+
up_key = "{}.lora_B.weight".format(name)
139+
140+
output_dict[up_key] = out[0].contiguous().to(finetuned_param.dtype)
141+
output_dict[down_key] = out[1].contiguous().to(finetuned_param.dtype)
142+
143+
prefix = "transformer" if "transformer" in base_model.__class__.__name__.lower() else "unet"
144+
output_dict = {f"{prefix}.{k}": v for k, v in output_dict.items()}
145+
save_file(output_dict, args.lora_out_path)
146+
print(f"LoRA saved and it contains {len(output_dict)} keys.")
147+
148+
149+
if __name__ == "__main__":
150+
args = parse_args()
151+
main(args)

0 commit comments

Comments
 (0)