|
18 | 18 | # compatible with python < 3.9
|
19 | 19 | NDArray: 'TypeAlias' = 'np.ndarray[Any, Any]'
|
20 | 20 |
|
21 |
| -def permute(weights: NDArray, n_head: int, n_kv_head: Optional[int] = None) -> NDArray: |
| 21 | +# reverse HF permute back to original pth layout |
| 22 | +# https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/convert_llama_weights_to_hf.py |
| 23 | +def reverse_hf_permute(weights: NDArray, n_head: int, n_kv_head: Optional[int] = None) -> NDArray: |
22 | 24 | if n_kv_head is not None and n_head != n_kv_head: n_head //= n_kv_head
|
23 | 25 | return (weights.reshape(n_head, 2, weights.shape[0] // n_head // 2, *weights.shape[1:])
|
24 | 26 | .swapaxes(1, 2)
|
@@ -219,9 +221,9 @@ def count_model_parts(dir_model: str) -> int:
|
219 | 221 |
|
220 | 222 | data = data.squeeze().numpy()
|
221 | 223 |
|
222 |
| - # permute these |
| 224 | + # reverse permute these |
223 | 225 | if name.endswith(".q_proj.weight") or name.endswith(".k_proj.weight"):
|
224 |
| - data = permute(data, head_count, head_count_kv) |
| 226 | + data = reverse_hf_permute(data, head_count, head_count_kv) |
225 | 227 |
|
226 | 228 | # map tensor names
|
227 | 229 | if name.endswith(".weight") and name[:-7] in tensor_map:
|
@@ -288,9 +290,9 @@ def count_model_parts(dir_model: str) -> int:
|
288 | 290 |
|
289 | 291 | data = data.squeeze().numpy()
|
290 | 292 |
|
291 |
| - # permute these |
| 293 | + # reverse permute these |
292 | 294 | if name.endswith(".q_proj.weight") or name.endswith(".k_proj.weight"):
|
293 |
| - data = permute(data, head_count, head_count_kv) |
| 295 | + data = reverse_hf_permute(data, head_count, head_count_kv) |
294 | 296 |
|
295 | 297 | # map tensor names
|
296 | 298 | if name.endswith(".weight") and name[:-7] in tensor_map:
|
|
0 commit comments