Skip to content

Commit 7dc61c2

Browse files
committed
llama : handle aborts and compute errors
ggml-ci
1 parent 2252eef commit 7dc61c2

File tree

2 files changed

+36
-4
lines changed

2 files changed

+36
-4
lines changed

include/llama.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -677,12 +677,14 @@ extern "C" {
677677

678678
// Returns the smallest position present in the KV cache for the specified sequence
679679
// This is typically non-zero only for SWA caches
680+
// Note that all positions in the range [pos_min, pos_max] are guaranteed to be present in the KV cache
680681
// Return -1 if the sequence is empty
681682
LLAMA_API llama_pos llama_kv_self_seq_pos_min(
682683
struct llama_context * ctx,
683684
llama_seq_id seq_id);
684685

685686
// Returns the largest position present in the KV cache for the specified sequence
687+
// Note that all positions in the range [pos_min, pos_max] are guaranteed to be present in the KV cache
686688
// Return -1 if the sequence is empty
687689
LLAMA_API llama_pos llama_kv_self_seq_pos_max(
688690
struct llama_context * ctx,

src/llama-context.cpp

Lines changed: 34 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,10 @@
66
#include "llama-model.h"
77
#include "llama-kv-cache.h"
88

9+
#include <cinttypes>
910
#include <cstring>
11+
#include <limits>
1012
#include <stdexcept>
11-
#include <cinttypes>
1213

1314
//
1415
// llama_context
@@ -951,19 +952,48 @@ int llama_context::decode(llama_batch & inp_batch) {
951952

952953
res->set_inputs(&ubatch);
953954

955+
int ret = 0;
956+
954957
const auto compute_status = graph_compute(gf, ubatch.n_tokens > 1);
955958
if (compute_status != GGML_STATUS_SUCCESS) {
956959
switch (compute_status) {
957960
case GGML_STATUS_ABORTED:
958-
return 2;
961+
{
962+
ret = 2;
963+
} break;
959964
case GGML_STATUS_ALLOC_FAILED:
960-
return -2;
965+
{
966+
ret = -2;
967+
} break;
961968
case GGML_STATUS_FAILED:
962969
default:
963-
return -3;
970+
{
971+
ret = -3;
972+
}
964973
}
965974
}
966975

976+
if (ret != 0) {
977+
// the last ubatch failed or was aborted -> remove all positions of that ubatch from the KV cache
978+
llama_pos pos_min[LLAMA_MAX_PARALLEL_SEQUENCES] = { std::numeric_limits<llama_pos>::max() };
979+
980+
for (uint32_t i = 0; i < ubatch.n_tokens; ++i) {
981+
const auto & seq_id = ubatch.seq_id[i][0];
982+
983+
pos_min[seq_id] = std::min(pos_min[seq_id], ubatch.pos[i]);
984+
}
985+
986+
for (int s = 0; s < LLAMA_MAX_PARALLEL_SEQUENCES; ++s) {
987+
if (pos_min[s] == std::numeric_limits<llama_pos>::max()) {
988+
continue;
989+
}
990+
991+
llama_kv_self_seq_rm(this, s, pos_min[s], -1);
992+
}
993+
994+
return ret;
995+
}
996+
967997
// plot the computation graph in dot format (for debugging purposes)
968998
//if (n_past%100 == 0) {
969999
// ggml_graph_dump_dot(gf, NULL, "llama.dot");

0 commit comments

Comments
 (0)