diff --git a/convert-hf-to-gguf.py b/convert-hf-to-gguf.py index 1dc18b2a57721..468b9d7c159f5 100755 --- a/convert-hf-to-gguf.py +++ b/convert-hf-to-gguf.py @@ -1434,6 +1434,8 @@ def set_gguf_parameters(self): self.gguf_writer.add_head_count_kv(self.hparams["num_key_value_heads"]) self.gguf_writer.add_layer_norm_rms_eps(self.hparams["rms_norm_eps"]) self.gguf_writer.add_file_type(self.ftype) + if "tie_lm_head" in self.hparams: + self.gguf_writer.add_tie_lm_head(self.hparams["tie_lm_head"]) def set_vocab(self): self._set_vocab_llama_hf() diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index 5951c0bb0fb5e..e9fd85938e859 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -43,6 +43,7 @@ class LLM: EXPERT_USED_COUNT = "{arch}.expert_used_count" POOLING_TYPE = "{arch}.pooling_type" LOGIT_SCALE = "{arch}.logit_scale" + TIE_LM_HEAD = "{arch}.tie_lm_head" class Attention: HEAD_COUNT = "{arch}.attention.head_count" @@ -612,6 +613,7 @@ class MODEL_TENSOR(IntEnum): ], MODEL_ARCH.MINICPM: [ MODEL_TENSOR.TOKEN_EMBD, + MODEL_TENSOR.OUTPUT, MODEL_TENSOR.OUTPUT_NORM, MODEL_TENSOR.ROPE_FREQS, MODEL_TENSOR.ATTN_NORM, @@ -916,6 +918,7 @@ def get_type(val: Any) -> GGUFValueType: KEY_FEED_FORWARD_LENGTH = Keys.LLM.FEED_FORWARD_LENGTH KEY_USE_PARALLEL_RESIDUAL = Keys.LLM.USE_PARALLEL_RESIDUAL KEY_TENSOR_DATA_LAYOUT = Keys.LLM.TENSOR_DATA_LAYOUT +KEY_TIE_LM_HEAD = Keys.LLM.TIE_LM_HEAD # attention KEY_ATTENTION_HEAD_COUNT = Keys.Attention.HEAD_COUNT diff --git a/gguf-py/gguf/gguf_writer.py b/gguf-py/gguf/gguf_writer.py index 8dcf9330b076f..228773ff9d4e0 100644 --- a/gguf-py/gguf/gguf_writer.py +++ b/gguf-py/gguf/gguf_writer.py @@ -416,6 +416,9 @@ def add_feed_forward_length(self, length: int) -> None: def add_parallel_residual(self, use: bool) -> None: self.add_bool(Keys.LLM.USE_PARALLEL_RESIDUAL.format(arch=self.arch), use) + + def add_tie_lm_head(self, tie_lm_head: bool) -> None: + self.add_bool(Keys.LLM.TIE_LM_HEAD.format(arch=self.arch), tie_lm_head) def add_head_count(self, count: int) -> None: self.add_uint32(Keys.Attention.HEAD_COUNT.format(arch=self.arch), count) diff --git a/llama.cpp b/llama.cpp index e7b3fd8b433b4..8f2f7871f9dc6 100644 --- a/llama.cpp +++ b/llama.cpp @@ -287,6 +287,7 @@ enum llm_kv { LLM_KV_EXPERT_USED_COUNT, LLM_KV_POOLING_TYPE, LLM_KV_LOGIT_SCALE, + LLM_KV_TIE_LM_HEAD, LLM_KV_ATTENTION_HEAD_COUNT, LLM_KV_ATTENTION_HEAD_COUNT_KV, @@ -364,6 +365,7 @@ static const std::map LLM_KV_NAMES = { { LLM_KV_EXPERT_USED_COUNT, "%s.expert_used_count" }, { LLM_KV_POOLING_TYPE , "%s.pooling_type" }, { LLM_KV_LOGIT_SCALE, "%s.logit_scale" }, + { LLM_KV_TIE_LM_HEAD, "%s.tie_lm_head" }, { LLM_KV_ATTENTION_HEAD_COUNT, "%s.attention.head_count" }, { LLM_KV_ATTENTION_HEAD_COUNT_KV, "%s.attention.head_count_kv" }, @@ -1845,6 +1847,8 @@ struct llama_hparams { float f_logit_scale = 0.0f; bool causal_attn = true; + bool need_kq_pos = false; + bool tie_lm_head = true; bool use_alibi = false; // currently, we need KQ_pos data for ALiBi-based models enum llama_pooling_type pooling_type = LLAMA_POOLING_TYPE_NONE; @@ -3786,6 +3790,7 @@ static void llm_load_hparams( ml.get_key(LLM_KV_BLOCK_COUNT, hparams.n_layer); ml.get_key(LLM_KV_EXPERT_COUNT, hparams.n_expert, false); ml.get_key(LLM_KV_EXPERT_USED_COUNT, hparams.n_expert_used, false); + ml.get_key(LLM_KV_TIE_LM_HEAD, hparams.tie_lm_head, false); GGML_ASSERT(hparams.n_expert <= LLAMA_MAX_EXPERTS); GGML_ASSERT(hparams.n_expert_used <= hparams.n_expert); @@ -4870,20 +4875,12 @@ static bool llm_load_tensors( case LLM_ARCH_MINICPM: { model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}); + if (!hparams.tie_lm_head){ + model.output = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, false); + } // output - { - model.output_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}); - if (model.arch != LLM_ARCH_MINICPM){ - model.output = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, false); - // if output is NULL, init from the input tok embed - if (model.output == NULL) { - model.output = ml.create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}); - ml.n_created--; // artificial tensor - ml.size_data += ggml_nbytes(model.output); - } - } - } + model.output_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}); for (int i = 0; i < n_layer; ++i) { ggml_context * ctx_layer = ctx_for_layer(i); @@ -10041,7 +10038,11 @@ struct llm_build_context { cb(cur, "lmhead_scaling", -1); // lm_head - cur = ggml_mul_mat(ctx0, model.tok_embd, cur); + if (hparams.tie_lm_head){ + cur = ggml_mul_mat(ctx0, model.tok_embd, cur); + }else{ + cur = ggml_mul_mat(ctx0, model.output, cur); + } cb(cur, "result_output", -1); ggml_build_forward_expand(gf, cur);