Skip to content

Commit c85d317

Browse files
authored
refactor : reduce code duplication and better API (#2415)
1 parent d8491fc commit c85d317

File tree

1 file changed

+25
-32
lines changed

1 file changed

+25
-32
lines changed

gguf.py

Lines changed: 25 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -71,63 +71,56 @@ def open(cls, path: str) -> "GGUFWriter":
7171
f = open(path, "wb")
7272
return cls(f)
7373

74-
def write_key(self, key: str, value_type: GGUFValueType):
75-
encoded_key = key.encode("utf8")
76-
self.buffered_writer.write(struct.pack("<I", len(encoded_key)))
77-
self.buffered_writer.write(encoded_key)
78-
self.buffered_writer.write(struct.pack("<I", value_type))
74+
def write_key(self, key: str):
75+
self.write_value(key, GGUFValueType.STRING)
7976

8077
def write_uint8(self, key: str, value: int):
81-
self.write_key(key, GGUFValueType.UINT8)
82-
self.buffered_writer.write(struct.pack("<B", value))
78+
self.write_key(key)
79+
self.write_value(value, GGUFValueType.UINT8)
8380

8481
def write_int8(self, key: str, value: int):
85-
self.write_key(key, GGUFValueType.INT8)
86-
self.buffered_writer.write(struct.pack("<b", value))
82+
self.write_key(key)
83+
self.write_value(value, GGUFValueType.INT8)
8784

8885
def write_uint16(self, key: str, value: int):
89-
self.write_key(key, GGUFValueType.UINT16)
90-
self.buffered_writer.write(struct.pack("<H", value))
86+
self.write_key(key)
87+
self.write_value(value, GGUFValueType.UINT16)
9188

9289
def write_int16(self, key: str, value: int):
93-
self.write_key(key, GGUFValueType.INT16)
94-
self.buffered_writer.write(struct.pack("<h", value))
90+
self.write_key(key)
91+
self.write_value(value, GGUFValueType.INT16)
9592

9693
def write_uint32(self, key: str, value: int):
97-
self.write_key(key, GGUFValueType.UINT32)
98-
self.buffered_writer.write(struct.pack("<I", value))
94+
self.write_key(key)
95+
self.write(value, GGUFValueType.UINT32)
9996

10097
def write_int32(self, key: str, value: int):
101-
self.write_key(key, GGUFValueType.INT32)
102-
self.buffered_writer.write(struct.pack("<i", value))
98+
self.write_key(key)
99+
self.write_value(value, GGUFValueType.INT32)
103100

104101
def write_float32(self, key: str, value: float):
105-
self.write_key(key, GGUFValueType.FLOAT32)
106-
self.buffered_writer.write(struct.pack("<f", value))
102+
self.write_key(key)
103+
self.write_value(value, GGUFValueType.FLOAT32)
107104

108105
def write_bool(self, key: str, value: bool):
109-
self.write_key(key, GGUFValueType.BOOL)
110-
self.buffered_writer.write(struct.pack("<?", value))
106+
self.write_key(key)
107+
self.write_value(value, GGUFValueType.BOOL)
111108

112109
def write_string(self, key: str, value: str):
113-
self.write_key(key, GGUFValueType.STRING)
114-
encoded_string = value.encode('utf-8')
115-
self.buffered_writer.write(struct.pack("<I", len(encoded_string)))
116-
self.buffered_writer.write(encoded_string)
110+
self.write_key(key)
111+
self.write_value(value, GGUFValueType.STRING)
117112

118113
def write_array(self, key: str, value: list):
119114
if not isinstance(value, list):
120115
raise ValueError("Value must be a list for array type")
121116

122-
self.write_key(key, GGUFValueType.ARRAY)
123-
124-
self.buffered_writer.write(struct.pack("<I", len(value)))
117+
self.write_key(key)
118+
self.write_value(value, GGUFValueType.ARRAY)
125119

126-
for item in value:
127-
self.write_value(item)
120+
def write_value(self: str, value: Any, value_type: GGUFValueType = None):
121+
if value_type is None:
122+
value_type = GGUFValueType.get_type(value)
128123

129-
def write_value(self: str, value: Any):
130-
value_type = GGUFValueType.get_type(value)
131124
self.buffered_writer.write(struct.pack("<I", value_type))
132125

133126
if value_type == GGUFValueType.UINT8:

0 commit comments

Comments
 (0)