Skip to content

Commit 7bb36cc

Browse files
ngxsonslaren
andauthored
gguf : enforce that tensor names are unique (#6905)
* not allow adding duplicated tensor name * no duplicated tensor while reading gguf * typo * throw exception inside llama_model_loader Co-authored-by: slaren <slarengh@gmail.com> --------- Co-authored-by: slaren <slarengh@gmail.com>
1 parent ce023f6 commit 7bb36cc

File tree

4 files changed

+32
-1
lines changed

4 files changed

+32
-1
lines changed

ggml.c

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20819,6 +20819,14 @@ struct gguf_context * gguf_init_from_file(const char * fname, struct gguf_init_p
2081920819
// TODO: return an error instead of crashing with GGML_ASSERT
2082020820
gguf_tensor_info_sanitize(info);
2082120821

20822+
// make sure there is no duplicated tensor names
20823+
for (uint64_t j = 0; j < i; ++j) {
20824+
if (strcmp(info->name.data, ctx->infos[j].name.data) == 0) {
20825+
fprintf(stderr, "%s: duplicated tensor name %s\n", __func__, info->name.data);
20826+
ok = false;
20827+
}
20828+
}
20829+
2082220830
if (!ok) {
2082320831
fprintf(stderr, "%s: failed to read tensor info\n", __func__);
2082420832
fclose(file);
@@ -21355,6 +21363,10 @@ void gguf_set_kv(struct gguf_context * ctx, struct gguf_context * src) {
2135521363
void gguf_add_tensor(
2135621364
struct gguf_context * ctx,
2135721365
const struct ggml_tensor * tensor) {
21366+
if (gguf_find_tensor(ctx, tensor->name) != -1) {
21367+
GGML_ASSERT(false && "duplicated tensor name");
21368+
}
21369+
2135821370
const int idx = ctx->header.n_tensors;
2135921371
ctx->infos = realloc(ctx->infos, (idx + 1)*sizeof(struct gguf_tensor_info));
2136021372

gguf-py/gguf/gguf_reader.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -234,8 +234,14 @@ def _build_tensors_fields(self, offs: int, count: int) -> tuple[int, list[Reader
234234

235235
def _build_tensors(self, start_offs: int, fields: list[ReaderField]) -> None:
236236
tensors = []
237+
tensor_names = set() # keep track of name to prevent duplicated tensors
237238
for field in fields:
238239
_name_len, name_data, _n_dims, dims, raw_dtype, offset_tensor = field.parts
240+
# check if there's any tensor having same name already in the list
241+
tensor_name = str(bytes(name_data), encoding = 'utf-8')
242+
if tensor_name in tensor_names:
243+
raise ValueError(f'Found duplicated tensor with name {tensor_name}')
244+
tensor_names.add(tensor_name)
239245
ggml_type = GGMLQuantizationType(raw_dtype[0])
240246
n_elems = np.prod(dims)
241247
block_size, type_size = GGML_QUANT_SIZES[ggml_type]
@@ -267,7 +273,7 @@ def _build_tensors(self, start_offs: int, fields: list[ReaderField]) -> None:
267273
item_count = n_bytes
268274
item_type = np.uint8
269275
tensors.append(ReaderTensor(
270-
name = str(bytes(name_data), encoding = 'utf-8'),
276+
name = tensor_name,
271277
tensor_type = ggml_type,
272278
shape = dims,
273279
n_elements = n_elems,

gguf-py/gguf/gguf_writer.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@ def __init__(
6363
self.kv_data_count = 0
6464
self.ti_data = bytearray()
6565
self.ti_data_count = 0
66+
self.ti_names = set()
6667
self.use_temp_file = use_temp_file
6768
self.temp_file = None
6869
self.tensors = []
@@ -197,6 +198,10 @@ def add_tensor_info(
197198
if self.state is not WriterState.EMPTY:
198199
raise ValueError(f'Expected output file to be empty, got {self.state}')
199200

201+
if name in self.ti_names:
202+
raise ValueError(f'Duplicated tensor name {name}')
203+
self.ti_names.add(name)
204+
200205
encoded_name = name.encode("utf8")
201206
self.ti_data += self._pack("Q", len(encoded_name))
202207
self.ti_data += encoded_name

llama.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3120,9 +3120,17 @@ struct llama_model_loader {
31203120

31213121
fver = (enum llama_fver) gguf_get_version(meta);
31223122

3123+
std::set<std::string> tensor_names;
31233124
for (auto & w : weights) {
31243125
n_elements += ggml_nelements(w.tensor);
31253126
n_bytes += ggml_nbytes(w.tensor);
3127+
// make sure there is no duplicated tensor names
3128+
const std::string name(w.tensor->name);
3129+
auto found = tensor_names.find(name);
3130+
if (found != tensor_names.end()) {
3131+
throw std::runtime_error(format("invalid model: tensor '%s' is duplicated", w.tensor->name));
3132+
}
3133+
tensor_names.insert(name);
31263134
}
31273135

31283136
LLAMA_LOG_INFO("%s: loaded meta data with %d key-value pairs and %d tensors from %s (version %s)\n",

0 commit comments

Comments
 (0)