Skip to content

Commit ea5615a

Browse files
authored
convert-llama-h5-to-gguf.py : clarify the reverse permute
1 parent 4a1741a commit ea5615a

File tree

1 file changed

+7
-5
lines changed

1 file changed

+7
-5
lines changed

convert-llama-h5-to-gguf.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,9 @@
1818
# compatible with python < 3.9
1919
NDArray: 'TypeAlias' = 'np.ndarray[Any, Any]'
2020

21-
def permute(weights: NDArray, n_head: int, n_kv_head: Optional[int] = None) -> NDArray:
21+
# reverse HF permute back to original pth layout
22+
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/convert_llama_weights_to_hf.py
23+
def reverse_hf_permute(weights: NDArray, n_head: int, n_kv_head: Optional[int] = None) -> NDArray:
2224
if n_kv_head is not None and n_head != n_kv_head: n_head //= n_kv_head
2325
return (weights.reshape(n_head, 2, weights.shape[0] // n_head // 2, *weights.shape[1:])
2426
.swapaxes(1, 2)
@@ -219,9 +221,9 @@ def count_model_parts(dir_model: str) -> int:
219221

220222
data = data.squeeze().numpy()
221223

222-
# permute these
224+
# reverse permute these
223225
if name.endswith(".q_proj.weight") or name.endswith(".k_proj.weight"):
224-
data = permute(data, head_count, head_count_kv)
226+
data = reverse_hf_permute(data, head_count, head_count_kv)
225227

226228
# map tensor names
227229
if name.endswith(".weight") and name[:-7] in tensor_map:
@@ -288,9 +290,9 @@ def count_model_parts(dir_model: str) -> int:
288290

289291
data = data.squeeze().numpy()
290292

291-
# permute these
293+
# reverse permute these
292294
if name.endswith(".q_proj.weight") or name.endswith(".k_proj.weight"):
293-
data = permute(data, head_count, head_count_kv)
295+
data = reverse_hf_permute(data, head_count, head_count_kv)
294296

295297
# map tensor names
296298
if name.endswith(".weight") and name[:-7] in tensor_map:

0 commit comments

Comments
 (0)