@@ -57,6 +57,12 @@ static const std::vector<quant_option> QUANT_OPTIONS = {
57
57
{ " COPY" , LLAMA_FTYPE_ALL_F32, " only copy tensors, no quantizing" , },
58
58
};
59
59
60
+ // Quantization types. Changes to this struct must be replicated in llama-quantize.cpp
61
+ struct tensor_quantization {
62
+ std::string name;
63
+ ggml_type quant = GGML_TYPE_COUNT;
64
+ };
65
+
60
66
static const char * const LLM_KV_QUANTIZE_IMATRIX_FILE = " quantize.imatrix.file" ;
61
67
static const char * const LLM_KV_QUANTIZE_IMATRIX_DATASET = " quantize.imatrix.dataset" ;
62
68
static const char * const LLM_KV_QUANTIZE_IMATRIX_N_ENTRIES = " quantize.imatrix.entries_count" ;
@@ -244,56 +250,10 @@ static ggml_type parse_ggml_type(const char * arg) {
244
250
return type;
245
251
}
246
252
}
247
- fprintf (stderr, " %s: invalid ggml_type '%s'\n " , __func__, arg);
253
+ fprintf (stderr, " \n %s: invalid ggml_type '%s'\n \n" , __func__, arg);
248
254
return GGML_TYPE_COUNT;
249
255
}
250
256
251
- // Allowed tensors for arbitrary quantization with --tensor-type option
252
- static const std::vector<std::string> ALLOWED_TENSOR_TYPE = {
253
- " attn_k" ,
254
- " attn_kv_a_mqa" ,
255
- " attn_kv_b" ,
256
- " attn_o" ,
257
- " attn_output" ,
258
- " attn_q" ,
259
- " attn_q_a" ,
260
- " attn_q_b" ,
261
- " attn_qkv" ,
262
- " attn_v" ,
263
- " channel_mix_key" ,
264
- " channel_mix_receptance" ,
265
- " channel_mix_value" ,
266
- " cls" ,
267
- " cls.output" ,
268
- " cross_attn_k" ,
269
- " cross_attn_o" ,
270
- " cross_attn_q" ,
271
- " cross_attn_v" ,
272
- " ffn_act" ,
273
- " ffn_down" ,
274
- " ffn_down_exps" ,
275
- " ffn_down_shexp" ,
276
- " ffn_gate" ,
277
- " ffn_gate_exps" ,
278
- " ffn_gate_shexp" ,
279
- " ffn_up" ,
280
- " ffn_up_exps" ,
281
- " ffn_up_shexp" ,
282
- " ssm_in" ,
283
- " ssm_out" ,
284
- " time_mix_gate" ,
285
- " time_mix_key" ,
286
- " time_mix_output" ,
287
- " time_mix_receptance" ,
288
- " time_mix_value" ,
289
- };
290
-
291
- // changes to this struct must be replicated in llama-quant.cpp
292
- struct tensor_quantization {
293
- std::string name;
294
- ggml_type quant = GGML_TYPE_COUNT;
295
- };
296
-
297
257
static bool parse_tensor_type (const char * data, std::vector<tensor_quantization> & tensor_type) {
298
258
const char * sep = strchr (data, ' =' );
299
259
if (sep == nullptr ) {
@@ -306,7 +266,6 @@ static bool parse_tensor_type(const char * data, std::vector<tensor_quantization
306
266
printf (" \n %s: missing tensor name\n\n " , __func__);
307
267
return false ;
308
268
}
309
-
310
269
if (const size_t qt_len = strlen (sep); qt_len == 1 ) {
311
270
printf (" \n %s: missing quantization type\n\n " , __func__);
312
271
return false ;
@@ -315,37 +274,15 @@ static bool parse_tensor_type(const char * data, std::vector<tensor_quantization
315
274
std::string tn (data, tn_len);
316
275
std::transform (tn.begin (), tn.end (), tn.begin (), tolower);
317
276
sep++;
318
- const std::string qt (sep);
319
-
320
- bool found = false ;
321
- for (const auto & allowed : ALLOWED_TENSOR_TYPE) {
322
- std::string tensor;
323
- tensor = tn.rfind (' .' ) != std::string::npos ? tn.substr (tn.rfind (' .' ) + 1 ) : tn;
324
- // handle special case of cls.output
325
- std::string cls_output = " cls.output" ;
326
- if (tn.find (cls_output) != std::string::npos) {
327
- tensor = " cls.output" ;
328
- }
329
- // check if an allowed tensor exists and it's at the end of the kv string
330
- if (tensor == allowed) {
331
- found = true ;
332
- break ;
333
- }
334
- }
335
- if (!found) {
336
- printf (" \n %s: invalid tensor name '%s'\n\n " , __func__, tn.c_str ());
337
- return false ;
338
- }
339
-
340
- if (parse_ggml_type (qt.c_str ()) == GGML_TYPE_COUNT) {
341
- printf (" \n %s: invalid quantization type '%s'\n\n " , __func__, qt.c_str ());
342
- return false ;
343
- }
344
-
345
277
tensor_quantization tqz;
346
278
tqz.name = tn;
347
- tqz.quant = parse_ggml_type (qt. c_str () );
279
+ tqz.quant = parse_ggml_type (sep );
348
280
tensor_type.emplace_back (std::move (tqz));
281
+ if (tqz.quant == GGML_TYPE_COUNT) {
282
+ printf (" \n %s: invalid quantization type '%s'\n\n " , __func__, sep);
283
+ return false ;
284
+ }
285
+
349
286
return true ;
350
287
}
351
288
0 commit comments