Skip to content

Commit e5c834f

Browse files
authored
quantize : improve tensor-type pattern matching (#13033)
1 parent 71bdbdb commit e5c834f

File tree

2 files changed

+26
-87
lines changed

2 files changed

+26
-87
lines changed

src/llama-quant.cpp

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,12 @@
1414
#include <thread>
1515
#include <unordered_map>
1616

17+
// Quantization types. Changes to this struct must be replicated in quantize.cpp
18+
struct tensor_quantization {
19+
std::string name;
20+
ggml_type quant = GGML_TYPE_COUNT;
21+
};
22+
1723
static void zeros(std::ofstream & file, size_t n) {
1824
char zero = 0;
1925
for (size_t i = 0; i < n; ++i) {
@@ -48,12 +54,6 @@ struct quantize_state_impl {
4854
{}
4955
};
5056

51-
// changes to this struct must be replicated in quantize.cpp
52-
struct tensor_quantization {
53-
std::string name;
54-
ggml_type quant = GGML_TYPE_COUNT;
55-
};
56-
5757
static void llama_tensor_dequantize_impl(
5858
ggml_tensor * tensor, std::vector<no_init<float>> & output, std::vector<std::thread> & workers,
5959
const size_t nelements, const int nthread
@@ -796,17 +796,19 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std::
796796
// unless the user specifies a type
797797
if (params->tensor_types) {
798798
const std::vector<tensor_quantization> & tensor_types = *static_cast<const std::vector<tensor_quantization> *>(params->tensor_types);
799+
const std::string tensor_name(tensor->name);
799800
for (const auto & [tname, qtype] : tensor_types) {
800-
if (std::regex pattern(tname); std::regex_search(tensor->name, pattern)) {
801-
if (qtype != new_type) {
802-
LLAMA_LOG_DEBUG("(overriding %s -> %s), ", ggml_type_name(new_type), ggml_type_name(qtype));
801+
if (std::regex pattern(tname); std::regex_search(tensor_name, pattern)) {
802+
if (qtype != new_type) {
803+
LLAMA_LOG_DEBUG("(overriding %s) ", ggml_type_name(new_type));
804+
new_type = qtype;
805+
break; // if two or more types are specified for the tensor, first match wins
803806
}
804-
new_type = qtype;
805-
break;
806807
}
807808
}
808809
}
809810
}
811+
810812
if (params->token_embedding_type < GGML_TYPE_COUNT && strcmp(tensor->name, "token_embd.weight") == 0) {
811813
new_type = params->token_embedding_type;
812814
}

tools/quantize/quantize.cpp

Lines changed: 13 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,12 @@ static const std::vector<quant_option> QUANT_OPTIONS = {
5757
{ "COPY", LLAMA_FTYPE_ALL_F32, "only copy tensors, no quantizing", },
5858
};
5959

60+
// Quantization types. Changes to this struct must be replicated in llama-quantize.cpp
61+
struct tensor_quantization {
62+
std::string name;
63+
ggml_type quant = GGML_TYPE_COUNT;
64+
};
65+
6066
static const char * const LLM_KV_QUANTIZE_IMATRIX_FILE = "quantize.imatrix.file";
6167
static const char * const LLM_KV_QUANTIZE_IMATRIX_DATASET = "quantize.imatrix.dataset";
6268
static const char * const LLM_KV_QUANTIZE_IMATRIX_N_ENTRIES = "quantize.imatrix.entries_count";
@@ -244,56 +250,10 @@ static ggml_type parse_ggml_type(const char * arg) {
244250
return type;
245251
}
246252
}
247-
fprintf(stderr, "%s: invalid ggml_type '%s'\n", __func__, arg);
253+
fprintf(stderr, "\n%s: invalid ggml_type '%s'\n\n", __func__, arg);
248254
return GGML_TYPE_COUNT;
249255
}
250256

251-
// Allowed tensors for arbitrary quantization with --tensor-type option
252-
static const std::vector<std::string> ALLOWED_TENSOR_TYPE = {
253-
"attn_k",
254-
"attn_kv_a_mqa",
255-
"attn_kv_b",
256-
"attn_o",
257-
"attn_output",
258-
"attn_q",
259-
"attn_q_a",
260-
"attn_q_b",
261-
"attn_qkv",
262-
"attn_v",
263-
"channel_mix_key",
264-
"channel_mix_receptance",
265-
"channel_mix_value",
266-
"cls",
267-
"cls.output",
268-
"cross_attn_k",
269-
"cross_attn_o",
270-
"cross_attn_q",
271-
"cross_attn_v",
272-
"ffn_act",
273-
"ffn_down",
274-
"ffn_down_exps",
275-
"ffn_down_shexp",
276-
"ffn_gate",
277-
"ffn_gate_exps",
278-
"ffn_gate_shexp",
279-
"ffn_up",
280-
"ffn_up_exps",
281-
"ffn_up_shexp",
282-
"ssm_in",
283-
"ssm_out",
284-
"time_mix_gate",
285-
"time_mix_key",
286-
"time_mix_output",
287-
"time_mix_receptance",
288-
"time_mix_value",
289-
};
290-
291-
// changes to this struct must be replicated in llama-quant.cpp
292-
struct tensor_quantization {
293-
std::string name;
294-
ggml_type quant = GGML_TYPE_COUNT;
295-
};
296-
297257
static bool parse_tensor_type(const char * data, std::vector<tensor_quantization> & tensor_type) {
298258
const char * sep = strchr(data, '=');
299259
if (sep == nullptr) {
@@ -306,7 +266,6 @@ static bool parse_tensor_type(const char * data, std::vector<tensor_quantization
306266
printf("\n%s: missing tensor name\n\n", __func__);
307267
return false;
308268
}
309-
310269
if (const size_t qt_len = strlen(sep); qt_len == 1) {
311270
printf("\n%s: missing quantization type\n\n", __func__);
312271
return false;
@@ -315,37 +274,15 @@ static bool parse_tensor_type(const char * data, std::vector<tensor_quantization
315274
std::string tn(data, tn_len);
316275
std::transform(tn.begin(), tn.end(), tn.begin(), tolower);
317276
sep++;
318-
const std::string qt(sep);
319-
320-
bool found = false;
321-
for (const auto & allowed : ALLOWED_TENSOR_TYPE) {
322-
std::string tensor;
323-
tensor = tn.rfind('.') != std::string::npos ? tn.substr(tn.rfind('.') + 1) : tn;
324-
// handle special case of cls.output
325-
std::string cls_output = "cls.output";
326-
if (tn.find(cls_output) != std::string::npos) {
327-
tensor = "cls.output";
328-
}
329-
// check if an allowed tensor exists and it's at the end of the kv string
330-
if (tensor == allowed) {
331-
found = true;
332-
break;
333-
}
334-
}
335-
if (!found) {
336-
printf("\n%s: invalid tensor name '%s'\n\n", __func__, tn.c_str());
337-
return false;
338-
}
339-
340-
if (parse_ggml_type(qt.c_str()) == GGML_TYPE_COUNT) {
341-
printf("\n%s: invalid quantization type '%s'\n\n", __func__, qt.c_str());
342-
return false;
343-
}
344-
345277
tensor_quantization tqz;
346278
tqz.name = tn;
347-
tqz.quant = parse_ggml_type(qt.c_str());
279+
tqz.quant = parse_ggml_type(sep);
348280
tensor_type.emplace_back(std::move(tqz));
281+
if (tqz.quant == GGML_TYPE_COUNT) {
282+
printf("\n%s: invalid quantization type '%s'\n\n", __func__, sep);
283+
return false;
284+
}
285+
349286
return true;
350287
}
351288

0 commit comments

Comments
 (0)