Skip to content

Commit de2ef53

Browse files
authored
kv-cache : rework kv_cell (#13706)
* kv-cache : rework kv_cell ggml-ci * kv-cells : use "shift" instead of "delta" consistently ggml-ci * llama : add llama_max_parallel_sequences() ggml-ci * kv-cells : update comments [no ci] * context : fail upon construction if sequences exceed max value ggml-ci * kv-cells : get_pos() -> pos_get() + comments ggml-ci * kv-cells : fix tracking of "used" cells ggml-ci
1 parent c508256 commit de2ef53

8 files changed

+470
-253
lines changed

include/llama.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -471,6 +471,7 @@ extern "C" {
471471
LLAMA_API int64_t llama_time_us(void);
472472

473473
LLAMA_API size_t llama_max_devices(void);
474+
LLAMA_API size_t llama_max_parallel_sequences(void);
474475

475476
LLAMA_API bool llama_supports_mmap (void);
476477
LLAMA_API bool llama_supports_mlock (void);

src/llama-context.cpp

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,11 @@ llama_context::llama_context(
2525

2626
const auto & hparams = model.hparams;
2727

28-
cparams.n_seq_max = std::max(1u, params.n_seq_max);
28+
cparams.n_seq_max = std::max(1u, params.n_seq_max);
29+
if (cparams.n_seq_max > LLAMA_MAX_PARALLEL_SEQUENCES) {
30+
throw std::runtime_error("n_seq_max must be <= " + std::to_string(LLAMA_MAX_PARALLEL_SEQUENCES));
31+
}
32+
2933
cparams.n_threads = params.n_threads;
3034
cparams.n_threads_batch = params.n_threads_batch;
3135
cparams.yarn_ext_factor = params.yarn_ext_factor;

src/llama-cparams.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,5 @@
11
#include "llama-cparams.h"
2+
3+
size_t llama_max_parallel_sequences(void) {
4+
return LLAMA_MAX_PARALLEL_SEQUENCES;
5+
}

src/llama-cparams.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44

55
#include <cstdint>
66

7+
#define LLAMA_MAX_PARALLEL_SEQUENCES 64
8+
79
struct llama_cparams {
810
uint32_t n_ctx; // context size used during inference
911
uint32_t n_batch;

0 commit comments

Comments
 (0)