From 3453e62bb91e2c9ee173c65e9e00479022d183aa Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Mon, 16 Sep 2024 16:59:17 +0300 Subject: [PATCH 01/24] py : add XLMRobertaForSequenceClassification [no ci] --- convert_hf_to_gguf.py | 7 ++++++- gguf-py/gguf/constants.py | 6 ++++++ gguf-py/gguf/tensor_mapping.py | 8 ++++++++ 3 files changed, 20 insertions(+), 1 deletion(-) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 7be609054d6b8..a7146442f7dd9 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -2598,7 +2598,7 @@ def set_gguf_parameters(self): self.gguf_writer.add_rope_freq_base(self.hparams["rotary_emb_base"]) -@Model.register("XLMRobertaModel") +@Model.register("XLMRobertaModel", "XLMRobertaForSequenceClassification") class XLMRobertaModel(BertModel): model_arch = gguf.MODEL_ARCH.BERT @@ -2701,6 +2701,11 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter if self._position_offset is not None: data_torch = data_torch[self._position_offset:,:] + # if name starts with "roberta.", remove the prefix + # e.g. https://huggingface.co/BAAI/bge-reranker-v2-m3/tree/main + if name.startswith("roberta."): + name = name[8:] + return super().modify_tensors(data_torch, name, bid) diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index 560eee916f27e..414a3ae211dd4 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -343,6 +343,8 @@ class MODEL_TENSOR(IntEnum): ENC_FFN_DOWN = auto() ENC_FFN_UP = auto() ENC_OUTPUT_NORM = auto() + CLS = auto() # classifier + CLS_OUT = auto() # classifier output projection MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = { @@ -501,6 +503,8 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.ENC_FFN_DOWN: "enc.blk.{bid}.ffn_down", MODEL_TENSOR.ENC_FFN_UP: "enc.blk.{bid}.ffn_up", MODEL_TENSOR.ENC_OUTPUT_NORM: "enc.output_norm", + MODEL_TENSOR.CLS: "cls", + MODEL_TENSOR.CLS_OUT: "cls.output", } MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = { @@ -610,6 +614,8 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.FFN_DOWN, MODEL_TENSOR.FFN_UP, MODEL_TENSOR.LAYER_OUT_NORM, + MODEL_TENSOR.CLS, + MODEL_TENSOR.CLS_OUT, ], MODEL_ARCH.NOMIC_BERT: [ MODEL_TENSOR.TOKEN_EMBD, diff --git a/gguf-py/gguf/tensor_mapping.py b/gguf-py/gguf/tensor_mapping.py index 4e850726e9ba4..4a34a549de18b 100644 --- a/gguf-py/gguf/tensor_mapping.py +++ b/gguf-py/gguf/tensor_mapping.py @@ -679,6 +679,14 @@ class TensorNameMap: MODEL_TENSOR.ENC_OUTPUT_NORM: ( "encoder.final_layer_norm", # t5 ), + + MODEL_TENSOR.CLS: ( + "classifier.dense", # roberta + ), + + MODEL_TENSOR.CLS_OUT: ( + "classifier.out_proj", # roberta + ), } # architecture-specific block mappings From 77723ed69e8a5af03cb4b945b4565e6cc6da5325 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Tue, 17 Sep 2024 13:40:52 +0300 Subject: [PATCH 02/24] py : fix scalar-tensor conversion [no ci] --- convert_hf_to_gguf.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index a7146442f7dd9..ca020f8764dca 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -291,8 +291,13 @@ def prepare_tensors(self): bid = int(part) break - for new_name, data in ((n, d.squeeze().numpy()) for n, d in self.modify_tensors(data_torch, name, bid)): - data: np.ndarray # type hint + for new_name, data_torch in (self.modify_tensors(data_torch, name, bid)): + data = data_torch.squeeze().numpy() + + # if data ends up empty, it means data_torch was a scalar tensor -> restore + if len(data.shape) == 0: + data = data_torch.numpy() + n_dims = len(data.shape) data_qtype: gguf.GGMLQuantizationType | bool = self.tensor_force_quant(name, new_name, bid, n_dims) From 49f90de363e97fd9071a1f2fbe8ccb107be7f478 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Tue, 17 Sep 2024 13:53:19 +0300 Subject: [PATCH 03/24] py : fix position embeddings chop [no ci] --- convert_hf_to_gguf.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index ca020f8764dca..d69a0d9f8270c 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -2701,16 +2701,16 @@ def set_vocab(self): self.gguf_writer.add_add_eos_token(True) def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: - # position embeddings start at pad_token_id + 1, so just chop down the weight tensor - if name == "embeddings.position_embeddings.weight": - if self._position_offset is not None: - data_torch = data_torch[self._position_offset:,:] - # if name starts with "roberta.", remove the prefix # e.g. https://huggingface.co/BAAI/bge-reranker-v2-m3/tree/main if name.startswith("roberta."): name = name[8:] + # position embeddings start at pad_token_id + 1, so just chop down the weight tensor + if name == "embeddings.position_embeddings.weight": + if self._position_offset is not None: + data_torch = data_torch[self._position_offset:,:] + return super().modify_tensors(data_torch, name, bid) From dc0cdd8760547b041bb206541e3c9cf9bb879777 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Tue, 17 Sep 2024 16:38:38 +0300 Subject: [PATCH 04/24] llama : read new cls tensors [no ci] --- src/llama.cpp | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/src/llama.cpp b/src/llama.cpp index 0accb1492efaa..ab0328ce89b31 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -602,6 +602,8 @@ enum llm_tensor { LLM_TENSOR_ENC_FFN_DOWN, LLM_TENSOR_ENC_FFN_UP, LLM_TENSOR_ENC_OUTPUT_NORM, + LLM_TENSOR_CLS, + LLM_TENSOR_CLS_OUT, }; static const std::map> LLM_TENSOR_NAMES = { @@ -789,6 +791,8 @@ static const std::map> LLM_TENSOR_NA { LLM_TENSOR_LAYER_OUT_NORM, "blk.%d.layer_output_norm" }, { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + { LLM_TENSOR_CLS, "cls" }, + { LLM_TENSOR_CLS_OUT, "cls.output" }, }, }, { @@ -2882,6 +2886,12 @@ struct llama_model { struct ggml_tensor * output_b; struct ggml_tensor * output_norm_enc; + // classifier + struct ggml_tensor * cls; + struct ggml_tensor * cls_b; + struct ggml_tensor * cls_out; + struct ggml_tensor * cls_out_b; + std::vector layers; llama_split_mode split_mode; @@ -7351,6 +7361,12 @@ static bool llm_load_tensors( if (model.arch == LLM_ARCH_BERT) { model.pos_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_POS_EMBD, "weight"), {n_embd, n_ctx_train}); + + model.cls = ml.create_tensor(ctx_output, tn(LLM_TENSOR_CLS, "weight"), {n_embd, n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED); + model.cls_b = ml.create_tensor(ctx_output, tn(LLM_TENSOR_CLS, "bias"), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED); + + model.cls_out = ml.create_tensor(ctx_output, tn(LLM_TENSOR_CLS_OUT, "weight"), {n_embd, 1}, llama_model_loader::TENSOR_NOT_REQUIRED); + model.cls_out_b = ml.create_tensor(ctx_output, tn(LLM_TENSOR_CLS_OUT, "bias"), {1}, llama_model_loader::TENSOR_NOT_REQUIRED); } model.tok_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight"), {n_embd}); From d0a7bf9382782368b57e68585b8926aa875a2f95 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Wed, 18 Sep 2024 21:20:21 +0300 Subject: [PATCH 05/24] llama : add classigication head (wip) [no ci] --- common/arg.cpp | 2 +- src/llama.cpp | 14 +++++++++++++- 2 files changed, 14 insertions(+), 2 deletions(-) diff --git a/common/arg.cpp b/common/arg.cpp index 6880117ed8001..885e982bb0fd2 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -391,7 +391,7 @@ gpt_params_context gpt_params_parser_init(gpt_params & params, llama_example ex, [](gpt_params & params) { params.verbose_prompt = true; } - ).set_examples({LLAMA_EXAMPLE_MAIN})); + )); add_opt(llama_arg( {"--no-display-prompt"}, format("don't print prompt at generation (default: %s)", !params.display_prompt ? "true" : "false"), diff --git a/src/llama.cpp b/src/llama.cpp index ab0328ce89b31..86731bf68fe7d 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -11455,8 +11455,20 @@ struct llm_build_context { inpL = cur; } - // final output cur = inpL; + + // classification head + // https://github.com/huggingface/transformers/blob/5af7d41e49bbfc8319f462eb45253dcb3863dfb7/src/transformers/models/roberta/modeling_roberta.py#L1566 + // TODO: become pooling layer? + if (model.cls) { + cur = ggml_add(ctx0, ggml_mul_mat(ctx0, model.cls, cur), model.cls_b); + + cur = ggml_tanh(ctx0, cur); + + cur = ggml_add(ctx0, ggml_mul_mat(ctx0, model.cls_out, cur), model.cls_out_b); + // TODO: cur is now a scalar - what to do? + } + cb(cur, "result_embd", -1); ggml_build_forward_expand(gf, cur); From 125a0671ab0507679af6d993fe5181b5c430c9d5 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Thu, 19 Sep 2024 13:21:15 +0300 Subject: [PATCH 06/24] llama : add "rank" pooling type ggml-ci --- common/arg.cpp | 3 ++- examples/embedding/embedding.cpp | 4 +++ include/llama.h | 1 + src/llama.cpp | 43 ++++++++++++++++++++------------ 4 files changed, 34 insertions(+), 17 deletions(-) diff --git a/common/arg.cpp b/common/arg.cpp index 885e982bb0fd2..498a0cf981efd 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -1098,8 +1098,9 @@ gpt_params_context gpt_params_parser_init(gpt_params & params, llama_example ex, [](gpt_params & params, const std::string & value) { /**/ if (value == "none") { params.pooling_type = LLAMA_POOLING_TYPE_NONE; } else if (value == "mean") { params.pooling_type = LLAMA_POOLING_TYPE_MEAN; } - else if (value == "cls") { params.pooling_type = LLAMA_POOLING_TYPE_CLS; } + else if (value == "cls") { params.pooling_type = LLAMA_POOLING_TYPE_CLS; } else if (value == "last") { params.pooling_type = LLAMA_POOLING_TYPE_LAST; } + else if (value == "rank") { params.pooling_type = LLAMA_POOLING_TYPE_RANK; } else { throw std::invalid_argument("invalid value"); } } ).set_examples({LLAMA_EXAMPLE_EMBEDDING, LLAMA_EXAMPLE_RETRIEVAL, LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_POOLING")); diff --git a/examples/embedding/embedding.cpp b/examples/embedding/embedding.cpp index a438dcb5adf34..a0ca9d98c978b 100644 --- a/examples/embedding/embedding.cpp +++ b/examples/embedding/embedding.cpp @@ -234,6 +234,10 @@ int main(int argc, char ** argv) { } LOG("\n"); } + } else if (pooling_type == LLAMA_POOLING_TYPE_RANK) { + for (int j = 0; j < n_embd_count; j++) { + LOG("rank score %d: %8.3f\n", j, emb[j * n_embd]); + } } else { // print the first part of the embeddings or for a single prompt, the full embedding for (int j = 0; j < n_prompts; j++) { diff --git a/include/llama.h b/include/llama.h index 132937a0700e7..6601b3444864a 100644 --- a/include/llama.h +++ b/include/llama.h @@ -192,6 +192,7 @@ extern "C" { LLAMA_POOLING_TYPE_MEAN = 1, LLAMA_POOLING_TYPE_CLS = 2, LLAMA_POOLING_TYPE_LAST = 3, + LLAMA_POOLING_TYPE_RANK = 4, }; enum llama_attention_type { diff --git a/src/llama.cpp b/src/llama.cpp index 86731bf68fe7d..1346198f8e59d 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -10213,6 +10213,10 @@ struct llm_build_context { struct ggml_tensor * cur; switch (pooling_type) { + case LLAMA_POOLING_TYPE_NONE: + { + cur = inp; + } break; case LLAMA_POOLING_TYPE_MEAN: { struct ggml_tensor * inp_mean = build_inp_mean(); @@ -10224,9 +10228,24 @@ struct llm_build_context { struct ggml_tensor * inp_cls = build_inp_cls(); cur = ggml_get_rows(ctx0, inp, inp_cls); } break; - case LLAMA_POOLING_TYPE_NONE: + case LLAMA_POOLING_TYPE_RANK: { - cur = inp; + struct ggml_tensor * inp_cls = build_inp_cls(); + inp = ggml_get_rows(ctx0, inp, inp_cls); + + // classification head + // https://github.com/huggingface/transformers/blob/5af7d41e49bbfc8319f462eb45253dcb3863dfb7/src/transformers/models/roberta/modeling_roberta.py#L1566 + GGML_ASSERT(model.cls != nullptr); + GGML_ASSERT(model.cls_b != nullptr); + GGML_ASSERT(model.cls_out != nullptr); + GGML_ASSERT(model.cls_out_b != nullptr); + + cur = ggml_add (ctx0, ggml_mul_mat(ctx0, model.cls, inp), model.cls_b); + cur = ggml_tanh(ctx0, cur); + cur = ggml_add (ctx0, ggml_mul_mat(ctx0, model.cls_out, cur), model.cls_out_b); + + // broadcast across the embedding size to make it compatible with the llama_get_embeddings API + cur = ggml_repeat(ctx0, cur, inp); } break; default: { @@ -11457,18 +11476,6 @@ struct llm_build_context { cur = inpL; - // classification head - // https://github.com/huggingface/transformers/blob/5af7d41e49bbfc8319f462eb45253dcb3863dfb7/src/transformers/models/roberta/modeling_roberta.py#L1566 - // TODO: become pooling layer? - if (model.cls) { - cur = ggml_add(ctx0, ggml_mul_mat(ctx0, model.cls, cur), model.cls_b); - - cur = ggml_tanh(ctx0, cur); - - cur = ggml_add(ctx0, ggml_mul_mat(ctx0, model.cls_out, cur), model.cls_out_b); - // TODO: cur is now a scalar - what to do? - } - cb(cur, "result_embd", -1); ggml_build_forward_expand(gf, cur); @@ -16446,7 +16453,9 @@ static void llama_set_inputs(llama_context & lctx, const llama_ubatch & batch) { } } - if (cparams.embeddings && cparams.pooling_type == LLAMA_POOLING_TYPE_CLS) { + if (cparams.embeddings && ( + cparams.pooling_type == LLAMA_POOLING_TYPE_CLS || + cparams.pooling_type == LLAMA_POOLING_TYPE_RANK)) { const int64_t n_tokens = batch.n_tokens; const int64_t n_seq_tokens = batch.n_seq_tokens; const int64_t n_seqs = batch.n_seqs; @@ -16461,7 +16470,7 @@ static void llama_set_inputs(llama_context & lctx, const llama_ubatch & batch) { const llama_seq_id seq_id = batch.seq_id[s][0]; // TODO: adapt limits to n_seqs when batch.equal_seqs is true - GGML_ASSERT(seq_id < n_tokens && "seq_id cannot be larger than n_tokens with pooling_type == CLS"); + GGML_ASSERT(seq_id < n_tokens && "seq_id cannot be larger than n_tokens with pooling_type == CLS or RANK"); for (int i = 0; i < n_seq_tokens; ++i) { const llama_pos pos = batch.pos[s*n_seq_tokens + i]; @@ -16988,6 +16997,7 @@ static int llama_decode_internal( case LLAMA_POOLING_TYPE_MEAN: case LLAMA_POOLING_TYPE_CLS: case LLAMA_POOLING_TYPE_LAST: + case LLAMA_POOLING_TYPE_RANK: { // extract sequence embeddings (cleared before processing each batch) auto & embd_seq_out = lctx.embd_seq; @@ -17191,6 +17201,7 @@ static int llama_encode_internal( case LLAMA_POOLING_TYPE_MEAN: case LLAMA_POOLING_TYPE_CLS: case LLAMA_POOLING_TYPE_LAST: + case LLAMA_POOLING_TYPE_RANK: { // extract sequence embeddings auto & embd_seq_out = lctx.embd_seq; From 6235c62ac952f64cf0fd940056eafb97b1bae402 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Thu, 19 Sep 2024 16:18:30 +0300 Subject: [PATCH 07/24] server : add rerank endpoint ggml-ci --- common/arg.cpp | 2 +- examples/server/server.cpp | 196 ++++++++++++++++++++++++++++++++++--- examples/server/utils.hpp | 25 ++++- 3 files changed, 209 insertions(+), 14 deletions(-) diff --git a/common/arg.cpp b/common/arg.cpp index 498a0cf981efd..12f05cc20cb4c 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -1093,7 +1093,7 @@ gpt_params_context gpt_params_parser_init(gpt_params & params, llama_example ex, } ).set_sparam()); add_opt(llama_arg( - {"--pooling"}, "{none,mean,cls,last}", + {"--pooling"}, "{none,mean,cls,last, rank}", "pooling type for embeddings, use model default if unspecified", [](gpt_params & params, const std::string & value) { /**/ if (value == "none") { params.pooling_type = LLAMA_POOLING_TYPE_NONE; } diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 61ff09bb2b40f..71d29002d74a0 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -92,6 +92,7 @@ enum server_task_type { enum server_task_cmpl_type { SERVER_TASK_CMPL_TYPE_NORMAL, SERVER_TASK_CMPL_TYPE_EMBEDDING, + SERVER_TASK_CMPL_TYPE_RERANK, SERVER_TASK_CMPL_TYPE_INFILL, }; @@ -172,6 +173,7 @@ struct server_slot { std::vector generated_token_probs; server_task_cmpl_type cmpl_type = SERVER_TASK_CMPL_TYPE_NORMAL; + bool has_next_token = true; bool truncated = false; bool stopped_eos = false; @@ -954,8 +956,17 @@ struct server_context { slot.prompt = *prompt; } else if (prompt->is_array() && prompt->size() == 1 && prompt->at(0).is_array()) { slot.prompt = prompt->at(0); + } else if (prompt->is_array() && prompt->size() > 1) { + // array of strings + for (const auto & el : *prompt) { + if (!el.is_string()) { + send_error(task, "\"prompt\" must be a string, an array of strings or an array of integers", ERROR_TYPE_INVALID_REQUEST); + return false; + } + } + slot.prompt = *prompt; } else { - send_error(task, "\"prompt\" must be a string or an array of integers", ERROR_TYPE_INVALID_REQUEST); + send_error(task, "\"prompt\" must be a string, an array of strings or an array of integers", ERROR_TYPE_INVALID_REQUEST); return false; } } @@ -1389,6 +1400,7 @@ struct server_context { res.data = json { {"embedding", std::vector(n_embd, 0.0f)}, + {"index", slot.index}, }; continue; @@ -1407,6 +1419,44 @@ struct server_context { queue_results.send(res); } + void send_rank(const server_slot & slot, const llama_batch & batch) { + server_task_result res; + res.id = slot.id_task; + res.error = false; + res.stop = true; + + for (int i = 0; i < batch.n_tokens; ++i) { + if (!batch.logits[i] || batch.seq_id[i][0] != slot.id + 1) { + continue; + } + + const float * embd = llama_get_embeddings_seq(ctx, batch.seq_id[i][0]); + if (embd == NULL) { + embd = llama_get_embeddings_ith(ctx, i); + } + + if (embd == NULL) { + SLT_ERR(slot, "failed to get embeddings, token = %d, seq_id = %d\n", batch.token[i], batch.seq_id[i][0]); + + res.data = json { + {"index", slot.index}, + {"rank", -1e6}, + }; + + continue; + } + + res.data = json { + {"index", slot.index}, + {"rank", embd[0]}, + }; + } + + SLT_DBG(slot, "sending rank, res = '%s'\n", res.data.dump().c_str()); + + queue_results.send(res); + } + // // Functions to create new task(s) and receive result(s) // @@ -1442,13 +1492,23 @@ struct server_context { // otherwise, it's a multiple-prompt task, we break it into smaller tasks else if (prompt.is_array()) { std::vector prompts = prompt; - for (size_t i = 0; i < prompts.size(); i++) { - const auto & e = prompts[i]; - if (e.is_string() || json_is_array_of_numbers(e)) { - data["index"] = i; - create_task(data, true, e); - } else { - throw std::runtime_error(error_msg); + if (cmpl_type == SERVER_TASK_CMPL_TYPE_RERANK) { + for (size_t i = 1; i < prompts.size(); i++) { + json qd; + qd.push_back(prompts[0]); + qd.push_back(prompts[i]); + data["index"] = i - 1; + create_task(data, true, qd); + } + } else { + for (size_t i = 0; i < prompts.size(); i++) { + const auto & e = prompts[i]; + if (e.is_string() || json_is_array_of_numbers(e)) { + data["index"] = i; + create_task(data, true, e); + } else { + throw std::runtime_error(error_msg); + } } } } @@ -1492,7 +1552,9 @@ struct server_context { return; } - size_t idx = result.data["index"]; + const size_t idx = result.data["index"]; + GGML_ASSERT(idx < results.size() && "index out of range"); + results[idx] = result; } result_handler(results); @@ -1951,6 +2013,29 @@ struct server_context { } prompt_tokens = embd_inp; + } else if (slot.cmpl_type == SERVER_TASK_CMPL_TYPE_RERANK) { + // require slot.prompt to be array of 2 strings + if (!slot.prompt.is_array() || slot.prompt.size() != 2) { + SLT_ERR(slot, "%s", "invalid prompt for rerank task\n"); + slot.release(); + send_error(slot, "invalid prompt for rerank task", ERROR_TYPE_INVALID_REQUEST); + continue; + } + + // prompt: querydoc + prompt_tokens.clear(); + prompt_tokens.push_back(llama_token_bos(model)); + { + const auto part = tokenize(slot.prompt[0], false); + prompt_tokens.insert(prompt_tokens.end(), part.begin(), part.end()); + } + prompt_tokens.push_back(llama_token_eos(model)); + prompt_tokens.push_back(llama_token_bos(model)); + { + const auto part = tokenize(slot.prompt[1], false); + prompt_tokens.insert(prompt_tokens.end(), part.begin(), part.end()); + } + prompt_tokens.push_back(llama_token_eos(model)); } else { prompt_tokens = tokenize(slot.prompt, system_prompt.empty()); // add BOS if there isn't system prompt } @@ -1970,7 +2055,7 @@ struct server_context { continue; } - if (slot.cmpl_type == SERVER_TASK_CMPL_TYPE_EMBEDDING) { + if (slot.cmpl_type == SERVER_TASK_CMPL_TYPE_EMBEDDING || slot.cmpl_type == SERVER_TASK_CMPL_TYPE_RERANK) { // this prompt is too large to process - discard it if (slot.n_prompt_tokens > n_ubatch) { slot.release(); @@ -2048,7 +2133,7 @@ struct server_context { slot.n_prompt_tokens_processed = 0; } - if (slot.cmpl_type == SERVER_TASK_CMPL_TYPE_EMBEDDING) { + if (slot.cmpl_type == SERVER_TASK_CMPL_TYPE_EMBEDDING || slot.cmpl_type == SERVER_TASK_CMPL_TYPE_RERANK) { // cannot fit the prompt in the current batch - will try next iter if (batch.n_tokens + slot.n_prompt_tokens > n_batch) { continue; @@ -2056,7 +2141,10 @@ struct server_context { } // check that we are in the right batch_type, if not defer the slot - bool slot_type = slot.cmpl_type == SERVER_TASK_CMPL_TYPE_EMBEDDING ? 1 : 0; + const bool slot_type = + slot.cmpl_type == SERVER_TASK_CMPL_TYPE_EMBEDDING || + slot.cmpl_type == SERVER_TASK_CMPL_TYPE_RERANK ? 1 : 0; + if (batch_type == -1) { batch_type = slot_type; } else if (batch_type != slot_type) { @@ -2229,6 +2317,13 @@ struct server_context { continue; // continue loop of slots } + if (slot.cmpl_type == SERVER_TASK_CMPL_TYPE_RERANK) { + send_rank(slot, batch_view); + slot.release(); + slot.i_batch = -1; + continue; // continue loop of slots + } + // prompt evaluated for next-token prediction slot.state = SLOT_STATE_GENERATING; } else if (slot.state != SLOT_STATE_GENERATING) { @@ -3023,6 +3118,82 @@ int main(int argc, char ** argv) { res_ok(res, root); }; + const auto handle_rerank = [&ctx_server, &res_error, &res_ok](const httplib::Request & req, httplib::Response & res) { + const json body = json::parse(req.body); + + // TODO: implement + //int top_n = 1; + //if (body.count("top_n") != 1) { + // top_n = body.at("top_n"); + //} else { + // res_error(res, format_error_response("\"top_n\" must be provided", ERROR_TYPE_INVALID_REQUEST)); + // return; + //} + + json query; + if (body.count("query") == 1) { + query = body.at("query"); + if (!query.is_string()) { + res_error(res, format_error_response("\"query\" must be a string", ERROR_TYPE_INVALID_REQUEST)); + return; + } + } else { + exit(0); + res_error(res, format_error_response("\"query\" must be provided", ERROR_TYPE_INVALID_REQUEST)); + return; + } + + json documents; + if (body.count("documents") != 0) { + documents = body.at("documents"); + if (!documents.is_array() || documents.size() == 0) { + res_error(res, format_error_response("\"documents\" must be a non-empty string array", ERROR_TYPE_INVALID_REQUEST)); + return; + } + } else { + res_error(res, format_error_response("\"documents\" must be provided", ERROR_TYPE_INVALID_REQUEST)); + return; + } + + // construct prompt object: array of ["query", "doc0", "doc1", ...] + json prompt; + prompt.push_back(query); + for (const auto & doc : documents) { + prompt.push_back(doc); + } + + LOG_DBG("rerank prompt: %s\n", prompt.dump().c_str()); + + // create and queue the task + json responses = json::array(); + bool error = false; + { + std::vector tasks = ctx_server.create_tasks_cmpl({{"prompt", prompt}}, SERVER_TASK_CMPL_TYPE_RERANK); + ctx_server.queue_results.add_waiting_tasks(tasks); + ctx_server.queue_tasks.post(tasks); + + // get the result + std::unordered_set task_ids = server_task::get_list_id(tasks); + + ctx_server.receive_cmpl_results(task_ids, [&](std::vector & results) { + for (const auto & res : results) { + responses.push_back(res.data); + } + }, [&](const json & error_data) { + res_error(res, error_data); + error = true; + }); + } + + if (error) { + return; + } + + // write JSON response + json root = format_response_rerank(body, responses); + res_ok(res, root); + }; + const auto handle_lora_adapters_list = [&](const httplib::Request &, httplib::Response & res) { json result = json::array(); for (size_t i = 0; i < ctx_server.loras.size(); ++i) { @@ -3119,6 +3290,7 @@ int main(int argc, char ** argv) { svr->Post("/embedding", handle_embeddings); // legacy svr->Post("/embeddings", handle_embeddings); svr->Post("/v1/embeddings", handle_embeddings); + svr->Post("/v1/rerank", handle_rerank); svr->Post("/tokenize", handle_tokenize); svr->Post("/detokenize", handle_detokenize); // LoRA adapters hotswap diff --git a/examples/server/utils.hpp b/examples/server/utils.hpp index f093f547ff2c1..91e7f792d28d6 100644 --- a/examples/server/utils.hpp +++ b/examples/server/utils.hpp @@ -537,7 +537,7 @@ static json format_embeddings_response_oaicompat(const json & request, const jso json res = json { {"model", json_value(request, "model", std::string(DEFAULT_OAICOMPAT_MODEL))}, {"object", "list"}, - {"usage", json { + {"usage", json { // TODO: fill {"prompt_tokens", 0}, {"total_tokens", 0} }}, @@ -547,6 +547,29 @@ static json format_embeddings_response_oaicompat(const json & request, const jso return res; } +static json format_response_rerank(const json & request, const json & ranks) { + json data = json::array(); + int i = 0; + for (const auto & rank : ranks) { + data.push_back(json{ + {"index", i++}, + {"relevance_score", json_value(rank, "rank", 0.0)}, + }); + } + + json res = json { + {"model", json_value(request, "model", std::string(DEFAULT_OAICOMPAT_MODEL))}, + {"object", "list"}, + {"usage", json { // TODO: fill + {"prompt_tokens", 0}, + {"total_tokens", 0} + }}, + {"results", data} + }; + + return res; +} + static bool is_valid_utf8(const std::string & str) { const unsigned char* bytes = reinterpret_cast(str.data()); const unsigned char* end = bytes + str.length(); From 6916ed160673d47e1e4f809f5b27ee68e2d9039e Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Mon, 23 Sep 2024 20:20:38 +0300 Subject: [PATCH 08/24] llama : aboud ggml_repeat during classification --- src/llama.cpp | 18 ++++++++++++++---- 1 file changed, 14 insertions(+), 4 deletions(-) diff --git a/src/llama.cpp b/src/llama.cpp index 1346198f8e59d..f0f7b67cf801c 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -10243,9 +10243,6 @@ struct llm_build_context { cur = ggml_add (ctx0, ggml_mul_mat(ctx0, model.cls, inp), model.cls_b); cur = ggml_tanh(ctx0, cur); cur = ggml_add (ctx0, ggml_mul_mat(ctx0, model.cls_out, cur), model.cls_out_b); - - // broadcast across the embedding size to make it compatible with the llama_get_embeddings API - cur = ggml_repeat(ctx0, cur, inp); } break; default: { @@ -16997,7 +16994,6 @@ static int llama_decode_internal( case LLAMA_POOLING_TYPE_MEAN: case LLAMA_POOLING_TYPE_CLS: case LLAMA_POOLING_TYPE_LAST: - case LLAMA_POOLING_TYPE_RANK: { // extract sequence embeddings (cleared before processing each batch) auto & embd_seq_out = lctx.embd_seq; @@ -17011,6 +17007,20 @@ static int llama_decode_internal( ggml_backend_tensor_get_async(backend_embd, embd, embd_seq_out[seq_id].data(), (n_embd*seq_id)*sizeof(float), n_embd*sizeof(float)); } } break; + case LLAMA_POOLING_TYPE_RANK: + { + // extract the rank score - a single float per sequence + auto & embd_seq_out = lctx.embd_seq; + + for (uint32_t s = 0; s < ubatch.n_seqs; ++s) { + const llama_seq_id seq_id = ubatch.seq_id[s][0]; + if (embd_seq_out.find(seq_id) != embd_seq_out.end()) { + continue; + } + embd_seq_out[seq_id].resize(1); + ggml_backend_tensor_get_async(backend_embd, embd, embd_seq_out[seq_id].data(), (seq_id)*sizeof(float), sizeof(float)); + } + } break; case LLAMA_POOLING_TYPE_UNSPECIFIED: { GGML_ABORT("unknown pooling type"); From 62a45d12ef4b42d5d5c0172e19ef41b17ba71a09 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Wed, 25 Sep 2024 16:58:54 +0300 Subject: [PATCH 09/24] rerank : cleanup + comments --- examples/embedding/embedding.cpp | 2 +- examples/server/server.cpp | 16 +++++++++++----- examples/server/utils.hpp | 2 +- include/llama.h | 11 ++++++----- src/llama.cpp | 10 ++++++++-- 5 files changed, 27 insertions(+), 14 deletions(-) diff --git a/examples/embedding/embedding.cpp b/examples/embedding/embedding.cpp index a0ca9d98c978b..18d6512608901 100644 --- a/examples/embedding/embedding.cpp +++ b/examples/embedding/embedding.cpp @@ -236,7 +236,7 @@ int main(int argc, char ** argv) { } } else if (pooling_type == LLAMA_POOLING_TYPE_RANK) { for (int j = 0; j < n_embd_count; j++) { - LOG("rank score %d: %8.3f\n", j, emb[j * n_embd]); + LOG("rerank score %d: %8.3f\n", j, emb[j * n_embd]); } } else { // print the first part of the embeddings or for a single prompt, the full embedding diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 71d29002d74a0..ce65164d1f2ac 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -1419,7 +1419,7 @@ struct server_context { queue_results.send(res); } - void send_rank(const server_slot & slot, const llama_batch & batch) { + void send_rerank(const server_slot & slot, const llama_batch & batch) { server_task_result res; res.id = slot.id_task; res.error = false; @@ -1440,7 +1440,7 @@ struct server_context { res.data = json { {"index", slot.index}, - {"rank", -1e6}, + {"score", -1e6}, }; continue; @@ -1448,11 +1448,11 @@ struct server_context { res.data = json { {"index", slot.index}, - {"rank", embd[0]}, + {"score", embd[0]}, }; } - SLT_DBG(slot, "sending rank, res = '%s'\n", res.data.dump().c_str()); + SLT_DBG(slot, "sending rerank result, res = '%s'\n", res.data.dump().c_str()); queue_results.send(res); } @@ -1493,6 +1493,9 @@ struct server_context { else if (prompt.is_array()) { std::vector prompts = prompt; if (cmpl_type == SERVER_TASK_CMPL_TYPE_RERANK) { + // prompts[0] is the question + // the rest are the answers/documents + SRV_DBG("creating rerank tasks, n_prompts = %d\n", (int) prompts.size() - 1); for (size_t i = 1; i < prompts.size(); i++) { json qd; qd.push_back(prompts[0]); @@ -1501,6 +1504,7 @@ struct server_context { create_task(data, true, qd); } } else { + SRV_DBG("creating multi-prompt tasks, n_prompts = %d\n", (int) prompts.size()); for (size_t i = 0; i < prompts.size(); i++) { const auto & e = prompts[i]; if (e.is_string() || json_is_array_of_numbers(e)) { @@ -1965,6 +1969,7 @@ struct server_context { // track if this is an embedding or non-embedding batch // if we've added sampled tokens above, we are in non-embedding mode // -1: none, 0: non-embedding, 1: embedding + // TODO: make enum int32_t batch_type = batch.n_tokens > 0 ? 0 : -1; // next, batch any pending prompts without exceeding n_batch @@ -2133,6 +2138,7 @@ struct server_context { slot.n_prompt_tokens_processed = 0; } + // non-causal tasks require to fit the entire prompt in the physical batch if (slot.cmpl_type == SERVER_TASK_CMPL_TYPE_EMBEDDING || slot.cmpl_type == SERVER_TASK_CMPL_TYPE_RERANK) { // cannot fit the prompt in the current batch - will try next iter if (batch.n_tokens + slot.n_prompt_tokens > n_batch) { @@ -2318,7 +2324,7 @@ struct server_context { } if (slot.cmpl_type == SERVER_TASK_CMPL_TYPE_RERANK) { - send_rank(slot, batch_view); + send_rerank(slot, batch_view); slot.release(); slot.i_batch = -1; continue; // continue loop of slots diff --git a/examples/server/utils.hpp b/examples/server/utils.hpp index 91e7f792d28d6..47dfdfde512dc 100644 --- a/examples/server/utils.hpp +++ b/examples/server/utils.hpp @@ -553,7 +553,7 @@ static json format_response_rerank(const json & request, const json & ranks) { for (const auto & rank : ranks) { data.push_back(json{ {"index", i++}, - {"relevance_score", json_value(rank, "rank", 0.0)}, + {"relevance_score", json_value(rank, "score", 0.0)}, }); } diff --git a/include/llama.h b/include/llama.h index 6601b3444864a..94341d78acb54 100644 --- a/include/llama.h +++ b/include/llama.h @@ -192,7 +192,7 @@ extern "C" { LLAMA_POOLING_TYPE_MEAN = 1, LLAMA_POOLING_TYPE_CLS = 2, LLAMA_POOLING_TYPE_LAST = 3, - LLAMA_POOLING_TYPE_RANK = 4, + LLAMA_POOLING_TYPE_RANK = 4, // used by reranking models to attach the classification head to the graph }; enum llama_attention_type { @@ -202,9 +202,9 @@ extern "C" { }; enum llama_split_mode { - LLAMA_SPLIT_MODE_NONE = 0, // single GPU - LLAMA_SPLIT_MODE_LAYER = 1, // split layers and KV across GPUs - LLAMA_SPLIT_MODE_ROW = 2, // split rows across GPUs + LLAMA_SPLIT_MODE_NONE = 0, // single GPU + LLAMA_SPLIT_MODE_LAYER = 1, // split layers and KV across GPUs + LLAMA_SPLIT_MODE_ROW = 2, // split rows across GPUs }; // TODO: simplify (https://github.com/ggerganov/llama.cpp/pull/9294#pullrequestreview-2286561979) @@ -872,7 +872,8 @@ extern "C" { // Get the embeddings for a sequence id // Returns NULL if pooling_type is LLAMA_POOLING_TYPE_NONE - // shape: [n_embd] (1-dimensional) + // when pooling_type == LLAMA_POOLING_TYPE_RANK, returns float[1] with the rank of the sequence + // otherwise: float[n_embd] (1-dimensional) LLAMA_API float * llama_get_embeddings_seq(struct llama_context * ctx, llama_seq_id seq_id); // diff --git a/src/llama.cpp b/src/llama.cpp index f0f7b67cf801c..b7c0fa4f4bf23 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -17009,7 +17009,7 @@ static int llama_decode_internal( } break; case LLAMA_POOLING_TYPE_RANK: { - // extract the rank score - a single float per sequence + // extract the rerank score - a single float per sequence auto & embd_seq_out = lctx.embd_seq; for (uint32_t s = 0; s < ubatch.n_seqs; ++s) { @@ -17211,7 +17211,6 @@ static int llama_encode_internal( case LLAMA_POOLING_TYPE_MEAN: case LLAMA_POOLING_TYPE_CLS: case LLAMA_POOLING_TYPE_LAST: - case LLAMA_POOLING_TYPE_RANK: { // extract sequence embeddings auto & embd_seq_out = lctx.embd_seq; @@ -17228,6 +17227,13 @@ static int llama_encode_internal( ggml_backend_tensor_get_async(backend_embd, embd, embd_seq_out[seq_id].data(), (n_embd*seq_id)*sizeof(float), n_embd*sizeof(float)); } } break; + case LLAMA_POOLING_TYPE_RANK: + { + // TODO: this likely should be the same logic as in llama_decoder_internal, but better to + // wait for an encoder model that requires this pooling type in order to test it + // https://github.com/ggerganov/llama.cpp/pull/9510 + GGML_ABORT("RANK pooling not implemented yet"); + } case LLAMA_POOLING_TYPE_UNSPECIFIED: { GGML_ABORT("unknown pooling type"); From 7bde9a04525d4d0ab5b053649dfb5620daa956a2 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Wed, 25 Sep 2024 17:12:32 +0300 Subject: [PATCH 10/24] server : accept /rerank endpoint in addition to /v1/rerank [no ci] --- examples/server/server.cpp | 1 + 1 file changed, 1 insertion(+) diff --git a/examples/server/server.cpp b/examples/server/server.cpp index ce65164d1f2ac..084dea212cbd7 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -3296,6 +3296,7 @@ int main(int argc, char ** argv) { svr->Post("/embedding", handle_embeddings); // legacy svr->Post("/embeddings", handle_embeddings); svr->Post("/v1/embeddings", handle_embeddings); + svr->Post("/rerank", handle_rerank); svr->Post("/v1/rerank", handle_rerank); svr->Post("/tokenize", handle_tokenize); svr->Post("/detokenize", handle_detokenize); From c62a39d91eb6af72536300a304cd99cdcc75b7b6 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Wed, 25 Sep 2024 20:36:38 +0300 Subject: [PATCH 11/24] embedding : parse special tokens --- examples/embedding/embedding.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/embedding/embedding.cpp b/examples/embedding/embedding.cpp index 18d6512608901..36e4f2e4de9fd 100644 --- a/examples/embedding/embedding.cpp +++ b/examples/embedding/embedding.cpp @@ -135,7 +135,7 @@ int main(int argc, char ** argv) { // tokenize the prompts and trim std::vector> inputs; for (const auto & prompt : prompts) { - auto inp = ::llama_tokenize(ctx, prompt, true, false); + auto inp = ::llama_tokenize(ctx, prompt, true, true); if (inp.size() > n_batch) { LOG_ERR("%s: number of tokens in input line (%lld) exceeds batch size (%lld), increase batch size and re-run\n", __func__, (long long int) inp.size(), (long long int) n_batch); From 866c0113fbd070a42839f09042fed5e951af1b33 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Wed, 25 Sep 2024 20:39:25 +0300 Subject: [PATCH 12/24] jina : support v1 reranker --- convert_hf_to_gguf.py | 10 ++++++++++ convert_hf_to_gguf_update.py | 1 + gguf-py/gguf/constants.py | 1 + gguf-py/gguf/tensor_mapping.py | 1 + src/llama.cpp | 34 ++++++++++++++++++++++++++-------- 5 files changed, 39 insertions(+), 8 deletions(-) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index d69a0d9f8270c..c7885f85e951d 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -597,6 +597,9 @@ def get_vocab_base_pre(self, tokenizer) -> str: if chkhsh == "a8594e3edff7c29c003940395316294b2c623e09894deebbc65f33f1515df79e": # ref: https://huggingface.co/databricks/dbrx-base res = "dbrx" + if chkhsh == "c7699093ba4255a91e702aa38a596aa81669f3525dae06c2953267dde580f448": + # ref: https://huggingface.co/jinaai/jina-reranker-v1-tiny-en + res = "jina-v1-en" if chkhsh == "0876d13b50744004aa9aeae05e7b0647eac9d801b5ba4668afc01e709c15e19f": # ref: https://huggingface.co/jinaai/jina-embeddings-v2-base-en res = "jina-v2-en" @@ -3117,6 +3120,13 @@ def set_vocab(self): self.gguf_writer.add_add_bos_token(True) self.gguf_writer.add_add_eos_token(True) + def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + # if name starts with "bert.", remove the prefix + # e.g. https://huggingface.co/jinaai/jina-reranker-v1-tiny-en + if name.startswith("bert."): + name = name[5:] + + return super().modify_tensors(data_torch, name, bid) @Model.register("OpenELMForCausalLM") class OpenELMModel(Model): diff --git a/convert_hf_to_gguf_update.py b/convert_hf_to_gguf_update.py index 021f65abdc45d..527bc44e51036 100755 --- a/convert_hf_to_gguf_update.py +++ b/convert_hf_to_gguf_update.py @@ -81,6 +81,7 @@ class TOKENIZER_TYPE(IntEnum): {"name": "qwen2", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/Qwen/Qwen1.5-7B", }, {"name": "olmo", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/allenai/OLMo-1.7-7B-hf", }, {"name": "dbrx", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/databricks/dbrx-base", }, + {"name": "jina-v1-en", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/jinaai/jina-reranker-v1-tiny-en", }, {"name": "jina-v2-en", "tokt": TOKENIZER_TYPE.WPM, "repo": "https://huggingface.co/jinaai/jina-embeddings-v2-base-en", }, # WPM! {"name": "jina-v2-es", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/jinaai/jina-embeddings-v2-base-es", }, {"name": "jina-v2-de", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/jinaai/jina-embeddings-v2-base-de", }, diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index 414a3ae211dd4..f7e1290e626ca 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -647,6 +647,7 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.FFN_GATE, MODEL_TENSOR.FFN_DOWN, MODEL_TENSOR.LAYER_OUT_NORM, + MODEL_TENSOR.CLS, ], MODEL_ARCH.MPT: [ MODEL_TENSOR.TOKEN_EMBD, diff --git a/gguf-py/gguf/tensor_mapping.py b/gguf-py/gguf/tensor_mapping.py index 4a34a549de18b..48d359071b7a5 100644 --- a/gguf-py/gguf/tensor_mapping.py +++ b/gguf-py/gguf/tensor_mapping.py @@ -681,6 +681,7 @@ class TensorNameMap: ), MODEL_TENSOR.CLS: ( + "classifier", # jina "classifier.dense", # roberta ), diff --git a/src/llama.cpp b/src/llama.cpp index b7c0fa4f4bf23..39f592cd12111 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -828,6 +828,7 @@ static const std::map> LLM_TENSOR_NA { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + { LLM_TENSOR_CLS, "cls" }, }, }, { @@ -5590,11 +5591,11 @@ static void llm_load_hparams( ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); ml.get_key(LLM_KV_ATTENTION_CAUSAL, hparams.causal_attn); ml.get_key(LLM_KV_TOKENIZER_TOKEN_TYPE_COUNT, hparams.n_vocab_type); - ml.get_key(LLM_KV_POOLING_TYPE, hparams.pooling_type); + ml.get_key(LLM_KV_POOLING_TYPE, hparams.pooling_type, false); hparams.f_max_alibi_bias = 8.0f; switch (hparams.n_layer) { - case 4: model.type = e_model::MODEL_33M; break; // jina-embeddings-small + case 4: model.type = e_model::MODEL_33M; break; // jina-embeddings-small case 12: model.type = e_model::MODEL_137M; break; // jina-embeddings-base } } break; @@ -6287,6 +6288,7 @@ static void llm_load_vocab( tokenizer_pre == "phi-2" || tokenizer_pre == "jina-es" || tokenizer_pre == "jina-de" || + tokenizer_pre == "jina-v1-en" || tokenizer_pre == "jina-v2-es" || tokenizer_pre == "jina-v2-de" || tokenizer_pre == "jina-v2-code") { @@ -6408,7 +6410,12 @@ static void llm_load_vocab( for (uint32_t i = 0; i < n_vocab; i++) { std::string word = gguf_get_arr_str(ctx, token_idx, i); - GGML_ASSERT(unicode_cpts_from_utf8(word).size() > 0); + + //GGML_ASSERT(unicode_cpts_from_utf8(word).size() > 0); + if (word.empty()) { + LLAMA_LOG_WARN("%s: empty token at index %u\n", __func__, i); + word = "[EMPTY_" + std::to_string(i) + "]"; + } vocab.token_to_id[word] = i; vocab.max_token_len = std::max(vocab.max_token_len, (int) word.size()); @@ -6487,8 +6494,14 @@ static void llm_load_vocab( vocab.linefeed_id = ids[0]; } else { const std::vector ids = llama_tokenize_internal(vocab, "\xC4\x8A", false); // U+010A - GGML_ASSERT(!ids.empty() && "model vocab missing newline token"); - vocab.linefeed_id = ids[0]; + + //GGML_ASSERT(!ids.empty() && "model vocab missing newline token"); + if (ids.empty()) { + LLAMA_LOG_WARN("%s: model vocab missing newline token, using special_pad_id instead\n", __func__); + vocab.linefeed_id = vocab.special_pad_id; + } else { + vocab.linefeed_id = ids[0]; + } } // special tokens @@ -7419,6 +7432,8 @@ static bool llm_load_tensors( model.tok_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight"), {n_embd}); // LayerNorm model.tok_norm_b = ml.create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD_NORM, "bias"), {n_embd}); //LayerNorm bias + model.cls = ml.create_tensor(ctx_output, tn(LLM_TENSOR_CLS, "weight"), {n_embd, 1}, llama_model_loader::TENSOR_NOT_REQUIRED); + model.cls_b = ml.create_tensor(ctx_output, tn(LLM_TENSOR_CLS, "bias"), {1}, llama_model_loader::TENSOR_NOT_REQUIRED); for (int i = 0; i < n_layer; ++i) { ggml_context * ctx_layer = ctx_for_layer(i); ggml_context * ctx_split = ctx_for_layer_split(i); @@ -10237,12 +10252,15 @@ struct llm_build_context { // https://github.com/huggingface/transformers/blob/5af7d41e49bbfc8319f462eb45253dcb3863dfb7/src/transformers/models/roberta/modeling_roberta.py#L1566 GGML_ASSERT(model.cls != nullptr); GGML_ASSERT(model.cls_b != nullptr); - GGML_ASSERT(model.cls_out != nullptr); - GGML_ASSERT(model.cls_out_b != nullptr); cur = ggml_add (ctx0, ggml_mul_mat(ctx0, model.cls, inp), model.cls_b); cur = ggml_tanh(ctx0, cur); - cur = ggml_add (ctx0, ggml_mul_mat(ctx0, model.cls_out, cur), model.cls_out_b); + + if (model.cls_out) { + GGML_ASSERT(model.cls_out_b != nullptr); + + cur = ggml_add (ctx0, ggml_mul_mat(ctx0, model.cls_out, cur), model.cls_out_b); + } } break; default: { From 84f56f3c45da31371bfe9674d6d23821aa3bbc49 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Wed, 25 Sep 2024 20:39:37 +0300 Subject: [PATCH 13/24] vocab : minor style ggml-ci --- src/llama-vocab.cpp | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/src/llama-vocab.cpp b/src/llama-vocab.cpp index a771eccda3017..5be1ec5e130ed 100644 --- a/src/llama-vocab.cpp +++ b/src/llama-vocab.cpp @@ -1477,7 +1477,7 @@ std::vector llama_tokenize_internal(const llama_vocab & vocab, { llm_tokenizer_ugm tokenizer(vocab); - if (add_special && vocab.tokenizer_add_bos != 0) { + if (add_special && vocab.tokenizer_add_bos) { GGML_ASSERT(vocab.special_bos_id != -1); output.push_back(vocab.special_bos_id); } @@ -1494,14 +1494,14 @@ std::vector llama_tokenize_internal(const llama_vocab & vocab, } } - if (add_special && vocab.tokenizer_add_bos != 0 && output.size() >= 2 && output[1] == vocab.special_bos_id) { + if (add_special && vocab.tokenizer_add_bos && output.size() >= 2 && output[1] == vocab.special_bos_id) { LLAMA_LOG_WARN( "%s: Added a BOS token to the prompt as specified by the model but the prompt " "also starts with a BOS token. So now the final prompt starts with 2 BOS tokens. " "Are you sure this is what you want?\n", __FUNCTION__); } - if (add_special && vocab.tokenizer_add_eos == 1) { + if (add_special && vocab.tokenizer_add_eos) { GGML_ASSERT(vocab.special_eos_id != -1); output.push_back(vocab.special_eos_id); } @@ -1713,11 +1713,13 @@ int32_t llama_token_to_piece_impl(const struct llama_vocab & vocab, llama_token // suppressing them like CONTROL tokens. if (attr & (attr_special | LLAMA_TOKEN_ATTR_USER_DEFINED)) { return _try_copy(token_text.data(), token_text.size()); - } else if (attr & LLAMA_TOKEN_ATTR_NORMAL) { + } + if (attr & LLAMA_TOKEN_ATTR_NORMAL) { std::string result = token_text; llama_unescape_whitespace(result); return _try_copy(result.data(), result.size()); - } else if (attr & LLAMA_TOKEN_ATTR_BYTE) { + } + if (attr & LLAMA_TOKEN_ATTR_BYTE) { char byte = (char) llama_token_to_byte(vocab, token); return _try_copy((char*) &byte, 1); } @@ -1728,7 +1730,8 @@ int32_t llama_token_to_piece_impl(const struct llama_vocab & vocab, llama_token // suppressing them like CONTROL tokens. if (attr & (attr_special | LLAMA_TOKEN_ATTR_USER_DEFINED)) { return _try_copy(token_text.data(), token_text.size()); - } else if (attr & LLAMA_TOKEN_ATTR_NORMAL) { + } + if (attr & LLAMA_TOKEN_ATTR_NORMAL) { std::string result = llama_decode_text(token_text); return _try_copy(result.data(), result.size()); } From 00b33760aa141e05379f5f5541bd1205e3bf52ee Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Thu, 26 Sep 2024 13:17:22 +0300 Subject: [PATCH 14/24] server : initiate tests for later ggml-ci --- convert_hf_to_gguf.py | 1 + examples/server/tests/features/rerank.feature | 39 +++++++++++++++++++ 2 files changed, 40 insertions(+) create mode 100644 examples/server/tests/features/rerank.feature diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index c7885f85e951d..9be6e17e244bc 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -3128,6 +3128,7 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter return super().modify_tensors(data_torch, name, bid) + @Model.register("OpenELMForCausalLM") class OpenELMModel(Model): model_arch = gguf.MODEL_ARCH.OPENELM diff --git a/examples/server/tests/features/rerank.feature b/examples/server/tests/features/rerank.feature new file mode 100644 index 0000000000000..e6981673d2d2c --- /dev/null +++ b/examples/server/tests/features/rerank.feature @@ -0,0 +1,39 @@ +@llama.cpp +@rerank +Feature: llama.cpp server + + Background: Server startup + Given a server listening on localhost:8080 + And a model url https://huggingface.co/ggml-org/models/resolve/main/jina-reranker-v1-tiny-en/ggml-model-f16.gguf + And a model file jina-reranker-v1-tiny-en.gguf + And a model alias jina-reranker-v1-tiny-en + And 42 as server seed + And 2 slots + And 128 as batch size + And 128 as ubatch size + And 512 KV cache size + And embeddings extraction + Then the server is starting + Then the server is healthy + +# TODO: implement some tests +# https://github.com/ggerganov/llama.cpp/pull/9510 +# Scenario: Rerank +# Given a prompt: +# """ +# What is panda? +# """ +# And a prompt: +# """ +# Hi. +# """ +# And a prompt: +# """ +# It's a bear. +# """ +# And a prompt: +# """ +# The giant panda (Ailuropoda melanoleuca), sometimes called a panda bear or simply panda, is a bear species endemic to China. +# """ +# When reranking request +# Then reranking results are returned From 877a04ccff10f5987b5fb0cc453090c08875fc5a Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Thu, 26 Sep 2024 14:31:03 +0300 Subject: [PATCH 15/24] server : add docs --- common/arg.cpp | 2 +- examples/server/README.md | 36 +++++++++++++++++++++++++++++++++++- examples/server/server.cpp | 2 ++ 3 files changed, 38 insertions(+), 2 deletions(-) diff --git a/common/arg.cpp b/common/arg.cpp index 12f05cc20cb4c..de8e6bac8ca59 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -1093,7 +1093,7 @@ gpt_params_context gpt_params_parser_init(gpt_params & params, llama_example ex, } ).set_sparam()); add_opt(llama_arg( - {"--pooling"}, "{none,mean,cls,last, rank}", + {"--pooling"}, "{none,mean,cls,last,rank}", "pooling type for embeddings, use model default if unspecified", [](gpt_params & params, const std::string & value) { /**/ if (value == "none") { params.pooling_type = LLAMA_POOLING_TYPE_NONE; } diff --git a/examples/server/README.md b/examples/server/README.md index dfca07f988824..2562680cb6c98 100644 --- a/examples/server/README.md +++ b/examples/server/README.md @@ -7,6 +7,7 @@ Set of LLM REST APIs and a simple web front end to interact with llama.cpp. **Features:** * LLM inference of F16 and quantized models on GPU and CPU * [OpenAI API](https://github.com/openai/openai-openapi) compatible chat completions and embeddings routes + * Reranking endoint (WIP: https://github.com/ggerganov/llama.cpp/pull/9510) * Parallel decoding with multi-user support * Continuous batching * Multimodal (wip) @@ -130,7 +131,7 @@ The project is under active development, and we are [looking for feedback and co | `--no-context-shift` | disables context shift on inifinite text generation (default: disabled)
(env: LLAMA_ARG_NO_CONTEXT_SHIFT) | | `-sp, --special` | special tokens output enabled (default: false) | | `--spm-infill` | use Suffix/Prefix/Middle pattern for infill (instead of Prefix/Suffix/Middle) as some models prefer this. (default: disabled) | -| `--pooling {none,mean,cls,last}` | pooling type for embeddings, use model default if unspecified
(env: LLAMA_ARG_POOLING) | +| `--pooling {none,mean,cls,last,rank}` | pooling type for embeddings, use model default if unspecified
(env: LLAMA_ARG_POOLING) | | `-cb, --cont-batching` | enable continuous batching (a.k.a dynamic batching) (default: enabled)
(env: LLAMA_ARG_CONT_BATCHING) | | `-nocb, --no-cont-batching` | disable continuous batching
(env: LLAMA_ARG_NO_CONT_BATCHING) | | `-a, --alias STRING` | set alias for model name (to be used by REST API)
(env: LLAMA_ARG_ALIAS) | @@ -478,6 +479,39 @@ The same as [the embedding example](../embedding) does. `image_data`: An array of objects to hold base64-encoded image `data` and its `id`s to be reference in `content`. You can determine the place of the image in the content as in the following: `Image: [img-21].\nCaption: This is a picture of a house`. In this case, `[img-21]` will be replaced by the embeddings of the image with id `21` in the following `image_data` array: `{..., "image_data": [{"data": "", "id": 21}]}`. Use `image_data` only with multimodal models, e.g., LLaVA. +### POST `/reranking`: Rerank documents according to a given query + +Similar to https://jina.ai/reranker/ but might change in the future. +Requires a reranker model (such as [bge-reranker-v2-m3](https://huggingface.co/BAAI/bge-reranker-v2-m3)) and the `--embedding --pooling rank` options. + + *Options:* + + `query`: The query against which the documents will be ranked. + + `documents`: An array strings representing the documents to be ranked. + + *Aliases:* + - `/rerank` + - `/v1/rerank` + - `/v1/reranking` + + *Examples:* + + ```shell + curl http://127.0.0.1:8012/v1/rerank \ + -H "Content-Type: application/json" \ + -d '{ + "model": "some-model", + "query": "What is panda?", + "top_n": 3, + "documents": [ + "hi", + "it is a bear", + "The giant panda (Ailuropoda melanoleuca), sometimes called a panda bear or simply panda, is a bear species endemic to China." + ] + }' | jq + ``` + ### POST `/infill`: For code infilling. Takes a prefix and a suffix and returns the predicted completion as stream. diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 084dea212cbd7..726d4a7e37fa5 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -3297,7 +3297,9 @@ int main(int argc, char ** argv) { svr->Post("/embeddings", handle_embeddings); svr->Post("/v1/embeddings", handle_embeddings); svr->Post("/rerank", handle_rerank); + svr->Post("/reranking", handle_rerank); svr->Post("/v1/rerank", handle_rerank); + svr->Post("/v1/reranking", handle_rerank); svr->Post("/tokenize", handle_tokenize); svr->Post("/detokenize", handle_detokenize); // LoRA adapters hotswap From 4d457755c08c4a59a3436959888da81399a25e5d Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Thu, 26 Sep 2024 14:36:14 +0300 Subject: [PATCH 16/24] llama : add comment [no ci] --- src/llama.cpp | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/llama.cpp b/src/llama.cpp index 39f592cd12111..c85e3033c125e 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -10257,6 +10257,8 @@ struct llm_build_context { cur = ggml_tanh(ctx0, cur); if (model.cls_out) { + // this path is taken for example by the https://huggingface.co/jinaai/jina-reranker-v1-tiny-en + // https://huggingface.co/jinaai/jina-reranker-v1-tiny-en/blob/cb5347e43979c3084a890e3f99491952603ae1b7/modeling_bert.py#L884-L896 GGML_ASSERT(model.cls_out_b != nullptr); cur = ggml_add (ctx0, ggml_mul_mat(ctx0, model.cls_out, cur), model.cls_out_b); From ca99a6ce70bf4fcb50a70b7efa9d6dbd1a93f22e Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Thu, 26 Sep 2024 15:20:11 +0300 Subject: [PATCH 17/24] llama : fix uninitialized tensors --- src/llama.cpp | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/llama.cpp b/src/llama.cpp index c85e3033c125e..179c0f977e4a6 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -2875,6 +2875,7 @@ struct llama_model { llama_hparams hparams = {}; llama_vocab vocab; + // TODO: should init all tensors to nullptr struct ggml_tensor * tok_embd; struct ggml_tensor * type_embd; struct ggml_tensor * pos_embd; @@ -2890,8 +2891,8 @@ struct llama_model { // classifier struct ggml_tensor * cls; struct ggml_tensor * cls_b; - struct ggml_tensor * cls_out; - struct ggml_tensor * cls_out_b; + struct ggml_tensor * cls_out = nullptr; + struct ggml_tensor * cls_out_b = nullptr; std::vector layers; From f19554f45390b066dee502cb1301f591b06aaf70 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Thu, 26 Sep 2024 15:20:32 +0300 Subject: [PATCH 18/24] ci : add rerank tests ggml-ci --- ci/run.sh | 85 ++++++++++++++++++++++++++++---- examples/embedding/embedding.cpp | 1 + 2 files changed, 77 insertions(+), 9 deletions(-) diff --git a/ci/run.sh b/ci/run.sh index 1ac08ee4e19a8..7d241ecc0ea06 100755 --- a/ci/run.sh +++ b/ci/run.sh @@ -712,6 +712,81 @@ function gg_run_embd_bge_small { set +e } +function gg_sum_embd_bge_small { + gg_printf '### %s\n\n' "${ci}" + + gg_printf 'BGE Small (BERT):\n' + gg_printf '- status: %s\n' "$(cat $OUT/${ci}.exit)" + gg_printf '- f16: \n```\n%s\n```\n' "$(cat $OUT/${ci}-tg-f16.log)" + gg_printf '- q8_0:\n```\n%s\n```\n' "$(cat $OUT/${ci}-tg-q8_0.log)" +} + +# rerank_tiny + +function gg_run_rerank_tiny { + cd ${SRC} + + gg_wget models-mnt/rerank-tiny/ https://huggingface.co/jinaai/jina-reranker-v1-tiny-en/raw/main/config.json + gg_wget models-mnt/rerank-tiny/ https://huggingface.co/jinaai/jina-reranker-v1-tiny-en/raw/main/tokenizer.json + gg_wget models-mnt/rerank-tiny/ https://huggingface.co/jinaai/jina-reranker-v1-tiny-en/raw/main/tokenizer_config.json + gg_wget models-mnt/rerank-tiny/ https://huggingface.co/jinaai/jina-reranker-v1-tiny-en/raw/main/special_tokens_map.json + gg_wget models-mnt/rerank-tiny/ https://huggingface.co/jinaai/jina-reranker-v1-tiny-en/resolve/main/pytorch_model.bin + gg_wget models-mnt/rerank-tiny/ https://huggingface.co/jinaai/jina-reranker-v1-tiny-en/raw/main/sentence_bert_config.json + gg_wget models-mnt/rerank-tiny/ https://huggingface.co/jinaai/jina-reranker-v1-tiny-en/raw/main/vocab.txt + gg_wget models-mnt/rerank-tiny/ https://huggingface.co/jinaai/jina-reranker-v1-tiny-en/raw/main/modules.json + gg_wget models-mnt/rerank-tiny/ https://huggingface.co/jinaai/jina-reranker-v1-tiny-en/raw/main/config.json + + gg_wget models-mnt/rerank-tiny/1_Pooling https://huggingface.co/jinaai/jina-reranker-v1-tiny-en/raw/main/1_Pooling/config.json + + path_models="../models-mnt/rerank-tiny" + + rm -rf build-ci-release && mkdir build-ci-release && cd build-ci-release + + set -e + + (time cmake -DCMAKE_BUILD_TYPE=Release ${CMAKE_EXTRA} .. ) 2>&1 | tee -a $OUT/${ci}-cmake.log + (time make -j$(nproc) ) 2>&1 | tee -a $OUT/${ci}-make.log + + python3 ../convert_hf_to_gguf.py ${path_models} --outfile ${path_models}/ggml-model-f16.gguf + + model_f16="${path_models}/ggml-model-f16.gguf" + + (time ./bin/llama-embedding --model ${model_f16} -p "what is panda?hi\nwhat is panda?it's a bear\nwhat is panda?The giant panda (Ailuropoda melanoleuca), sometimes called a panda bear or simply panda, is a bear species endemic to China." --pooling rank --embd-normalize -1 --verbose-prompt) 2>&1 | tee -a $OUT/${ci}-rk-f16.log + + # sample output + # rerank score 0: 0.029 + # rerank score 1: 0.029 + # rerank score 2: 0.135 + + # check that the score is in the range [$3, $4] + function check_score { + qnt="$1" + score=$(echo "$2" | grep -oE "[0-9]+\.[0-9]+" | tail -n 1) + + if [ $(echo "$score < $3" | bc) -eq 1 ] || [ $(echo "$score > $4" | bc) -eq 1 ]; then + printf ' - %s @ %s (FAIL: score not in range [%s, %s])\n' "$qnt" "$score" "$3" "$4" + return 20 + fi + + printf ' - %s @ %s OK\n' "$qnt" "$score" + return 0 + } + + check_score "rerank score 0" "$(cat $OUT/${ci}-rk-f16.log | grep "rerank score 0")" "0.00" "0.05" | tee -a $OUT/${ci}-rk-f16.log + check_score "rerank score 1" "$(cat $OUT/${ci}-rk-f16.log | grep "rerank score 1")" "0.00" "0.05" | tee -a $OUT/${ci}-rk-f16.log + check_score "rerank score 2" "$(cat $OUT/${ci}-rk-f16.log | grep "rerank score 2")" "0.10" "0.15" | tee -a $OUT/${ci}-rk-f16.log + + set +e +} + +function gg_sum_rerank_tiny { + gg_printf '### %s\n\n' "${ci}" + + gg_printf 'Rerank Tiny (Jina):\n' + gg_printf '- status: %s\n' "$(cat $OUT/${ci}.exit)" + gg_printf '- f16: \n```\n%s\n```\n' "$(cat $OUT/${ci}-rk-f16.log)" +} + function gg_check_build_requirements { if ! command -v cmake &> /dev/null; then gg_printf 'cmake not found, please install' @@ -726,15 +801,6 @@ function gg_check_build_requirements { fi } -function gg_sum_embd_bge_small { - gg_printf '### %s\n\n' "${ci}" - - gg_printf 'BGE Small (BERT):\n' - gg_printf '- status: %s\n' "$(cat $OUT/${ci}.exit)" - gg_printf '- f16: \n```\n%s\n```\n' "$(cat $OUT/${ci}-tg-f16.log)" - gg_printf '- q8_0:\n```\n%s\n```\n' "$(cat $OUT/${ci}-tg-q8_0.log)" -} - ## main export LLAMA_LOG_PREFIX=1 @@ -762,6 +828,7 @@ test $ret -eq 0 && gg_run ctest_release if [ -z ${GG_BUILD_LOW_PERF} ]; then test $ret -eq 0 && gg_run embd_bge_small + test $ret -eq 0 && gg_run rerank_tiny if [ -z ${GG_BUILD_CLOUD} ] || [ ${GG_BUILD_EXTRA_TESTS_0} ]; then test $ret -eq 0 && gg_run test_scripts_debug diff --git a/examples/embedding/embedding.cpp b/examples/embedding/embedding.cpp index 36e4f2e4de9fd..7349268223827 100644 --- a/examples/embedding/embedding.cpp +++ b/examples/embedding/embedding.cpp @@ -236,6 +236,7 @@ int main(int argc, char ** argv) { } } else if (pooling_type == LLAMA_POOLING_TYPE_RANK) { for (int j = 0; j < n_embd_count; j++) { + // NOTE: if you change this log - update the tests in ci/run.sh LOG("rerank score %d: %8.3f\n", j, emb[j * n_embd]); } } else { From f27dd6990dff42676f4443c4f0180c94f136d408 Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Thu, 26 Sep 2024 17:43:02 +0200 Subject: [PATCH 19/24] add reranking test --- examples/server/tests/features/rerank.feature | 49 ++++++++++--------- examples/server/tests/features/steps/steps.py | 40 +++++++++++++++ 2 files changed, 66 insertions(+), 23 deletions(-) diff --git a/examples/server/tests/features/rerank.feature b/examples/server/tests/features/rerank.feature index e6981673d2d2c..72616d6c5d4c4 100644 --- a/examples/server/tests/features/rerank.feature +++ b/examples/server/tests/features/rerank.feature @@ -9,31 +9,34 @@ Feature: llama.cpp server And a model alias jina-reranker-v1-tiny-en And 42 as server seed And 2 slots - And 128 as batch size - And 128 as ubatch size + And 512 as batch size + And 512 as ubatch size And 512 KV cache size And embeddings extraction Then the server is starting Then the server is healthy -# TODO: implement some tests -# https://github.com/ggerganov/llama.cpp/pull/9510 -# Scenario: Rerank -# Given a prompt: -# """ -# What is panda? -# """ -# And a prompt: -# """ -# Hi. -# """ -# And a prompt: -# """ -# It's a bear. -# """ -# And a prompt: -# """ -# The giant panda (Ailuropoda melanoleuca), sometimes called a panda bear or simply panda, is a bear species endemic to China. -# """ -# When reranking request -# Then reranking results are returned + Scenario: Rerank + Given a rerank query: + """ + Organic skincare products for sensitive skin + """ + And a rerank document: + """ + Organic skincare for sensitive skin with aloe vera and chamomile: Imagine the soothing embrace of nature with our organic skincare range, crafted specifically for sensitive skin. Infused with the calming properties of aloe vera and chamomile, each product provides gentle nourishment and protection. Say goodbye to irritation and hello to a glowing, healthy complexion. + """ + And a rerank document: + """ + New makeup trends focus on bold colors and innovative techniques: Step into the world of cutting-edge beauty with this seasons makeup trends. Bold, vibrant colors and groundbreaking techniques are redefining the art of makeup. From neon eyeliners to holographic highlighters, unleash your creativity and make a statement with every look. + """ + And a rerank document: + """ + Las nuevas tendencias de maquillaje se centran en colores vivos y técnicas innovadoras: Entra en el fascinante mundo del maquillaje con las tendencias más actuales. Colores vivos y técnicas innovadoras están revolucionando el arte del maquillaje. Desde delineadores neón hasta iluminadores holográficos, desata tu creatividad y destaca en cada look. + """ + And a rerank document: + """ + 新的化妆趋势注重鲜艳的颜色和创新的技巧:进入化妆艺术的新纪元,本季的化妆趋势以大胆的颜色和创新的技巧为主。无论是霓虹眼线还是全息高光,每一款妆容都能让您脱颖而出,展现独特魅力。 + """ + When reranking request + Then reranking results are returned + Then reranking highest score is index 2 diff --git a/examples/server/tests/features/steps/steps.py b/examples/server/tests/features/steps/steps.py index 0fea0fe87b799..9ae2ce67bafc9 100644 --- a/examples/server/tests/features/steps/steps.py +++ b/examples/server/tests/features/steps/steps.py @@ -83,6 +83,10 @@ def step_server_config(context, server_fqdn: str, server_port: str): context.concurrent_tasks = [] context.prompts = [] + context.reranking_query = None + context.reranking_documents = [] + context.reranking_results = None + @step('a model file {hf_file} from HF repo {hf_repo}') def step_download_hf_model(context, hf_file: str, hf_repo: str): @@ -452,6 +456,14 @@ def step_impl(context, n_ga_w): def step_prompt_passkey(context): context.prompt_passkey = context_text(context) +@step('a rerank query') +def step_set_rerank_query(context): + context.reranking_query = context_text(context) + context.reranking_documents = [] + +@step('a rerank document') +def step_set_rerank_document(context): + context.reranking_documents.append(context_text(context)) @step('{n_prompts:d} fixed prompts') def step_fixed_prompts(context, n_prompts): @@ -619,6 +631,22 @@ async def step_compute_embedding(context): context.embeddings = await request_embedding(context_text(context), None, base_url=context.base_url) +@step('reranking request') +@async_run_until_complete +async def step_compute_reranking(context): + async with aiohttp.ClientSession(timeout=DEFAULT_TIMEOUT_SECONDS) as session: + async with session.post(f'{context.base_url}/reranking', + json={ + "query": context.reranking_query, + "documents": context.reranking_documents, + }) as response: + if response.status == 200: + response_json = await response.json() + context.reranking_results = response_json['results'] + else: + context.reranking_results = response.status + + @step('all embeddings are the same') @async_run_until_complete async def step_all_embeddings_are_the_same(context): @@ -704,6 +732,18 @@ async def all_embeddings_are_generated(context): for i in range(n_embedding_requests): assert_embeddings(context.tasks_result.pop().pop()) +@step('reranking results are returned') +def reranking_results_are_returned(context): + assert len(context.reranking_results) == len(context.reranking_documents) + +@step('reranking highest score is index {idx:d}') +def reranking_results_are_returned(context, idx: int): + max_score, max_idx = 0, 0 + for res in context.reranking_results: + if max_score < res['relevance_score']: + max_score = res['relevance_score'] + max_idx = res['index'] + assert max_idx == idx @step('adding special tokens') def step_tokenize_set_add_special(context): From 1ae8376d401d3cb1b008242a2183100e80a997cd Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Thu, 26 Sep 2024 17:52:57 +0200 Subject: [PATCH 20/24] change test data --- examples/server/tests/features/rerank.feature | 13 +++++++------ examples/server/tests/features/steps/steps.py | 12 +++++++++--- 2 files changed, 16 insertions(+), 9 deletions(-) diff --git a/examples/server/tests/features/rerank.feature b/examples/server/tests/features/rerank.feature index 72616d6c5d4c4..79e00f689e990 100644 --- a/examples/server/tests/features/rerank.feature +++ b/examples/server/tests/features/rerank.feature @@ -19,24 +19,25 @@ Feature: llama.cpp server Scenario: Rerank Given a rerank query: """ - Organic skincare products for sensitive skin + Machine learning is """ And a rerank document: """ - Organic skincare for sensitive skin with aloe vera and chamomile: Imagine the soothing embrace of nature with our organic skincare range, crafted specifically for sensitive skin. Infused with the calming properties of aloe vera and chamomile, each product provides gentle nourishment and protection. Say goodbye to irritation and hello to a glowing, healthy complexion. + A machine is a physical system that uses power to apply forces and control movement to perform an action. The term is commonly applied to artificial devices, such as those employing engines or motors, but also to natural biological macromolecules, such as molecular machines. """ And a rerank document: """ - New makeup trends focus on bold colors and innovative techniques: Step into the world of cutting-edge beauty with this seasons makeup trends. Bold, vibrant colors and groundbreaking techniques are redefining the art of makeup. From neon eyeliners to holographic highlighters, unleash your creativity and make a statement with every look. + Learning is the process of acquiring new understanding, knowledge, behaviors, skills, values, attitudes, and preferences. The ability to learn is possessed by humans, non-human animals, and some machines; there is also evidence for some kind of learning in certain plants. """ And a rerank document: """ - Las nuevas tendencias de maquillaje se centran en colores vivos y técnicas innovadoras: Entra en el fascinante mundo del maquillaje con las tendencias más actuales. Colores vivos y técnicas innovadoras están revolucionando el arte del maquillaje. Desde delineadores neón hasta iluminadores holográficos, desata tu creatividad y destaca en cada look. + Machine learning is a field of study in artificial intelligence concerned with the development and study of statistical algorithms that can learn from data and generalize to unseen data, and thus perform tasks without explicit instructions. """ And a rerank document: """ - 新的化妆趋势注重鲜艳的颜色和创新的技巧:进入化妆艺术的新纪元,本季的化妆趋势以大胆的颜色和创新的技巧为主。无论是霓虹眼线还是全息高光,每一款妆容都能让您脱颖而出,展现独特魅力。 + Paris, capitale de la France, est une grande ville européenne et un centre mondial de l'art, de la mode, de la gastronomie et de la culture. Son paysage urbain du XIXe siècle est traversé par de larges boulevards et la Seine. """ When reranking request Then reranking results are returned - Then reranking highest score is index 2 + # TODO: this result make no sense, probably need a better model? + Then reranking highest score is index 3 and lowest score is index 0 diff --git a/examples/server/tests/features/steps/steps.py b/examples/server/tests/features/steps/steps.py index 9ae2ce67bafc9..1fe101673f306 100644 --- a/examples/server/tests/features/steps/steps.py +++ b/examples/server/tests/features/steps/steps.py @@ -736,14 +736,20 @@ async def all_embeddings_are_generated(context): def reranking_results_are_returned(context): assert len(context.reranking_results) == len(context.reranking_documents) -@step('reranking highest score is index {idx:d}') -def reranking_results_are_returned(context, idx: int): +@step('reranking highest score is index {idx_high:d} and lowest score is index {idx_low:d}') +def reranking_results_are_returned(context, idx_high: int, idx_low: int): max_score, max_idx = 0, 0 + min_score, min_idx = 0, 0 for res in context.reranking_results: if max_score < res['relevance_score']: max_score = res['relevance_score'] max_idx = res['index'] - assert max_idx == idx + if min_score > res['relevance_score']: + min_score = res['relevance_score'] + min_idx = res['index'] + print(context.reranking_results) + assert max_idx == idx_high + assert min_idx == idx_low @step('adding special tokens') def step_tokenize_set_add_special(context): From 84b0af8355e34b378e60fbf30b5a22be7ae65ce3 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Fri, 27 Sep 2024 10:46:37 +0300 Subject: [PATCH 21/24] Update examples/server/server.cpp Co-authored-by: Xuan Son Nguyen --- examples/server/server.cpp | 1 - 1 file changed, 1 deletion(-) diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 726d4a7e37fa5..613d55ccbdf7d 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -3144,7 +3144,6 @@ int main(int argc, char ** argv) { return; } } else { - exit(0); res_error(res, format_error_response("\"query\" must be provided", ERROR_TYPE_INVALID_REQUEST)); return; } From 0d6f6a799f3bc85ccef8a9f036670a6d3456c2b6 Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Fri, 27 Sep 2024 15:25:39 +0200 Subject: [PATCH 22/24] add `--reranking` argument --- common/arg.cpp | 11 +++++++ common/common.cpp | 5 ++++ common/common.h | 1 + examples/server/server.cpp | 29 ++++++++++--------- .../server/tests/features/embeddings.feature | 2 +- examples/server/tests/features/rerank.feature | 5 ++-- examples/server/tests/features/steps/steps.py | 8 ++++- 7 files changed, 43 insertions(+), 18 deletions(-) diff --git a/common/arg.cpp b/common/arg.cpp index de8e6bac8ca59..8266a16c261c5 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -284,6 +284,10 @@ static bool gpt_params_parse_ex(int argc, char ** argv, gpt_params_context & ctx params.kv_overrides.back().key[0] = 0; } + if (params.reranking && params.embedding) { + throw std::invalid_argument("error: either --embedding or --reranking can be specified, but not both"); + } + return true; } @@ -1750,6 +1754,13 @@ gpt_params_context gpt_params_parser_init(gpt_params & params, llama_example ex, params.embedding = true; } ).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_EMBEDDINGS")); + add_opt(llama_arg( + {"--reranking", "--rerank"}, + format("enable reranking endpoint on server (default: %s)", params.reranking ? "enabled" : "disabled"), + [](gpt_params & params) { + params.reranking = true; + } + ).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_RERANKING")); add_opt(llama_arg( {"--api-key"}, "KEY", "API key to use for authentication (default: none)", diff --git a/common/common.cpp b/common/common.cpp index 8d0ed4f95a737..e2b8574bf77d7 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -1023,6 +1023,11 @@ struct llama_context_params llama_context_params_from_gpt_params(const gpt_param cparams.flash_attn = params.flash_attn; cparams.no_perf = params.no_perf; + if (params.reranking) { + cparams.embeddings = true; + cparams.pooling_type = LLAMA_POOLING_TYPE_RANK; + } + cparams.type_k = kv_cache_type_from_str(params.cache_type_k); cparams.type_v = kv_cache_type_from_str(params.cache_type_v); diff --git a/common/common.h b/common/common.h index cb87c4479ed0a..8b84cf9ad45ee 100644 --- a/common/common.h +++ b/common/common.h @@ -271,6 +271,7 @@ struct gpt_params { int32_t embd_normalize = 2; // normalisation for embendings (-1=none, 0=max absolute int16, 1=taxicab, 2=euclidean, >2=p-norm) std::string embd_out = ""; // empty = default, "array" = [[],[]...], "json" = openai style, "json+" = same "json" + cosine similarity matrix std::string embd_sep = "\n"; // separator of embendings + bool reranking = false; // enable reranking support on server // server params int32_t port = 8080; // server listens on this network port diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 613d55ccbdf7d..f343cc252f89a 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -2888,8 +2888,8 @@ int main(int argc, char ** argv) { }; const auto handle_completions_generic = [&ctx_server, &res_error, &res_ok](server_task_cmpl_type cmpl_type, json & data, httplib::Response & res) { - if (ctx_server.params.embedding) { - res_error(res, format_error_response("This server does not support completions. Start it without `--embeddings`", ERROR_TYPE_NOT_SUPPORTED)); + if (ctx_server.params.embedding || ctx_server.params.reranking) { + res_error(res, format_error_response("This server does not support completions. Start it without `--embeddings` or `--reranking`", ERROR_TYPE_NOT_SUPPORTED)); return; } @@ -2949,8 +2949,8 @@ int main(int argc, char ** argv) { // TODO: maybe merge this function with "handle_completions_generic" const auto handle_chat_completions = [&ctx_server, ¶ms, &res_error, &res_ok, verbose](const httplib::Request & req, httplib::Response & res) { - if (ctx_server.params.embedding) { - res_error(res, format_error_response("This server does not support completions. Start it without `--embeddings`", ERROR_TYPE_NOT_SUPPORTED)); + if (ctx_server.params.embedding || ctx_server.params.reranking) { + res_error(res, format_error_response("This server does not support completions. Start it without `--embeddings` or `--reranking`", ERROR_TYPE_NOT_SUPPORTED)); return; } @@ -3074,6 +3074,11 @@ int main(int argc, char ** argv) { }; const auto handle_embeddings = [&ctx_server, &res_error, &res_ok](const httplib::Request & req, httplib::Response & res) { + // TODO: somehow clean up this checks in the future + if (!ctx_server.params.embedding || ctx_server.params.reranking) { + res_error(res, format_error_response("This server does not support embeddings. Start it with `--embeddings` and without `--reranking`", ERROR_TYPE_NOT_SUPPORTED)); + return; + } const json body = json::parse(req.body); bool is_openai = false; @@ -3125,6 +3130,10 @@ int main(int argc, char ** argv) { }; const auto handle_rerank = [&ctx_server, &res_error, &res_ok](const httplib::Request & req, httplib::Response & res) { + if (!ctx_server.params.reranking) { + res_error(res, format_error_response("This server does not support reranking. Start it with `--reranking`", ERROR_TYPE_NOT_SUPPORTED)); + return; + } const json body = json::parse(req.body); // TODO: implement @@ -3148,15 +3157,9 @@ int main(int argc, char ** argv) { return; } - json documents; - if (body.count("documents") != 0) { - documents = body.at("documents"); - if (!documents.is_array() || documents.size() == 0) { - res_error(res, format_error_response("\"documents\" must be a non-empty string array", ERROR_TYPE_INVALID_REQUEST)); - return; - } - } else { - res_error(res, format_error_response("\"documents\" must be provided", ERROR_TYPE_INVALID_REQUEST)); + std::vector documents = json_value(body, "documents", std::vector()); + if (documents.empty()) { + res_error(res, format_error_response("\"documents\" must be a non-empty string array", ERROR_TYPE_INVALID_REQUEST)); return; } diff --git a/examples/server/tests/features/embeddings.feature b/examples/server/tests/features/embeddings.feature index 818ea3beb90cd..f4fe2ee4335ff 100644 --- a/examples/server/tests/features/embeddings.feature +++ b/examples/server/tests/features/embeddings.feature @@ -15,7 +15,7 @@ Feature: llama.cpp server And 128 as batch size And 128 as ubatch size And 512 KV cache size - And embeddings extraction + And enable embeddings endpoint Then the server is starting Then the server is healthy diff --git a/examples/server/tests/features/rerank.feature b/examples/server/tests/features/rerank.feature index 79e00f689e990..c36cc8e215fa6 100644 --- a/examples/server/tests/features/rerank.feature +++ b/examples/server/tests/features/rerank.feature @@ -12,7 +12,7 @@ Feature: llama.cpp server And 512 as batch size And 512 as ubatch size And 512 KV cache size - And embeddings extraction + And enable reranking endpoint Then the server is starting Then the server is healthy @@ -39,5 +39,4 @@ Feature: llama.cpp server """ When reranking request Then reranking results are returned - # TODO: this result make no sense, probably need a better model? - Then reranking highest score is index 3 and lowest score is index 0 + Then reranking highest score is index 2 and lowest score is index 3 diff --git a/examples/server/tests/features/steps/steps.py b/examples/server/tests/features/steps/steps.py index 1fe101673f306..2611614ba3633 100644 --- a/examples/server/tests/features/steps/steps.py +++ b/examples/server/tests/features/steps/steps.py @@ -68,6 +68,7 @@ def step_server_config(context, server_fqdn: str, server_port: str): context.server_api_key = None context.server_continuous_batching = False context.server_embeddings = False + context.server_reranking = False context.server_metrics = False context.server_process = None context.seed = None @@ -176,10 +177,13 @@ def step_server_continuous_batching(context): context.server_continuous_batching = True -@step('embeddings extraction') +@step('enable embeddings endpoint') def step_server_embeddings(context): context.server_embeddings = True +@step('enable reranking endpoint') +def step_server_reranking(context): + context.server_reranking = True @step('prometheus compatible metrics exposed') def step_server_metrics(context): @@ -1408,6 +1412,8 @@ def start_server_background(context): server_args.append('--cont-batching') if context.server_embeddings: server_args.append('--embedding') + if context.server_reranking: + server_args.append('--reranking') if context.server_metrics: server_args.append('--metrics') if context.model_alias: From a4ac45f65991859ea7359931476fdc4a434cb7bb Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Fri, 27 Sep 2024 15:30:41 +0200 Subject: [PATCH 23/24] update server docs --- examples/server/README.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/examples/server/README.md b/examples/server/README.md index 2562680cb6c98..951c4a44c6058 100644 --- a/examples/server/README.md +++ b/examples/server/README.md @@ -24,6 +24,7 @@ The project is under active development, and we are [looking for feedback and co | -------- | ----------- | | `-h, --help, --usage` | print usage and exit | | `--version` | show version and build info | +| `--verbose-prompt` | print a verbose prompt before generation (default: false) | | `-t, --threads N` | number of threads to use during generation (default: -1)
(env: LLAMA_ARG_THREADS) | | `-tb, --threads-batch N` | number of threads to use during batch and prompt processing (default: same as --threads) | | `-C, --cpu-mask M` | CPU affinity mask: arbitrarily long hex. Complements cpu-range (default: "") | @@ -139,6 +140,7 @@ The project is under active development, and we are [looking for feedback and co | `--port PORT` | port to listen (default: 8080)
(env: LLAMA_ARG_PORT) | | `--path PATH` | path to serve static files from (default: )
(env: LLAMA_ARG_STATIC_PATH) | | `--embedding, --embeddings` | restrict to only support embedding use case; use only with dedicated embedding models (default: disabled)
(env: LLAMA_ARG_EMBEDDINGS) | +| `--reranking, --rerank` | enable reranking endpoint on server (default: disabled)
(env: LLAMA_ARG_RERANKING) | | `--api-key KEY` | API key to use for authentication (default: none)
(env: LLAMA_API_KEY) | | `--api-key-file FNAME` | path to file containing API keys (default: none) | | `--ssl-key-file FNAME` | path to file a PEM-encoded SSL private key
(env: LLAMA_ARG_SSL_KEY_FILE) | @@ -153,6 +155,7 @@ The project is under active development, and we are [looking for feedback and co | `-sps, --slot-prompt-similarity SIMILARITY` | how much the prompt of a request must match the prompt of a slot in order to use that slot (default: 0.50, 0.0 = disabled)
| | `--lora-init-without-apply` | load LoRA adapters without applying them (apply later via POST /lora-adapters) (default: disabled) | + Note: If both command line argument and environment variable are both set for the same param, the argument will take precedence over env var. Example usage of docker compose with environment variables: From 39167b69c02dfbc83ed512e5393d1bc7c4fd842e Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sat, 28 Sep 2024 14:51:57 +0300 Subject: [PATCH 24/24] llama : fix comment [no ci] ggml-ci --- src/llama.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/llama.cpp b/src/llama.cpp index 179c0f977e4a6..a592d890a2808 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -10257,9 +10257,9 @@ struct llm_build_context { cur = ggml_add (ctx0, ggml_mul_mat(ctx0, model.cls, inp), model.cls_b); cur = ggml_tanh(ctx0, cur); + // some models don't have `cls_out`, for example: https://huggingface.co/jinaai/jina-reranker-v1-tiny-en + // https://huggingface.co/jinaai/jina-reranker-v1-tiny-en/blob/cb5347e43979c3084a890e3f99491952603ae1b7/modeling_bert.py#L884-L896 if (model.cls_out) { - // this path is taken for example by the https://huggingface.co/jinaai/jina-reranker-v1-tiny-en - // https://huggingface.co/jinaai/jina-reranker-v1-tiny-en/blob/cb5347e43979c3084a890e3f99491952603ae1b7/modeling_bert.py#L884-L896 GGML_ASSERT(model.cls_out_b != nullptr); cur = ggml_add (ctx0, ggml_mul_mat(ctx0, model.cls_out, cur), model.cls_out_b);