Skip to content

Commit 2252eef

Browse files
committed
kv-cache : improve find_slot() using min/max seq pos info
ggml-ci
1 parent 332f460 commit 2252eef

File tree

5 files changed

+63
-58
lines changed

5 files changed

+63
-58
lines changed

src/llama-batch.cpp

Lines changed: 0 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -4,21 +4,6 @@
44
#include <cstring>
55
#include <algorithm>
66

7-
void llama_ubatch::update() {
8-
if (equal_seqs) {
9-
// TODO: for now don't compute min/max for recurrent batches since we don't need this.
10-
// the batches will be refactored anyway, so we'll fix this later
11-
return;
12-
}
13-
14-
for (uint32_t i = 0; i < n_tokens; ++i) {
15-
const llama_seq_id s = seq_id[i][0];
16-
17-
seq_pos_min[s] = seq_pos_min[s] == -1 ? pos[i] : std::min(seq_pos_min[s], pos[i]);
18-
seq_pos_max[s] = seq_pos_max[s] == -1 ? pos[i] : std::max(seq_pos_max[s], pos[i]);
19-
}
20-
}
21-
227
llama_ubatch llama_sbatch::reserve_ubatch(size_t n_ubatch, bool has_embd) {
238
// clear empty sequences
249
// the previous ubatch is assumed to be gone,
@@ -47,8 +32,6 @@ llama_ubatch llama_sbatch::reserve_ubatch(size_t n_ubatch, bool has_embd) {
4732
/*n_tokens =*/ 0,
4833
/*n_seq_tokens =*/ 0,
4934
/*n_seqs =*/ 0,
50-
/*seq_pos_min =*/ {-1},
51-
/*seq_pos_max =*/ {-1},
5235
/*token =*/ !has_embd ? udata.token.data() : nullptr,
5336
/*embd =*/ has_embd ? udata.embd.data() : nullptr,
5437
/*pos =*/ udata.pos.data(),
@@ -172,7 +155,6 @@ llama_ubatch llama_sbatch::split_simple(size_t n_ubatch) {
172155
GGML_ASSERT(seq.size() == 1 && s.n_seq_id == 0); // don't mix with other splits
173156
add_seq_to_ubatch(ubatch, s, length);
174157
}
175-
ubatch.update();
176158
return ubatch;
177159
}
178160

@@ -200,7 +182,6 @@ llama_ubatch llama_sbatch::split_equal(size_t n_ubatch) {
200182
if (length + n_tokens_in_ubatch > n_ubatch) { break; }
201183
}
202184
}
203-
ubatch.update();
204185
return ubatch;
205186
}
206187

@@ -213,7 +194,6 @@ llama_ubatch llama_sbatch::split_seq(size_t n_ubatch) {
213194
GGML_ASSERT(s.n_seq_id > 0); // should not be mixed with simple splits
214195
add_seq_to_ubatch(ubatch, s, length);
215196
}
216-
ubatch.update();
217197
return ubatch;
218198
}
219199

src/llama-batch.h

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,20 @@
11
#pragma once
22

33
#include "llama.h"
4-
#include "llama-cparams.h"
54

65
#include <array>
76
#include <vector>
87

98
// very similar to llama_batch,
109
// but has more metadata about sequences
1110
struct llama_ubatch {
12-
void update();
13-
1411
bool equal_seqs;
1512
// TODO: whole_seqs for embeddings?
1613

1714
uint32_t n_tokens; // total tokens (n_seq_tokens * n_seqs)
1815
uint32_t n_seq_tokens; // tokens per sequence
1916
uint32_t n_seqs;
2017

21-
llama_pos seq_pos_min[LLAMA_MAX_PARALLEL_SEQUENCES]; // min position of each sequence
22-
llama_pos seq_pos_max[LLAMA_MAX_PARALLEL_SEQUENCES]; // max position of each sequence
23-
2418
llama_token * token; // [n_tokens]
2519
float * embd; // [n_embd, n_tokens]
2620
llama_pos * pos; // [n_tokens]

src/llama-context.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1233,7 +1233,7 @@ ggml_cgraph * llama_context::graph_reserve(uint32_t n_tokens, uint32_t n_seqs, u
12331233
this->n_outputs = n_outputs;
12341234

12351235
llama_token token = model.vocab.token_bos(); // not actually used by llama_build_graph, but required to choose between token and embedding inputs graph
1236-
llama_ubatch ubatch = { true, n_tokens, n_tokens / n_seqs, n_seqs, {-1}, {-1}, &token, nullptr, nullptr, nullptr, nullptr, nullptr};
1236+
llama_ubatch ubatch = { true, n_tokens, n_tokens / n_seqs, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr};
12371237

12381238
auto * gf = graph_init();
12391239
auto res = graph_build(ctx_compute.get(), gf, ubatch, LLM_GRAPH_TYPE_DEFAULT);

src/llama-kv-cache.cpp

Lines changed: 37 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -548,7 +548,7 @@ int32_t llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) const {
548548
if (cells.is_empty(i)) {
549549
ss += '.';
550550
} else {
551-
ss += 'x';
551+
ss += std::to_string(cells.seq_get(i));
552552
}
553553
if (i%256 == 255) {
554554
ss += '\n';
@@ -557,6 +557,10 @@ int32_t llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) const {
557557
}
558558
LLAMA_LOG_WARN("\n%s\n", ss.c_str());
559559
}
560+
561+
LLAMA_LOG_WARN("kv_cells: n_swa = %4d, min[0] = %5d, max[0] = %5d\n", n_swa, cells.seq_pos_min(0), cells.seq_pos_max(0));
562+
LLAMA_LOG_WARN("kv_cells: n_swa = %4d, min[1] = %5d, max[1] = %5d\n", n_swa, cells.seq_pos_min(1), cells.seq_pos_max(1));
563+
LLAMA_LOG_WARN("kv_cells: n_swa = %4d, min[2] = %5d, max[2] = %5d\n", n_swa, cells.seq_pos_min(2), cells.seq_pos_max(2));
560564
#endif
561565

562566
uint32_t n_tested = 0;
@@ -568,24 +572,44 @@ int32_t llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) const {
568572
continue;
569573
}
570574

575+
// keep track of what the minimum sequence positions would be if we accept the ubatch
576+
llama_seq_id seq_pos_min[LLAMA_MAX_PARALLEL_SEQUENCES];
577+
for (int s = 0; s < LLAMA_MAX_PARALLEL_SEQUENCES; ++s) {
578+
seq_pos_min[s] = cells.seq_pos_min(s);
579+
}
580+
571581
bool found = true;
572582
for (uint32_t i = 0; i < n_tokens; i++) {
573583
const llama_pos pos = ubatch.pos[i];
574584
const llama_seq_id seq_id = ubatch.seq_id[i][0];
575585

576586
// can we use this cell? either:
577587
// - the cell is empty
578-
// - the cell is occupied only by the same sequence, and the pos is masked
579-
const bool can_use =
580-
cells.is_empty(head_cur + i) ||
581-
(
582-
cells.seq_has (head_cur + i, seq_id) && // sequence mask
583-
cells.seq_count(head_cur + i) == 1 &&
584-
(
585-
cells.pos_get (head_cur + i) >= pos || // causal mask
586-
is_masked_swa(cells.pos_get(head_cur + i), ubatch.seq_pos_min[seq_id]) // SWA mask
587-
)
588-
);
588+
// - the cell is occupied only by one sequence:
589+
// - mask causally, if the sequence is the same as the one we are inserting
590+
// - mask SWA, using current max pos for that sequence in the cache
591+
// always insert in the cell with minimum pos
592+
bool can_use = cells.is_empty(head_cur + i);
593+
594+
if (!can_use && cells.seq_count(head_cur + i) == 1) {
595+
const llama_pos pos_cell = cells.pos_get(head_cur + i);
596+
597+
// causal mask
598+
if (cells.seq_has(head_cur + i, seq_id)) {
599+
can_use = pos_cell >= pos;
600+
}
601+
602+
if (!can_use) {
603+
const llama_seq_id seq_id_cell = cells.seq_get(head_cur + i);
604+
605+
// SWA mask
606+
if (pos_cell == seq_pos_min[seq_id_cell] &&
607+
is_masked_swa(pos_cell, cells.seq_pos_max(seq_id_cell) + 1)) {
608+
seq_pos_min[seq_id_cell]++;
609+
can_use = true;
610+
}
611+
}
612+
}
589613

590614
if (!can_use) {
591615
found = false;
@@ -613,9 +637,7 @@ void llama_kv_cache_unified::fill_slot(uint32_t head_cur, const llama_ubatch & u
613637

614638
for (uint32_t i = 0; i < ubatch.n_tokens; ++i) {
615639
if (!cells.is_empty(head + i)) {
616-
cells.pos_chg(head + i, ubatch.pos[i]);
617-
618-
continue;
640+
cells.rm(head + i);
619641
}
620642

621643
cells.pos_set(head + i, ubatch.pos[i]);

src/llama-kv-cells.h

Lines changed: 25 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -68,12 +68,6 @@ class llama_kv_cells_unified {
6868
// the index of the last cell that is used + 1
6969
// return 0 if no cells are used
7070
uint32_t used_max_p1() const {
71-
#if 0
72-
if (!seq_pos[0].empty()) printf("kv_cells: min[0] = %5d, max[0] = %5d\n", *seq_pos[0].begin(), *seq_pos[0].rbegin());
73-
if (!seq_pos[1].empty()) printf("kv_cells: min[1] = %5d, max[1] = %5d\n", *seq_pos[1].begin(), *seq_pos[1].rbegin());
74-
if (!seq_pos[2].empty()) printf("kv_cells: min[2] = %5d, max[2] = %5d\n", *seq_pos[2].begin(), *seq_pos[2].rbegin());
75-
#endif
76-
7771
return used.empty() ? 0 : *used.rbegin() + 1;
7872
}
7973

@@ -144,6 +138,18 @@ class llama_kv_cells_unified {
144138
}
145139
}
146140

141+
void rm(uint32_t i) {
142+
assert(i < pos.size());
143+
assert(pos[i] != -1);
144+
145+
seq_pos_rm(i);
146+
147+
pos[i] = -1;
148+
seq[i].reset();
149+
150+
used.erase(i);
151+
}
152+
147153
// note: call only if the cell has seq_id
148154
// return true if the cell becomes empty
149155
bool seq_rm(uint32_t i, llama_seq_id seq_id) {
@@ -220,6 +226,18 @@ class llama_kv_cells_unified {
220226
seq_pos[seq_id].insert(pos[i]);
221227
}
222228

229+
llama_seq_id seq_get(uint32_t i) const {
230+
assert(seq[i].count() == 1);
231+
232+
for (int s = 0; s < LLAMA_MAX_PARALLEL_SEQUENCES; ++s) {
233+
if (seq[i].test(s)) {
234+
return s;
235+
}
236+
}
237+
238+
return -1;
239+
}
240+
223241
// the minimum position of sequence seq_id currently present in any of the cells
224242
// return -1 if the sequence is not present
225243
llama_pos seq_pos_min(llama_seq_id seq_id) const {
@@ -275,22 +293,13 @@ class llama_kv_cells_unified {
275293
void pos_set(uint32_t i, llama_pos p) {
276294
assert(i < pos.size());
277295
assert(pos[i] == -1);
296+
assert(seq[i].none());
278297

279298
pos[i] = p;
280299

281300
used.insert(i);
282301
}
283302

284-
// change the position of a non-empty cell
285-
// does not modify "has_shift"
286-
// note: call only if the cell is not empty
287-
void pos_chg(uint32_t i, llama_pos p) {
288-
assert(i < pos.size());
289-
assert(pos[i] != -1);
290-
291-
pos[i] = p;
292-
}
293-
294303
// pos[i] = pos[i] + d
295304
// sets "has_shift" to true
296305
// note: call only if the cell is not empty

0 commit comments

Comments
 (0)