Skip to content

Commit a3ebf0a

Browse files
committed
llama : handle aborts and compute errors
ggml-ci
1 parent 2252eef commit a3ebf0a

File tree

3 files changed

+89
-36
lines changed

3 files changed

+89
-36
lines changed

include/llama.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -677,12 +677,14 @@ extern "C" {
677677

678678
// Returns the smallest position present in the KV cache for the specified sequence
679679
// This is typically non-zero only for SWA caches
680+
// Note that all positions in the range [pos_min, pos_max] are guaranteed to be present in the KV cache
680681
// Return -1 if the sequence is empty
681682
LLAMA_API llama_pos llama_kv_self_seq_pos_min(
682683
struct llama_context * ctx,
683684
llama_seq_id seq_id);
684685

685686
// Returns the largest position present in the KV cache for the specified sequence
687+
// Note that all positions in the range [pos_min, pos_max] are guaranteed to be present in the KV cache
686688
// Return -1 if the sequence is empty
687689
LLAMA_API llama_pos llama_kv_self_seq_pos_max(
688690
struct llama_context * ctx,

src/llama-context.cpp

Lines changed: 78 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,10 @@
66
#include "llama-model.h"
77
#include "llama-kv-cache.h"
88

9+
#include <cinttypes>
910
#include <cstring>
11+
#include <limits>
1012
#include <stdexcept>
11-
#include <cinttypes>
1213

1314
//
1415
// llama_context
@@ -632,6 +633,49 @@ bool llama_context::apply_adapter_cvec(
632633
return cvec.apply(model, data, len, n_embd, il_start, il_end);
633634
}
634635

636+
llm_graph_result_ptr llama_context::process(const llama_ubatch & ubatch, llm_graph_type gtype, ggml_status * ret) {
637+
auto * gf = graph_init();
638+
if (!gf) {
639+
LLAMA_LOG_ERROR("%s: failed to initialize graph\n", __func__);
640+
if (ret) {
641+
*ret = GGML_STATUS_FAILED;
642+
}
643+
return nullptr;
644+
}
645+
646+
auto res = graph_build(ctx_compute.get(), gf, ubatch, gtype);
647+
if (!res) {
648+
LLAMA_LOG_ERROR("%s: failed to build graph\n", __func__);
649+
if (ret) {
650+
*ret = GGML_STATUS_FAILED;
651+
}
652+
return nullptr;
653+
}
654+
655+
// LLAMA_LOG_INFO("graph build time: %.3f ms (%d nodes, %d leafs)\n", (ggml_time_us() - t_start_us)/1000.0, gf->n_nodes, gf->n_leafs);
656+
657+
if (!ggml_backend_sched_alloc_graph(sched.get(), gf)) {
658+
LLAMA_LOG_ERROR("%s: failed to allocate graph\n", __func__);
659+
if (ret) {
660+
*ret = GGML_STATUS_ALLOC_FAILED;
661+
}
662+
return nullptr;
663+
}
664+
665+
res->set_inputs(&ubatch);
666+
667+
const auto status = graph_compute(gf, ubatch.n_tokens > 1);
668+
if (status != GGML_STATUS_SUCCESS) {
669+
LLAMA_LOG_ERROR("%s: failed to compute graph, compute status: %d\n", __func__, status);
670+
if (ret) {
671+
*ret = status;
672+
}
673+
return nullptr;
674+
}
675+
676+
return res;
677+
}
678+
635679
int llama_context::encode(llama_batch & inp_batch) {
636680
if (inp_batch.n_tokens == 0) {
637681
LLAMA_LOG_ERROR("%s: n_tokens == 0\n", __func__);
@@ -703,26 +747,18 @@ int llama_context::encode(llama_batch & inp_batch) {
703747
// ref: https://github.com/ggml-org/llama.cpp/pull/12181#issuecomment-2730451223
704748
cparams.causal_attn = false;
705749

706-
auto * gf = graph_init();
707-
auto res = graph_build(ctx_compute.get(), gf, ubatch, LLM_GRAPH_TYPE_ENCODER);
708-
709-
ggml_backend_sched_alloc_graph(sched.get(), gf);
710-
711-
res->set_inputs(&ubatch);
750+
ggml_status status;
751+
auto res = process(ubatch, LLM_GRAPH_TYPE_ENCODER, &status);
712752

713753
cparams.causal_attn = causal_attn_org;
714754

715-
const auto compute_status = graph_compute(gf, n_tokens > 1);
716-
switch (compute_status) {
717-
case GGML_STATUS_SUCCESS:
718-
break;
719-
case GGML_STATUS_ABORTED:
720-
return 2;
721-
case GGML_STATUS_ALLOC_FAILED:
722-
return -2;
723-
case GGML_STATUS_FAILED:
724-
default:
725-
return -3;
755+
if (!res) {
756+
switch (status) {
757+
case GGML_STATUS_ABORTED: return 2;
758+
case GGML_STATUS_ALLOC_FAILED: return -2;
759+
case GGML_STATUS_FAILED: return -3;
760+
case GGML_STATUS_SUCCESS: GGML_ABORT("should not happen");
761+
}
726762
}
727763

728764
auto * t_embd = res->get_embd_pooled() ? res->get_embd_pooled() : res->get_embd();
@@ -942,25 +978,34 @@ int llama_context::decode(llama_batch & inp_batch) {
942978
ggml_backend_sched_reset(sched.get());
943979
ggml_backend_sched_set_eval_callback(sched.get(), cparams.cb_eval, cparams.cb_eval_user_data);
944980

945-
auto * gf = graph_init();
946-
auto res = graph_build(ctx_compute.get(), gf, ubatch, LLM_GRAPH_TYPE_DECODER);
981+
ggml_status status;
982+
auto res = process(ubatch, LLM_GRAPH_TYPE_DECODER, &status);
983+
984+
if (!res) {
985+
// the last ubatch failed or was aborted -> remove all positions of that ubatch from the KV cache
986+
llama_pos pos_min[LLAMA_MAX_PARALLEL_SEQUENCES] = { std::numeric_limits<llama_pos>::max() };
987+
988+
for (uint32_t i = 0; i < ubatch.n_tokens; ++i) {
989+
const auto & seq_id = ubatch.seq_id[i][0];
947990

948-
// LLAMA_LOG_INFO("graph build time: %.3f ms (%d nodes, %d leafs)\n", (ggml_time_us() - t_start_us)/1000.0, gf->n_nodes, gf->n_leafs);
991+
pos_min[seq_id] = std::min(pos_min[seq_id], ubatch.pos[i]);
992+
}
949993

950-
ggml_backend_sched_alloc_graph(sched.get(), gf);
994+
for (int s = 0; s < LLAMA_MAX_PARALLEL_SEQUENCES; ++s) {
995+
if (pos_min[s] == std::numeric_limits<llama_pos>::max()) {
996+
continue;
997+
}
951998

952-
res->set_inputs(&ubatch);
999+
LLAMA_LOG_WARN("%s: removing KV cache entries for seq_id = %d, pos = [%d, +inf)\n", __func__, s, pos_min[s]);
1000+
1001+
llama_kv_self_seq_rm(this, s, pos_min[s], -1);
1002+
}
9531003

954-
const auto compute_status = graph_compute(gf, ubatch.n_tokens > 1);
955-
if (compute_status != GGML_STATUS_SUCCESS) {
956-
switch (compute_status) {
957-
case GGML_STATUS_ABORTED:
958-
return 2;
959-
case GGML_STATUS_ALLOC_FAILED:
960-
return -2;
961-
case GGML_STATUS_FAILED:
962-
default:
963-
return -3;
1004+
switch (status) {
1005+
case GGML_STATUS_ABORTED: return 2;
1006+
case GGML_STATUS_ALLOC_FAILED: return -2;
1007+
case GGML_STATUS_FAILED: return -3;
1008+
case GGML_STATUS_SUCCESS: GGML_ABORT("should not happen");
9641009
}
9651010
}
9661011

src/llama-context.h

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,14 @@ struct llama_context {
8989
int32_t il_start,
9090
int32_t il_end);
9191

92+
// process a single ubatch with a specific graph type
93+
// ret contains the status of the graph computation
94+
// returns nullptr only if ret != GGML_STATUS_SUCCESS
95+
llm_graph_result_ptr process(
96+
const llama_ubatch & ubatch,
97+
llm_graph_type gtype,
98+
ggml_status * ret);
99+
92100
int encode(llama_batch & inp_batch);
93101
int decode(llama_batch & inp_batch);
94102

@@ -181,9 +189,7 @@ struct llama_context {
181189
ggml_cgraph * graph_init();
182190

183191
// returns the result of ggml_backend_sched_graph_compute_async execution
184-
ggml_status graph_compute(
185-
ggml_cgraph * gf,
186-
bool batched);
192+
ggml_status graph_compute(ggml_cgraph * gf, bool batched);
187193

188194
// reserve a graph with a dummy ubatch of the specified size
189195
ggml_cgraph * graph_reserve(uint32_t n_tokens, uint32_t n_seqs, uint32_t n_outputs);

0 commit comments

Comments
 (0)