From 6303ea26401d7be51f5bf60113530dc683b450b1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sigbj=C3=B8rn=20Skj=C3=A6ret?= Date: Wed, 21 May 2025 20:17:35 +0200 Subject: [PATCH 1/9] initial jina-embeddings-v3 support --- gguf-py/gguf/constants.py | 14 ++++++++++++++ gguf-py/gguf/tensor_mapping.py | 2 ++ 2 files changed, 16 insertions(+) diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index 7aed8b83ec378..618f871804b21 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -272,6 +272,7 @@ class MODEL_ARCH(IntEnum): NOMIC_BERT = auto() NOMIC_BERT_MOE = auto() JINA_BERT_V2 = auto() + JINA_BERT_V3 = auto() BLOOM = auto() STABLELM = auto() QWEN = auto() @@ -534,6 +535,7 @@ class MODEL_TENSOR(IntEnum): MODEL_ARCH.NOMIC_BERT: "nomic-bert", MODEL_ARCH.NOMIC_BERT_MOE: "nomic-bert-moe", MODEL_ARCH.JINA_BERT_V2: "jina-bert-v2", + MODEL_ARCH.JINA_BERT_V3: "jina-bert-v3", MODEL_ARCH.BLOOM: "bloom", MODEL_ARCH.STABLELM: "stablelm", MODEL_ARCH.QWEN: "qwen", @@ -1020,6 +1022,18 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.LAYER_OUT_NORM, MODEL_TENSOR.CLS, ], + MODEL_ARCH.JINA_BERT_V3: [ + MODEL_TENSOR.TOKEN_EMBD, + MODEL_TENSOR.TOKEN_EMBD_NORM, + MODEL_TENSOR.TOKEN_TYPES, + MODEL_TENSOR.OUTPUT_NORM, + MODEL_TENSOR.ATTN_OUT_NORM, + MODEL_TENSOR.ATTN_QKV, + MODEL_TENSOR.ATTN_OUT, + MODEL_TENSOR.FFN_DOWN, + MODEL_TENSOR.FFN_UP, + MODEL_TENSOR.LAYER_OUT_NORM, + ], MODEL_ARCH.MPT: [ MODEL_TENSOR.TOKEN_EMBD, MODEL_TENSOR.OUTPUT_NORM, diff --git a/gguf-py/gguf/tensor_mapping.py b/gguf-py/gguf/tensor_mapping.py index 0298f8b460a3f..b6eb770d8aa0f 100644 --- a/gguf-py/gguf/tensor_mapping.py +++ b/gguf-py/gguf/tensor_mapping.py @@ -157,6 +157,7 @@ class TensorNameMap: "h.{bid}.attn.c_attn", # gpt2 "transformer.h.{bid}.mixer.Wqkv", # phi2 "encoder.layers.{bid}.attn.Wqkv", # nomic-bert + "encoder.layers.{bid}.mixer.Wqkv", # jina-bert-v3 "model.layers.{bid}.self_attn.qkv_proj", # phi3 "encoder.layers.{bid}.self_attention.query_key_value", # chatglm "transformer.layers.{bid}.attn.qkv_proj", # openelm @@ -224,6 +225,7 @@ class TensorNameMap: "model.layers.layers.{bid}.self_attn.o_proj", # plamo "model.layers.{bid}.attention.wo", # internlm2 "encoder.layers.{bid}.attn.out_proj", # nomic-bert + "encoder.layers.{bid}.mixer.out_proj", # jina-bert-v3 "transformer.decoder_layer.{bid}.multi_head_attention.linear", # Grok "transformer.blocks.{bid}.norm_attn_norm.attn.out_proj", # dbrx "encoder.layers.{bid}.self_attention.dense", # chatglm From ba51f891bce132c2fdfdac175f8c78ec863627f8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sigbj=C3=B8rn=20Skj=C3=A6ret?= Date: Wed, 21 May 2025 20:18:10 +0200 Subject: [PATCH 2/9] initial jina-embeddings-v3 support --- convert_hf_to_gguf.py | 44 +++++++++++++++++++++++++++++++++++++++---- 1 file changed, 40 insertions(+), 4 deletions(-) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index e88076ccccd2d..12fd6a3643cf0 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -802,6 +802,9 @@ def get_vocab_base_pre(self, tokenizer) -> str: if chkhsh == "d5f1dd6f980fec569fb218a81a7658ac45fc56b38c5a0adeb1c232fbe04ef5ec": # ref: https://huggingface.co/ByteDance-Seed/Seed-Coder-8B-Base res = "seed-coder" + if chkhsh == "a81863d07e75497e2194eb1a1574d5e5cd4d5f85a87a0728b922bf2bed6fb327": + # ref: https://huggingface.co/jinaai/jina-embeddings-v3 + res = "jina-v3" if res is None: logger.warning("\n") @@ -3829,12 +3832,24 @@ def _is_tokenizer_xlmroberta(self) -> bool: class XLMRobertaModel(BertModel): model_arch = gguf.MODEL_ARCH.BERT - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self._xlmroberta_tokenizer_init() + def __init__(self, dir_model: Path, ftype: gguf.LlamaFileType, fname_out: Path, **kwargs: Any): + hparams = kwargs.pop("hparams", None) + if hparams is None: + hparams = ModelBase.load_hparams(dir_model) + + if hparams.get("lora_adaptations"): + self.model_arch = gguf.MODEL_ARCH.JINA_BERT_V3 + + super().__init__(dir_model, ftype, fname_out, hparams=hparams, **kwargs) + + self._tokenizer_is_xlmroberta = False if self.model_arch == gguf.MODEL_ARCH.JINA_BERT_V3 else True + if self._tokenizer_is_xlmroberta: + self._xlmroberta_tokenizer_init() def set_vocab(self): - self._xlmroberta_set_vocab() + if self._tokenizer_is_xlmroberta: + return self._xlmroberta_set_vocab() + return super().set_vocab() def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: # if name starts with "roberta.", remove the prefix @@ -3842,13 +3857,34 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter if name.startswith("roberta."): name = name[8:] + # jina-embeddings-v3 + if ".parametrizations." in name: + name = name.replace(".parametrizations.", ".") + if name.endswith(".original"): + name = name[:-9] + # position embeddings start at pad_token_id + 1, so just chop down the weight tensor if name == "embeddings.position_embeddings.weight": if self._position_offset is not None: data_torch = data_torch[self._position_offset:,:] + if name.endswith(".lora_A"): + # TODO: convert loras + return [] + + if name.endswith(".lora_B"): + # TODO: convert loras + return [] + return super().modify_tensors(data_torch, name, bid) + def set_gguf_parameters(self): + super().set_gguf_parameters() + + # jina-embeddings-v3 + if rotary_emb_base := self.hparams.get("rotary_emb_base"): + self.gguf_writer.add_rope_freq_base(rotary_emb_base) + @ModelBase.register("GemmaForCausalLM") class GemmaModel(TextModel): From 4ac1380531dc0b74d1a1e7d14e273216086aa5e0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sigbj=C3=B8rn=20Skj=C3=A6ret?= Date: Wed, 21 May 2025 20:20:02 +0200 Subject: [PATCH 3/9] initial jina-embeddings-v3 support --- src/llama-arch.cpp | 15 ++++++++++++++ src/llama-arch.h | 1 + src/llama-model.cpp | 50 +++++++++++++++++++++++++++++++++++++++++++-- src/llama-model.h | 1 + 4 files changed, 65 insertions(+), 2 deletions(-) diff --git a/src/llama-arch.cpp b/src/llama-arch.cpp index abf436adac416..19f58ba82db3d 100644 --- a/src/llama-arch.cpp +++ b/src/llama-arch.cpp @@ -21,6 +21,7 @@ static const std::map LLM_ARCH_NAMES = { { LLM_ARCH_NOMIC_BERT, "nomic-bert" }, { LLM_ARCH_NOMIC_BERT_MOE, "nomic-bert-moe" }, { LLM_ARCH_JINA_BERT_V2, "jina-bert-v2" }, + { LLM_ARCH_JINA_BERT_V3, "jina-bert-v3" }, { LLM_ARCH_BLOOM, "bloom" }, { LLM_ARCH_STABLELM, "stablelm" }, { LLM_ARCH_QWEN, "qwen" }, @@ -513,6 +514,20 @@ static const std::map> LLM_TENSOR_N { LLM_TENSOR_CLS, "cls" }, }, }, + { + LLM_ARCH_JINA_BERT_V3, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_TOKEN_EMBD_NORM, "token_embd_norm" }, + { LLM_TENSOR_TOKEN_TYPES, "token_types" }, + { LLM_TENSOR_ATTN_OUT_NORM, "blk.%d.attn_output_norm" }, + { LLM_TENSOR_ATTN_QKV, "blk.%d.attn_qkv" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + { LLM_TENSOR_LAYER_OUT_NORM, "blk.%d.layer_output_norm" }, + }, + }, { LLM_ARCH_BLOOM, { diff --git a/src/llama-arch.h b/src/llama-arch.h index 41a023da3da6e..c696b4113a4c7 100644 --- a/src/llama-arch.h +++ b/src/llama-arch.h @@ -25,6 +25,7 @@ enum llm_arch { LLM_ARCH_NOMIC_BERT, LLM_ARCH_NOMIC_BERT_MOE, LLM_ARCH_JINA_BERT_V2, + LLM_ARCH_JINA_BERT_V3, LLM_ARCH_BLOOM, LLM_ARCH_STABLELM, LLM_ARCH_QWEN, diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 82557ea054bb2..b0512bf44a34d 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -41,6 +41,7 @@ const char * llm_type_name(llm_type type) { case LLM_TYPE_410M: return "410M"; case LLM_TYPE_450M: return "450M"; case LLM_TYPE_475M: return "475M"; + case LLM_TYPE_558M: return "558M"; case LLM_TYPE_770M: return "770M"; case LLM_TYPE_780M: return "780M"; case LLM_TYPE_0_5B: return "0.5B"; @@ -710,6 +711,18 @@ void llama_model::load_hparams(llama_model_loader & ml) { default: type = LLM_TYPE_UNKNOWN; } } break; + case LLM_ARCH_JINA_BERT_V3: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); + ml.get_key(LLM_KV_ATTENTION_CAUSAL, hparams.causal_attn); + ml.get_key(LLM_KV_POOLING_TYPE, hparams.pooling_type, false); + + switch (hparams.n_layer) { + case 24: + type = LLM_TYPE_558M; break; + default: type = LLM_TYPE_UNKNOWN; + } + } break; case LLM_ARCH_NOMIC_BERT: case LLM_ARCH_NOMIC_BERT_MOE: { @@ -2215,6 +2228,36 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0); layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, 0); + layer.layer_out_norm = create_tensor(tn(LLM_TENSOR_LAYER_OUT_NORM, "weight", i), {n_embd}, 0); + layer.layer_out_norm_b = create_tensor(tn(LLM_TENSOR_LAYER_OUT_NORM, "bias", i), {n_embd}, 0); + } + } break; + case LLM_ARCH_JINA_BERT_V3: + { + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + type_embd = create_tensor(tn(LLM_TENSOR_TOKEN_TYPES, "weight"), {n_embd, n_token_types}, 0); + + tok_norm = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight"), {n_embd}, 0); + tok_norm_b = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "bias"), {n_embd}, 0); + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, 0); + layer.bqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd + 2*n_embd_gqa}, 0); + + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); + layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, 0); + + layer.attn_out_norm = create_tensor(tn(LLM_TENSOR_ATTN_OUT_NORM, "weight", i), {n_embd}, 0); + layer.attn_out_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_OUT_NORM, "bias", i), {n_embd}, 0); + + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_up_b = create_tensor(tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}, 0); + + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0); + layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, 0); + layer.layer_out_norm = create_tensor(tn(LLM_TENSOR_LAYER_OUT_NORM, "weight", i), {n_embd}, 0); layer.layer_out_norm_b = create_tensor(tn(LLM_TENSOR_LAYER_OUT_NORM, "bias", i), {n_embd}, 0); } @@ -5931,7 +5974,7 @@ struct llm_build_bert : public llm_graph_context { cur = build_lora_mm(model.layers[il].wqkv, cur); cb(cur, "wqkv", il); - if (model.arch == LLM_ARCH_NOMIC_BERT_MOE) { + if (model.arch == LLM_ARCH_NOMIC_BERT_MOE || model.arch == LLM_ARCH_JINA_BERT_V3) { cur = ggml_add(ctx0, cur, model.layers[il].bqkv); cb(cur, "bqkv", il); } @@ -6003,7 +6046,7 @@ struct llm_build_bert : public llm_graph_context { 0.0f, LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX, il); cb(cur, "ffn_moe_out", il); - } else if (model.arch == LLM_ARCH_BERT || model.arch == LLM_ARCH_NOMIC_BERT_MOE) { + } else if (model.arch == LLM_ARCH_BERT || model.arch == LLM_ARCH_NOMIC_BERT_MOE || model.arch == LLM_ARCH_JINA_BERT_V3) { cur = build_ffn(cur, model.layers[il].ffn_up, model.layers[il].ffn_up_b, NULL, NULL, NULL, NULL, @@ -13187,6 +13230,7 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params, switch (arch) { case LLM_ARCH_BERT: case LLM_ARCH_JINA_BERT_V2: + case LLM_ARCH_JINA_BERT_V3: case LLM_ARCH_NOMIC_BERT: case LLM_ARCH_NOMIC_BERT_MOE: { @@ -13292,6 +13336,7 @@ llm_graph_result_ptr llama_model::build_graph( } break; case LLM_ARCH_BERT: case LLM_ARCH_JINA_BERT_V2: + case LLM_ARCH_JINA_BERT_V3: case LLM_ARCH_NOMIC_BERT: case LLM_ARCH_NOMIC_BERT_MOE: { @@ -13658,6 +13703,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) { case LLM_ARCH_GROK: case LLM_ARCH_DBRX: case LLM_ARCH_BERT: + case LLM_ARCH_JINA_BERT_V3: case LLM_ARCH_NOMIC_BERT: case LLM_ARCH_NOMIC_BERT_MOE: case LLM_ARCH_STABLELM: diff --git a/src/llama-model.h b/src/llama-model.h index cbea2cb331b62..77e36f12f252b 100644 --- a/src/llama-model.h +++ b/src/llama-model.h @@ -37,6 +37,7 @@ enum llm_type { LLM_TYPE_410M, LLM_TYPE_450M, LLM_TYPE_475M, + LLM_TYPE_558M, LLM_TYPE_770M, LLM_TYPE_780M, LLM_TYPE_0_5B, From 1274c8c357102f123d4262677120707cd3b0624a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sigbj=C3=B8rn=20Skj=C3=A6ret?= Date: Thu, 22 May 2025 11:55:03 +0200 Subject: [PATCH 4/9] fix vocab parsing with only tokenizer.json --- convert_hf_to_gguf.py | 132 ++++++++++++++++++++++++++++-------------- 1 file changed, 87 insertions(+), 45 deletions(-) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 12fd6a3643cf0..bf7915a932ed9 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -802,9 +802,6 @@ def get_vocab_base_pre(self, tokenizer) -> str: if chkhsh == "d5f1dd6f980fec569fb218a81a7658ac45fc56b38c5a0adeb1c232fbe04ef5ec": # ref: https://huggingface.co/ByteDance-Seed/Seed-Coder-8B-Base res = "seed-coder" - if chkhsh == "a81863d07e75497e2194eb1a1574d5e5cd4d5f85a87a0728b922bf2bed6fb327": - # ref: https://huggingface.co/jinaai/jina-embeddings-v3 - res = "jina-v3" if res is None: logger.warning("\n") @@ -3626,44 +3623,93 @@ def _xlmroberta_set_vocab(self) -> None: from sentencepiece import sentencepiece_model_pb2 as model tokenizer_path = self.dir_model / 'sentencepiece.bpe.model' + + tokenizer_json = {} + tokenizer_config_json = {} if not tokenizer_path.is_file(): - raise FileNotFoundError(f"File not found: {tokenizer_path}") + tokenizer_path = self.dir_model / 'tokenizer.json' + tokenizer_config_path = self.dir_model / 'tokenizer_config.json' - sentencepiece_model = model.ModelProto() # pyright: ignore[reportAttributeAccessIssue] - sentencepiece_model.ParseFromString(open(tokenizer_path, "rb").read()) - assert sentencepiece_model.trainer_spec.model_type == 1 # UNIGRAM + if not tokenizer_path.is_file(): + raise FileNotFoundError(f"File not found: {tokenizer_path}") - add_prefix = sentencepiece_model.normalizer_spec.add_dummy_prefix - remove_whitespaces = sentencepiece_model.normalizer_spec.remove_extra_whitespaces - precompiled_charsmap = sentencepiece_model.normalizer_spec.precompiled_charsmap + from base64 import b64decode + from transformers import AutoTokenizer + tokenizer = AutoTokenizer.from_pretrained(self.dir_model) - tokenizer = SentencePieceProcessor() - tokenizer.LoadFromFile(str(tokenizer_path)) + with open(tokenizer_path, "r", encoding="utf-8") as fp: + tokenizer_json = json.load(fp) - vocab_size = self.hparams.get('vocab_size', tokenizer.vocab_size()) + if tokenizer_config_path.is_file(): + with open(tokenizer_config_path, "r", encoding="utf-8") as fp: + tokenizer_config_json = json.load(fp) + + add_prefix = tokenizer.add_prefix_space + remove_whitespaces = tokenizer.clean_up_tokenization_spaces + precompiled_charsmap = b64decode(tokenizer_json["normalizer"]["precompiled_charsmap"]) + + vocab_size = self.hparams.get("vocab_size", tokenizer.vocab_size) + else: + sentencepiece_model = model.ModelProto() # pyright: ignore[reportAttributeAccessIssue] + sentencepiece_model.ParseFromString(open(tokenizer_path, "rb").read()) + assert sentencepiece_model.trainer_spec.model_type == 1 # UNIGRAM + + add_prefix = sentencepiece_model.normalizer_spec.add_dummy_prefix + remove_whitespaces = sentencepiece_model.normalizer_spec.remove_extra_whitespaces + precompiled_charsmap = sentencepiece_model.normalizer_spec.precompiled_charsmap + + tokenizer = SentencePieceProcessor() + tokenizer.LoadFromFile(str(tokenizer_path)) + + vocab_size = self.hparams.get('vocab_size', tokenizer.vocab_size()) tokens: list[bytes] = [f"[PAD{i}]".encode("utf-8") for i in range(vocab_size)] scores: list[float] = [-10000.0] * vocab_size toktypes: list[int] = [SentencePieceTokenTypes.UNUSED] * vocab_size - for token_id in range(tokenizer.vocab_size()): - piece = tokenizer.IdToPiece(token_id) - text = piece.encode("utf-8") - score = tokenizer.GetScore(token_id) + if isinstance(tokenizer, SentencePieceProcessor): + for token_id in range(vocab_size): + piece = tokenizer.IdToPiece(token_id) + text = piece.encode("utf-8") + score = tokenizer.GetScore(token_id) - toktype = SentencePieceTokenTypes.NORMAL - if tokenizer.IsUnknown(token_id): - toktype = SentencePieceTokenTypes.UNKNOWN - elif tokenizer.IsControl(token_id): - toktype = SentencePieceTokenTypes.CONTROL - elif tokenizer.IsUnused(token_id): - toktype = SentencePieceTokenTypes.UNUSED - elif tokenizer.IsByte(token_id): - toktype = SentencePieceTokenTypes.BYTE + toktype = SentencePieceTokenTypes.NORMAL + if tokenizer.IsUnknown(token_id): + toktype = SentencePieceTokenTypes.UNKNOWN + elif tokenizer.IsControl(token_id): + toktype = SentencePieceTokenTypes.CONTROL + elif tokenizer.IsUnused(token_id): + toktype = SentencePieceTokenTypes.UNUSED + elif tokenizer.IsByte(token_id): + toktype = SentencePieceTokenTypes.BYTE - tokens[token_id] = text - scores[token_id] = score - toktypes[token_id] = toktype + tokens[token_id] = text + scores[token_id] = score + toktypes[token_id] = toktype + else: + added_vocab = tokenizer.get_added_vocab() + unk_token = tokenizer_config_json.get("unk_token") + unk_token_id = added_vocab.get(unk_token, 3) + + for token_id in range(vocab_size): + piece = tokenizer._convert_id_to_token(token_id) + text = piece.encode("utf-8") + score = tokenizer_json["model"]["vocab"][token_id][1] + + toktype = SentencePieceTokenTypes.NORMAL + if token_id == unk_token_id: + toktype = SentencePieceTokenTypes.UNKNOWN + elif token_id in tokenizer.all_special_ids: + toktype = SentencePieceTokenTypes.CONTROL + elif token_id in added_vocab.values(): + toktype = SentencePieceTokenTypes.USER_DEFINED + # No reliable way to detect this, but jina-embeddings-v3 doesn't have any + # elif tokenizer.IsByte(token_id): + # toktype = SentencePieceTokenTypes.BYTE + + tokens[token_id] = text + scores[token_id] = score + toktypes[token_id] = toktype if vocab_size > len(tokens): pad_count = vocab_size - len(tokens) @@ -3673,15 +3719,16 @@ def _xlmroberta_set_vocab(self) -> None: scores.append(-1000.0) toktypes.append(SentencePieceTokenTypes.UNUSED) - # realign tokens (see HF tokenizer code) - tokens = [b'', b'', b'', b''] + tokens[3:-1] - scores = [0.0, 0.0, 0.0, 0.0] + scores[3:-1] - toktypes = [ - SentencePieceTokenTypes.CONTROL, - SentencePieceTokenTypes.CONTROL, - SentencePieceTokenTypes.CONTROL, - SentencePieceTokenTypes.UNKNOWN, - ] + toktypes[3:-1] + if isinstance(tokenizer, SentencePieceProcessor): + # realign tokens (see HF tokenizer code) + tokens = [b'', b'', b'', b''] + tokens[3:-1] + scores = [0.0, 0.0, 0.0, 0.0] + scores[3:-1] + toktypes = [ + SentencePieceTokenTypes.CONTROL, + SentencePieceTokenTypes.CONTROL, + SentencePieceTokenTypes.CONTROL, + SentencePieceTokenTypes.UNKNOWN, + ] + toktypes[3:-1] self.gguf_writer.add_tokenizer_model("t5") self.gguf_writer.add_tokenizer_pre("default") @@ -3841,15 +3888,10 @@ def __init__(self, dir_model: Path, ftype: gguf.LlamaFileType, fname_out: Path, self.model_arch = gguf.MODEL_ARCH.JINA_BERT_V3 super().__init__(dir_model, ftype, fname_out, hparams=hparams, **kwargs) - - self._tokenizer_is_xlmroberta = False if self.model_arch == gguf.MODEL_ARCH.JINA_BERT_V3 else True - if self._tokenizer_is_xlmroberta: - self._xlmroberta_tokenizer_init() + self._xlmroberta_tokenizer_init() def set_vocab(self): - if self._tokenizer_is_xlmroberta: - return self._xlmroberta_set_vocab() - return super().set_vocab() + self._xlmroberta_set_vocab() def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: # if name starts with "roberta.", remove the prefix From 65a37fa8e5f83afd36dd84ea9dbafb7c5e56388a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sigbj=C3=B8rn=20Skj=C3=A6ret?= Date: Thu, 22 May 2025 15:19:04 +0200 Subject: [PATCH 5/9] set mask token lstrip attribute --- src/llama-vocab.cpp | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/src/llama-vocab.cpp b/src/llama-vocab.cpp index 9389ca805a584..6d6a1c6e18782 100644 --- a/src/llama-vocab.cpp +++ b/src/llama-vocab.cpp @@ -2080,9 +2080,11 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { std::string model_name; std::string tokenizer_pre; + std::string general_arch; ml.get_key(LLM_KV_GENERAL_NAME, model_name, false); ml.get_key(LLM_KV_TOKENIZER_PRE, tokenizer_pre, false); + ml.get_key(LLM_KV_GENERAL_ARCHITECTURE, general_arch, false); // model name to lowercase std::transform(model_name.begin(), model_name.end(), model_name.begin(), @@ -2091,8 +2093,11 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { } ); - // set attributes by model/tokenizer name - if (_contains_any(tokenizer_pre, {"jina-v2-de", "jina-v2-es", "jina-v2-code"})) { + // set attributes by model/tokenizer/architecture name + if (false + || _contains_any(tokenizer_pre, {"jina-v2-de", "jina-v2-es", "jina-v2-code"}) + || _contains_any(general_arch, {"jina-bert-v3"}) + ) { _set_token_attr("", LLAMA_TOKEN_ATTR_LSTRIP, true); } else if (_contains_any(model_name, {"phi-3", "phi3"})) { for (auto id : cache_special_tokens) { From f2d876ad5ab5013d4df5848038e18e2f8b51760a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sigbj=C3=B8rn=20Skj=C3=A6ret?= Date: Fri, 23 May 2025 09:22:45 +0200 Subject: [PATCH 6/9] additional unk_token_id fallback just in case [no ci] --- convert_hf_to_gguf.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index bf7915a932ed9..0f2c41ecc8b3d 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -3689,7 +3689,7 @@ def _xlmroberta_set_vocab(self) -> None: else: added_vocab = tokenizer.get_added_vocab() unk_token = tokenizer_config_json.get("unk_token") - unk_token_id = added_vocab.get(unk_token, 3) + unk_token_id = added_vocab.get(unk_token, tokenizer_json["model"].get("unk_id", 3)) for token_id in range(vocab_size): piece = tokenizer._convert_id_to_token(token_id) From b17e9811f47adc32682042a238219ea3b37d7e32 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sigbj=C3=B8rn=20Skj=C3=A6ret?= Date: Mon, 26 May 2025 08:40:46 +0200 Subject: [PATCH 7/9] revert vocab_size() change [no ci] --- convert_hf_to_gguf.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 0f2c41ecc8b3d..753c88e7c5f17 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -3668,7 +3668,7 @@ def _xlmroberta_set_vocab(self) -> None: toktypes: list[int] = [SentencePieceTokenTypes.UNUSED] * vocab_size if isinstance(tokenizer, SentencePieceProcessor): - for token_id in range(vocab_size): + for token_id in range(tokenizer.vocab_size()): piece = tokenizer.IdToPiece(token_id) text = piece.encode("utf-8") score = tokenizer.GetScore(token_id) From f5d0305d51de45f7f4642e16349f27b37862502a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sigbj=C3=B8rn=20Skj=C3=A6ret?= Date: Sun, 1 Jun 2025 20:55:03 +0200 Subject: [PATCH 8/9] merge tensor loading into general bert --- src/llama-model.cpp | 47 ++++++++------------------------------------- 1 file changed, 8 insertions(+), 39 deletions(-) diff --git a/src/llama-model.cpp b/src/llama-model.cpp index e46403f3c2011..31b27106315c1 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -2128,6 +2128,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { case LLM_ARCH_BERT: case LLM_ARCH_NOMIC_BERT: case LLM_ARCH_NOMIC_BERT_MOE: + case LLM_ARCH_JINA_BERT_V3: { tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); type_embd = create_tensor(tn(LLM_TENSOR_TOKEN_TYPES, "weight"), {n_embd, n_token_types}, TENSOR_NOT_REQUIRED); @@ -2163,24 +2164,22 @@ bool llama_model::load_tensors(llama_model_loader & ml) { } layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); + layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); layer.attn_out_norm = create_tensor(tn(LLM_TENSOR_ATTN_OUT_NORM, "weight", i), {n_embd}, 0); layer.attn_out_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_OUT_NORM, "bias", i), {n_embd}, 0); if (hparams.moe_every_n_layers > 0 && i % hparams.moe_every_n_layers == 1) { - layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, 0); layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), { n_embd, n_ff, n_expert}, 0); layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), { n_ff, n_embd, n_expert}, 0); layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0); } else { - layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); - layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0); - - if (arch == LLM_ARCH_BERT || arch == LLM_ARCH_NOMIC_BERT_MOE) { - layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, 0); - layer.ffn_up_b = create_tensor(tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}, 0); - layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, 0); - } else { + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_up_b = create_tensor(tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}, TENSOR_NOT_REQUIRED); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0); + layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); + + if (arch == LLM_ARCH_NOMIC_BERT) { layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); } } @@ -2232,36 +2231,6 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0); layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, 0); - layer.layer_out_norm = create_tensor(tn(LLM_TENSOR_LAYER_OUT_NORM, "weight", i), {n_embd}, 0); - layer.layer_out_norm_b = create_tensor(tn(LLM_TENSOR_LAYER_OUT_NORM, "bias", i), {n_embd}, 0); - } - } break; - case LLM_ARCH_JINA_BERT_V3: - { - tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); - type_embd = create_tensor(tn(LLM_TENSOR_TOKEN_TYPES, "weight"), {n_embd, n_token_types}, 0); - - tok_norm = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight"), {n_embd}, 0); - tok_norm_b = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "bias"), {n_embd}, 0); - - for (int i = 0; i < n_layer; ++i) { - auto & layer = layers[i]; - - layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, 0); - layer.bqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd + 2*n_embd_gqa}, 0); - - layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); - layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, 0); - - layer.attn_out_norm = create_tensor(tn(LLM_TENSOR_ATTN_OUT_NORM, "weight", i), {n_embd}, 0); - layer.attn_out_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_OUT_NORM, "bias", i), {n_embd}, 0); - - layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); - layer.ffn_up_b = create_tensor(tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}, 0); - - layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0); - layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, 0); - layer.layer_out_norm = create_tensor(tn(LLM_TENSOR_LAYER_OUT_NORM, "weight", i), {n_embd}, 0); layer.layer_out_norm_b = create_tensor(tn(LLM_TENSOR_LAYER_OUT_NORM, "bias", i), {n_embd}, 0); } From 3862d954bbfa85cc4301cc7cd61548956c09db01 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sigbj=C3=B8rn=20Skj=C3=A6ret?= Date: Sun, 1 Jun 2025 21:46:15 +0200 Subject: [PATCH 9/9] rope --- src/llama-model.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 31b27106315c1..9ac88c27c5f79 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -5961,7 +5961,7 @@ struct llm_build_bert : public llm_graph_context { Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); // RoPE - if (model.arch == LLM_ARCH_NOMIC_BERT || model.arch == LLM_ARCH_NOMIC_BERT_MOE) { + if (model.arch == LLM_ARCH_NOMIC_BERT || model.arch == LLM_ARCH_NOMIC_BERT_MOE || model.arch == LLM_ARCH_JINA_BERT_V3) { Qcur = ggml_rope_ext( ctx0, Qcur, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,