File tree 2 files changed +36
-4
lines changed 2 files changed +36
-4
lines changed Original file line number Diff line number Diff line change @@ -677,12 +677,14 @@ extern "C" {
677
677
678
678
// Returns the smallest position present in the KV cache for the specified sequence
679
679
// This is typically non-zero only for SWA caches
680
+ // Note that all positions in the range [pos_min, pos_max] are guaranteed to be present in the KV cache
680
681
// Return -1 if the sequence is empty
681
682
LLAMA_API llama_pos llama_kv_self_seq_pos_min (
682
683
struct llama_context * ctx,
683
684
llama_seq_id seq_id);
684
685
685
686
// Returns the largest position present in the KV cache for the specified sequence
687
+ // Note that all positions in the range [pos_min, pos_max] are guaranteed to be present in the KV cache
686
688
// Return -1 if the sequence is empty
687
689
LLAMA_API llama_pos llama_kv_self_seq_pos_max (
688
690
struct llama_context * ctx,
Original file line number Diff line number Diff line change 6
6
#include " llama-model.h"
7
7
#include " llama-kv-cache.h"
8
8
9
+ #include < cinttypes>
9
10
#include < cstring>
11
+ #include < limits>
10
12
#include < stdexcept>
11
- #include < cinttypes>
12
13
13
14
//
14
15
// llama_context
@@ -951,19 +952,48 @@ int llama_context::decode(llama_batch & inp_batch) {
951
952
952
953
res->set_inputs (&ubatch);
953
954
955
+ int ret = 0 ;
956
+
954
957
const auto compute_status = graph_compute (gf, ubatch.n_tokens > 1 );
955
958
if (compute_status != GGML_STATUS_SUCCESS) {
956
959
switch (compute_status) {
957
960
case GGML_STATUS_ABORTED:
958
- return 2 ;
961
+ {
962
+ ret = 2 ;
963
+ } break ;
959
964
case GGML_STATUS_ALLOC_FAILED:
960
- return -2 ;
965
+ {
966
+ ret = -2 ;
967
+ } break ;
961
968
case GGML_STATUS_FAILED:
962
969
default :
963
- return -3 ;
970
+ {
971
+ ret = -3 ;
972
+ }
964
973
}
965
974
}
966
975
976
+ if (ret != 0 ) {
977
+ // the last ubatch failed or was aborted -> remove all positions of that ubatch from the KV cache
978
+ llama_pos pos_min[LLAMA_MAX_PARALLEL_SEQUENCES] = { std::numeric_limits<llama_pos>::max () };
979
+
980
+ for (uint32_t i = 0 ; i < ubatch.n_tokens ; ++i) {
981
+ const auto & seq_id = ubatch.seq_id [i][0 ];
982
+
983
+ pos_min[seq_id] = std::min (pos_min[seq_id], ubatch.pos [i]);
984
+ }
985
+
986
+ for (int s = 0 ; s < LLAMA_MAX_PARALLEL_SEQUENCES; ++s) {
987
+ if (pos_min[s] == std::numeric_limits<llama_pos>::max ()) {
988
+ continue ;
989
+ }
990
+
991
+ llama_kv_self_seq_rm (this , s, pos_min[s], -1 );
992
+ }
993
+
994
+ return ret;
995
+ }
996
+
967
997
// plot the computation graph in dot format (for debugging purposes)
968
998
// if (n_past%100 == 0) {
969
999
// ggml_graph_dump_dot(gf, NULL, "llama.dot");
You can’t perform that action at this time.
0 commit comments