Skip to content

Commit c6c4fc0

Browse files
authored
lora : add support for non-llama models (#3333)
* lora : add support for non-llama models ggml-ci * avoid leaking ggml_context on failure cleanup ggml-ci * lora : allow 1d tensors * lora : include embd and output layers in size calculation * fix style
1 parent 8a5be3b commit c6c4fc0

File tree

3 files changed

+113
-105
lines changed

3 files changed

+113
-105
lines changed

convert-lora-to-ggml.py

Lines changed: 44 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -3,49 +3,20 @@
33

44
import json
55
import os
6-
import re
76
import struct
87
import sys
98
from typing import Any, BinaryIO, Sequence
109

1110
import numpy as np
1211
import torch
1312

14-
NUMPY_TYPE_TO_FTYPE: dict[str, int] = {"float32": 0, "float16": 1}
15-
13+
from pathlib import Path
14+
if 'NO_LOCAL_GGUF' not in os.environ:
15+
sys.path.insert(1, str(Path(__file__).parent / 'gguf-py' / 'gguf'))
16+
import gguf
1617

17-
HF_SUBLAYER_TO_GGML = {
18-
"self_attn.q_proj": "attn_q",
19-
"self_attn.k_proj": "attn_k",
20-
"self_attn.v_proj": "attn_v",
21-
"self_attn.o_proj": "attn_output",
22-
"mlp.gate_proj": "ffn_gate",
23-
"mlp.down_proj": "ffn_down",
24-
"mlp.up_proj": "ffn_up",
25-
"input_layernorm": "attn_norm",
26-
"post_attention_layernorm": "ffn_norm",
27-
}
28-
29-
30-
def translate_tensor_name(t: str) -> str:
31-
match = re.match(r".*layers\.(\d+)\.(\w+\.\w+)\.lora_(A|B)\.weight", t)
32-
if match:
33-
nn = match.group(1)
34-
sub_layer = match.group(2)
35-
lora_type = match.group(3)
36-
37-
sub_layer_renamed = HF_SUBLAYER_TO_GGML.get(sub_layer)
38-
if sub_layer_renamed is None:
39-
print(f"Error: unrecognized sub-layer {sub_layer} in tensor {t}")
40-
sys.exit(1)
4118

42-
output_string = (
43-
f"blk.{nn}.{HF_SUBLAYER_TO_GGML[sub_layer]}.weight.lora{lora_type}"
44-
)
45-
return output_string
46-
else:
47-
print(f"Error: unrecognized tensor {t}")
48-
sys.exit(1)
19+
NUMPY_TYPE_TO_FTYPE: dict[str, int] = {"float32": 0, "float16": 1}
4920

5021

5122
def write_file_header(fout: BinaryIO, params: dict[str, Any]) -> None:
@@ -61,9 +32,7 @@ def write_file_header(fout: BinaryIO, params: dict[str, Any]) -> None:
6132
fout.write(struct.pack("i", int(params["lora_alpha"])))
6233

6334

64-
def write_tensor_header(
65-
self, name: str, shape: Sequence[int], data_type: np.dtype[Any]
66-
) -> None:
35+
def write_tensor_header(fout: BinaryIO, name: str, shape: Sequence[int], data_type: np.dtype[Any]) -> None:
6736
sname = name.encode("utf-8")
6837
fout.write(
6938
struct.pack(
@@ -78,18 +47,27 @@ def write_tensor_header(
7847
fout.seek((fout.tell() + 31) & -32)
7948

8049

81-
if len(sys.argv) != 2:
82-
print(f"Usage: python {sys.argv[0]} <path>")
50+
if len(sys.argv) < 2:
51+
print(f"Usage: python {sys.argv[0]} <path> [arch]")
8352
print(
8453
"Path must contain HuggingFace PEFT LoRA files 'adapter_config.json' and 'adapter_model.bin'"
8554
)
55+
print(f"Arch must be one of {list(gguf.MODEL_ARCH_NAMES.values())} (default: llama)")
8656
sys.exit(1)
8757

8858
input_json = os.path.join(sys.argv[1], "adapter_config.json")
8959
input_model = os.path.join(sys.argv[1], "adapter_model.bin")
9060
output_path = os.path.join(sys.argv[1], "ggml-adapter-model.bin")
9161

9262
model = torch.load(input_model, map_location="cpu")
63+
arch_name = sys.argv[2] if len(sys.argv) == 3 else "llama"
64+
65+
if arch_name not in gguf.MODEL_ARCH_NAMES.values():
66+
print(f"Error: unsupported architecture {arch_name}")
67+
sys.exit(1)
68+
69+
arch = list(gguf.MODEL_ARCH_NAMES.keys())[list(gguf.MODEL_ARCH_NAMES.values()).index(arch_name)]
70+
name_map = gguf.TensorNameMap(arch, 200) # 200 layers ought to be enough for anyone
9371

9472
with open(input_json, "r") as f:
9573
params = json.load(f)
@@ -117,6 +95,7 @@ def write_tensor_header(
11795

11896
write_file_header(fout, params)
11997
for k, v in model.items():
98+
orig_k = k
12099
if k.endswith(".default.weight"):
121100
k = k.replace(".default.weight", ".weight")
122101
if k in ["llama_proj.weight", "llama_proj.bias"]:
@@ -129,7 +108,32 @@ def write_tensor_header(
129108
v = v.float()
130109

131110
t = v.detach().numpy()
132-
tname = translate_tensor_name(k)
111+
112+
prefix = "base_model.model."
113+
if k.startswith(prefix):
114+
k = k[len(prefix) :]
115+
116+
lora_suffixes = (".lora_A.weight", ".lora_B.weight")
117+
if k.endswith(lora_suffixes):
118+
suffix = k[-len(lora_suffixes[0]):]
119+
k = k[: -len(lora_suffixes[0])]
120+
else:
121+
print(f"Error: unrecognized tensor name {orig_k}")
122+
sys.exit(1)
123+
124+
tname = name_map.get_name(k)
125+
if tname is None:
126+
print(f"Error: could not map tensor name {orig_k}")
127+
print(" Note: the arch parameter must be specified if the model is not llama")
128+
sys.exit(1)
129+
130+
if suffix == ".lora_A.weight":
131+
tname += ".weight.loraA"
132+
elif suffix == ".lora_B.weight":
133+
tname += ".weight.loraB"
134+
else:
135+
assert False
136+
133137
print(f"{k} => {tname} {t.shape} {t.dtype} {t.nbytes/1024/1024:.2f}MB")
134138
write_tensor_header(fout, tname, t.shape, t.dtype)
135139
t.tofile(fout)

0 commit comments

Comments
 (0)