|
6 | 6 | #include "llama-model.h"
|
7 | 7 | #include "llama-kv-cache.h"
|
8 | 8 |
|
| 9 | +#include <cinttypes> |
9 | 10 | #include <cstring>
|
| 11 | +#include <limits> |
10 | 12 | #include <stdexcept>
|
11 |
| -#include <cinttypes> |
12 | 13 |
|
13 | 14 | //
|
14 | 15 | // llama_context
|
@@ -632,6 +633,49 @@ bool llama_context::apply_adapter_cvec(
|
632 | 633 | return cvec.apply(model, data, len, n_embd, il_start, il_end);
|
633 | 634 | }
|
634 | 635 |
|
| 636 | +llm_graph_result_ptr llama_context::process(const llama_ubatch & ubatch, llm_graph_type gtype, ggml_status * ret) { |
| 637 | + auto * gf = graph_init(); |
| 638 | + if (!gf) { |
| 639 | + LLAMA_LOG_ERROR("%s: failed to initialize graph\n", __func__); |
| 640 | + if (ret) { |
| 641 | + *ret = GGML_STATUS_FAILED; |
| 642 | + } |
| 643 | + return nullptr; |
| 644 | + } |
| 645 | + |
| 646 | + auto res = graph_build(ctx_compute.get(), gf, ubatch, gtype); |
| 647 | + if (!res) { |
| 648 | + LLAMA_LOG_ERROR("%s: failed to build graph\n", __func__); |
| 649 | + if (ret) { |
| 650 | + *ret = GGML_STATUS_FAILED; |
| 651 | + } |
| 652 | + return nullptr; |
| 653 | + } |
| 654 | + |
| 655 | + // LLAMA_LOG_INFO("graph build time: %.3f ms (%d nodes, %d leafs)\n", (ggml_time_us() - t_start_us)/1000.0, gf->n_nodes, gf->n_leafs); |
| 656 | + |
| 657 | + if (!ggml_backend_sched_alloc_graph(sched.get(), gf)) { |
| 658 | + LLAMA_LOG_ERROR("%s: failed to allocate graph\n", __func__); |
| 659 | + if (ret) { |
| 660 | + *ret = GGML_STATUS_ALLOC_FAILED; |
| 661 | + } |
| 662 | + return nullptr; |
| 663 | + } |
| 664 | + |
| 665 | + res->set_inputs(&ubatch); |
| 666 | + |
| 667 | + const auto status = graph_compute(gf, ubatch.n_tokens > 1); |
| 668 | + if (status != GGML_STATUS_SUCCESS) { |
| 669 | + LLAMA_LOG_ERROR("%s: failed to compute graph, compute status: %d\n", __func__, status); |
| 670 | + if (ret) { |
| 671 | + *ret = status; |
| 672 | + } |
| 673 | + return nullptr; |
| 674 | + } |
| 675 | + |
| 676 | + return res; |
| 677 | +} |
| 678 | + |
635 | 679 | int llama_context::encode(llama_batch & inp_batch) {
|
636 | 680 | if (inp_batch.n_tokens == 0) {
|
637 | 681 | LLAMA_LOG_ERROR("%s: n_tokens == 0\n", __func__);
|
@@ -703,26 +747,18 @@ int llama_context::encode(llama_batch & inp_batch) {
|
703 | 747 | // ref: https://github.com/ggml-org/llama.cpp/pull/12181#issuecomment-2730451223
|
704 | 748 | cparams.causal_attn = false;
|
705 | 749 |
|
706 |
| - auto * gf = graph_init(); |
707 |
| - auto res = graph_build(ctx_compute.get(), gf, ubatch, LLM_GRAPH_TYPE_ENCODER); |
708 |
| - |
709 |
| - ggml_backend_sched_alloc_graph(sched.get(), gf); |
710 |
| - |
711 |
| - res->set_inputs(&ubatch); |
| 750 | + ggml_status status; |
| 751 | + auto res = process(ubatch, LLM_GRAPH_TYPE_ENCODER, &status); |
712 | 752 |
|
713 | 753 | cparams.causal_attn = causal_attn_org;
|
714 | 754 |
|
715 |
| - const auto compute_status = graph_compute(gf, n_tokens > 1); |
716 |
| - switch (compute_status) { |
717 |
| - case GGML_STATUS_SUCCESS: |
718 |
| - break; |
719 |
| - case GGML_STATUS_ABORTED: |
720 |
| - return 2; |
721 |
| - case GGML_STATUS_ALLOC_FAILED: |
722 |
| - return -2; |
723 |
| - case GGML_STATUS_FAILED: |
724 |
| - default: |
725 |
| - return -3; |
| 755 | + if (!res) { |
| 756 | + switch (status) { |
| 757 | + case GGML_STATUS_ABORTED: return 2; |
| 758 | + case GGML_STATUS_ALLOC_FAILED: return -2; |
| 759 | + case GGML_STATUS_FAILED: return -3; |
| 760 | + case GGML_STATUS_SUCCESS: GGML_ABORT("should not happen"); |
| 761 | + } |
726 | 762 | }
|
727 | 763 |
|
728 | 764 | auto * t_embd = res->get_embd_pooled() ? res->get_embd_pooled() : res->get_embd();
|
@@ -942,25 +978,34 @@ int llama_context::decode(llama_batch & inp_batch) {
|
942 | 978 | ggml_backend_sched_reset(sched.get());
|
943 | 979 | ggml_backend_sched_set_eval_callback(sched.get(), cparams.cb_eval, cparams.cb_eval_user_data);
|
944 | 980 |
|
945 |
| - auto * gf = graph_init(); |
946 |
| - auto res = graph_build(ctx_compute.get(), gf, ubatch, LLM_GRAPH_TYPE_DECODER); |
| 981 | + ggml_status status; |
| 982 | + auto res = process(ubatch, LLM_GRAPH_TYPE_DECODER, &status); |
| 983 | + |
| 984 | + if (!res) { |
| 985 | + // the last ubatch failed or was aborted -> remove all positions of that ubatch from the KV cache |
| 986 | + llama_pos pos_min[LLAMA_MAX_PARALLEL_SEQUENCES] = { std::numeric_limits<llama_pos>::max() }; |
| 987 | + |
| 988 | + for (uint32_t i = 0; i < ubatch.n_tokens; ++i) { |
| 989 | + const auto & seq_id = ubatch.seq_id[i][0]; |
947 | 990 |
|
948 |
| - // LLAMA_LOG_INFO("graph build time: %.3f ms (%d nodes, %d leafs)\n", (ggml_time_us() - t_start_us)/1000.0, gf->n_nodes, gf->n_leafs); |
| 991 | + pos_min[seq_id] = std::min(pos_min[seq_id], ubatch.pos[i]); |
| 992 | + } |
949 | 993 |
|
950 |
| - ggml_backend_sched_alloc_graph(sched.get(), gf); |
| 994 | + for (int s = 0; s < LLAMA_MAX_PARALLEL_SEQUENCES; ++s) { |
| 995 | + if (pos_min[s] == std::numeric_limits<llama_pos>::max()) { |
| 996 | + continue; |
| 997 | + } |
951 | 998 |
|
952 |
| - res->set_inputs(&ubatch); |
| 999 | + LLAMA_LOG_WARN("%s: removing KV cache entries for seq_id = %d, pos = [%d, +inf)\n", __func__, s, pos_min[s]); |
| 1000 | + |
| 1001 | + llama_kv_self_seq_rm(this, s, pos_min[s], -1); |
| 1002 | + } |
953 | 1003 |
|
954 |
| - const auto compute_status = graph_compute(gf, ubatch.n_tokens > 1); |
955 |
| - if (compute_status != GGML_STATUS_SUCCESS) { |
956 |
| - switch (compute_status) { |
957 |
| - case GGML_STATUS_ABORTED: |
958 |
| - return 2; |
959 |
| - case GGML_STATUS_ALLOC_FAILED: |
960 |
| - return -2; |
961 |
| - case GGML_STATUS_FAILED: |
962 |
| - default: |
963 |
| - return -3; |
| 1004 | + switch (status) { |
| 1005 | + case GGML_STATUS_ABORTED: return 2; |
| 1006 | + case GGML_STATUS_ALLOC_FAILED: return -2; |
| 1007 | + case GGML_STATUS_FAILED: return -3; |
| 1008 | + case GGML_STATUS_SUCCESS: GGML_ABORT("should not happen"); |
964 | 1009 | }
|
965 | 1010 | }
|
966 | 1011 |
|
|
0 commit comments