Skip to content

Commit 6c68218

Browse files
simonMoisselinsimonMoisselin
and
simonMoisselin
authored
models : add ggml_to_pt script (ggml-org#1042)
* adding ggml_to_pt * typo sys too many args * fixing swap errors dimensions --------- Co-authored-by: simonMoisselin <simon.moisselin@gmail.com>
1 parent f11f33f commit 6c68218

File tree

1 file changed

+109
-0
lines changed

1 file changed

+109
-0
lines changed

models/ggml_to_pt.py

Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
import struct
2+
import torch
3+
import numpy as np
4+
from collections import OrderedDict
5+
from pathlib import Path
6+
import sys
7+
8+
if len(sys.argv) < 3:
9+
print(
10+
"Usage: convert-ggml-to-pt.py model.bin dir-output\n")
11+
sys.exit(1)
12+
13+
fname_inp = Path(sys.argv[1])
14+
dir_out = Path(sys.argv[2])
15+
fname_out = dir_out / "torch-model.pt"
16+
17+
18+
19+
# Open the ggml file
20+
with open(fname_inp, "rb") as f:
21+
# Read magic number and hyperparameters
22+
magic_number, n_vocab, n_audio_ctx, n_audio_state, n_audio_head, n_audio_layer, n_text_ctx, n_text_state, n_text_head, n_text_layer, n_mels, use_f16 = struct.unpack("12i", f.read(48))
23+
print(f"Magic number: {magic_number}")
24+
print(f"Vocab size: {n_vocab}")
25+
print(f"Audio context size: {n_audio_ctx}")
26+
print(f"Audio state size: {n_audio_state}")
27+
print(f"Audio head size: {n_audio_head}")
28+
print(f"Audio layer size: {n_audio_layer}")
29+
print(f"Text context size: {n_text_ctx}")
30+
print(f"Text head size: {n_text_head}")
31+
print(f"Mel size: {n_mels}")
32+
# Read mel filters
33+
# mel_filters = np.fromfile(f, dtype=np.float32, count=n_mels * 2).reshape(n_mels, 2)
34+
# print(f"Mel filters: {mel_filters}")
35+
filters_shape_0 = struct.unpack("i", f.read(4))[0]
36+
print(f"Filters shape 0: {filters_shape_0}")
37+
filters_shape_1 = struct.unpack("i", f.read(4))[0]
38+
print(f"Filters shape 1: {filters_shape_1}")
39+
40+
# Read tokenizer tokens
41+
# bytes = f.read(4)
42+
# print(bytes)
43+
44+
45+
# for i in range(filters.shape[0]):
46+
# for j in range(filters.shape[1]):
47+
# fout.write(struct.pack("f", filters[i][j]))
48+
mel_filters = np.zeros((filters_shape_0, filters_shape_1))
49+
50+
for i in range(filters_shape_0):
51+
for j in range(filters_shape_1):
52+
mel_filters[i][j] = struct.unpack("f", f.read(4))[0]
53+
54+
bytes_data = f.read(4)
55+
num_tokens = struct.unpack("i", bytes_data)[0]
56+
tokens = {}
57+
58+
59+
for _ in range(num_tokens):
60+
token_len = struct.unpack("i", f.read(4))[0]
61+
token = f.read(token_len)
62+
tokens[token] = {}
63+
64+
# Read model variables
65+
model_state_dict = OrderedDict()
66+
while True:
67+
try:
68+
n_dims, name_length, ftype = struct.unpack("iii", f.read(12))
69+
except struct.error:
70+
break # End of file
71+
dims = [struct.unpack("i", f.read(4))[0] for _ in range(n_dims)]
72+
dims = dims[::-1]
73+
name = f.read(name_length).decode("utf-8")
74+
if ftype == 1: # f16
75+
data = np.fromfile(f, dtype=np.float16, count=np.prod(dims)).reshape(dims)
76+
else: # f32
77+
data = np.fromfile(f, dtype=np.float32, count=np.prod(dims)).reshape(dims)
78+
79+
80+
if name in ["encoder.conv1.bias", "encoder.conv2.bias"]:
81+
82+
data = data[:, 0]
83+
84+
85+
model_state_dict[name] = torch.from_numpy(data)
86+
87+
# Now you have the model's state_dict stored in model_state_dict
88+
# You can load this state_dict into a model with the same architecture
89+
90+
# dims = ModelDimensions(**checkpoint["dims"])
91+
# model = Whisper(dims)
92+
from whisper import Whisper, ModelDimensions
93+
dims = ModelDimensions(
94+
n_mels=n_mels,
95+
n_audio_ctx=n_audio_ctx,
96+
n_audio_state=n_audio_state,
97+
n_audio_head=n_audio_head,
98+
n_audio_layer=n_audio_layer,
99+
n_text_ctx=n_text_ctx,
100+
n_text_state=n_text_state,
101+
n_text_head=n_text_head,
102+
n_text_layer=n_text_layer,
103+
n_vocab=n_vocab,
104+
)
105+
model = Whisper(dims) # Replace with your model's class
106+
model.load_state_dict(model_state_dict)
107+
108+
# Save the model in PyTorch format
109+
torch.save(model.state_dict(), fname_out)

0 commit comments

Comments
 (0)