|
| 1 | +""" |
| 2 | +This script demonstrates how to extract a LoRA checkpoint from a fully finetuned model with the CogVideoX model. |
| 3 | +
|
| 4 | +To make it work for other models: |
| 5 | +
|
| 6 | +* Change the model class. Here we use `CogVideoXTransformer3DModel`. For Flux, it would be `FluxTransformer2DModel`, |
| 7 | +for example. (TODO: more reason to add `AutoModel`). |
| 8 | +* Spply path to the base checkpoint via `base_ckpt_path`. |
| 9 | +* Supply path to the fully fine-tuned checkpoint via `--finetune_ckpt_path`. |
| 10 | +* Change the `--rank` as needed. |
| 11 | +
|
| 12 | +Example usage: |
| 13 | +
|
| 14 | +```bash |
| 15 | +python extract_lora_from_model.py \ |
| 16 | + --base_ckpt_path=THUDM/CogVideoX-5b \ |
| 17 | + --finetune_ckpt_path=finetrainers/cakeify-v0 \ |
| 18 | + --lora_out_path=cakeify_lora.safetensors |
| 19 | +``` |
| 20 | +
|
| 21 | +Script is adapted from |
| 22 | +https://github.com/Stability-AI/stability-ComfyUI-nodes/blob/001154622564b17223ce0191803c5fff7b87146c/control_lora_create.py |
| 23 | +""" |
| 24 | + |
| 25 | +import argparse |
| 26 | + |
| 27 | +import torch |
| 28 | +from safetensors.torch import save_file |
| 29 | +from tqdm.auto import tqdm |
| 30 | + |
| 31 | +from diffusers import CogVideoXTransformer3DModel |
| 32 | + |
| 33 | + |
| 34 | +RANK = 64 |
| 35 | +CLAMP_QUANTILE = 0.99 |
| 36 | + |
| 37 | + |
| 38 | +# Comes from |
| 39 | +# https://github.com/Stability-AI/stability-ComfyUI-nodes/blob/001154622564b17223ce0191803c5fff7b87146c/control_lora_create.py#L9 |
| 40 | +def extract_lora(diff, rank): |
| 41 | + # Important to use CUDA otherwise, very slow! |
| 42 | + if torch.cuda.is_available(): |
| 43 | + diff = diff.to("cuda") |
| 44 | + |
| 45 | + is_conv2d = len(diff.shape) == 4 |
| 46 | + kernel_size = None if not is_conv2d else diff.size()[2:4] |
| 47 | + is_conv2d_3x3 = is_conv2d and kernel_size != (1, 1) |
| 48 | + out_dim, in_dim = diff.size()[0:2] |
| 49 | + rank = min(rank, in_dim, out_dim) |
| 50 | + |
| 51 | + if is_conv2d: |
| 52 | + if is_conv2d_3x3: |
| 53 | + diff = diff.flatten(start_dim=1) |
| 54 | + else: |
| 55 | + diff = diff.squeeze() |
| 56 | + |
| 57 | + U, S, Vh = torch.linalg.svd(diff.float()) |
| 58 | + U = U[:, :rank] |
| 59 | + S = S[:rank] |
| 60 | + U = U @ torch.diag(S) |
| 61 | + Vh = Vh[:rank, :] |
| 62 | + |
| 63 | + dist = torch.cat([U.flatten(), Vh.flatten()]) |
| 64 | + hi_val = torch.quantile(dist, CLAMP_QUANTILE) |
| 65 | + low_val = -hi_val |
| 66 | + |
| 67 | + U = U.clamp(low_val, hi_val) |
| 68 | + Vh = Vh.clamp(low_val, hi_val) |
| 69 | + if is_conv2d: |
| 70 | + U = U.reshape(out_dim, rank, 1, 1) |
| 71 | + Vh = Vh.reshape(rank, in_dim, kernel_size[0], kernel_size[1]) |
| 72 | + return (U.cpu(), Vh.cpu()) |
| 73 | + |
| 74 | + |
| 75 | +def parse_args(): |
| 76 | + parser = argparse.ArgumentParser() |
| 77 | + parser.add_argument( |
| 78 | + "--base_ckpt_path", |
| 79 | + default=None, |
| 80 | + type=str, |
| 81 | + required=True, |
| 82 | + help="Base checkpoint path from which the model was finetuned. Can be a model ID on the Hub.", |
| 83 | + ) |
| 84 | + parser.add_argument( |
| 85 | + "--base_subfolder", |
| 86 | + default="transformer", |
| 87 | + type=str, |
| 88 | + help="subfolder to load the base checkpoint from if any.", |
| 89 | + ) |
| 90 | + parser.add_argument( |
| 91 | + "--finetune_ckpt_path", |
| 92 | + default=None, |
| 93 | + type=str, |
| 94 | + required=True, |
| 95 | + help="Fully fine-tuned checkpoint path. Can be a model ID on the Hub.", |
| 96 | + ) |
| 97 | + parser.add_argument( |
| 98 | + "--finetune_subfolder", |
| 99 | + default=None, |
| 100 | + type=str, |
| 101 | + help="subfolder to load the fulle finetuned checkpoint from if any.", |
| 102 | + ) |
| 103 | + parser.add_argument("--rank", default=64, type=int) |
| 104 | + parser.add_argument("--lora_out_path", default=None, type=str, required=True) |
| 105 | + args = parser.parse_args() |
| 106 | + |
| 107 | + if not args.lora_out_path.endswith(".safetensors"): |
| 108 | + raise ValueError("`lora_out_path` must end with `.safetensors`.") |
| 109 | + |
| 110 | + return args |
| 111 | + |
| 112 | + |
| 113 | +@torch.no_grad() |
| 114 | +def main(args): |
| 115 | + model_finetuned = CogVideoXTransformer3DModel.from_pretrained( |
| 116 | + args.finetune_ckpt_path, subfolder=args.finetune_subfolder, torch_dtype=torch.bfloat16 |
| 117 | + ) |
| 118 | + state_dict_ft = model_finetuned.state_dict() |
| 119 | + |
| 120 | + # Change the `subfolder` as needed. |
| 121 | + base_model = CogVideoXTransformer3DModel.from_pretrained( |
| 122 | + args.base_ckpt_path, subfolder=args.base_subfolder, torch_dtype=torch.bfloat16 |
| 123 | + ) |
| 124 | + state_dict = base_model.state_dict() |
| 125 | + output_dict = {} |
| 126 | + |
| 127 | + for k in tqdm(state_dict, desc="Extracting LoRA..."): |
| 128 | + original_param = state_dict[k] |
| 129 | + finetuned_param = state_dict_ft[k] |
| 130 | + if len(original_param.shape) >= 2: |
| 131 | + diff = finetuned_param.float() - original_param.float() |
| 132 | + out = extract_lora(diff, RANK) |
| 133 | + name = k |
| 134 | + |
| 135 | + if name.endswith(".weight"): |
| 136 | + name = name[: -len(".weight")] |
| 137 | + down_key = "{}.lora_A.weight".format(name) |
| 138 | + up_key = "{}.lora_B.weight".format(name) |
| 139 | + |
| 140 | + output_dict[up_key] = out[0].contiguous().to(finetuned_param.dtype) |
| 141 | + output_dict[down_key] = out[1].contiguous().to(finetuned_param.dtype) |
| 142 | + |
| 143 | + prefix = "transformer" if "transformer" in base_model.__class__.__name__.lower() else "unet" |
| 144 | + output_dict = {f"{prefix}.{k}": v for k, v in output_dict.items()} |
| 145 | + save_file(output_dict, args.lora_out_path) |
| 146 | + print(f"LoRA saved and it contains {len(output_dict)} keys.") |
| 147 | + |
| 148 | + |
| 149 | +if __name__ == "__main__": |
| 150 | + args = parse_args() |
| 151 | + main(args) |
0 commit comments