Skip to content

Commit c1434b8

Browse files
committed
kv-cache : improve slot allocation logic
ggml-ci
1 parent 50d9a51 commit c1434b8

File tree

2 files changed

+14
-5
lines changed

2 files changed

+14
-5
lines changed

src/llama-kv-cache.cpp

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -554,16 +554,18 @@ int32_t llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) const {
554554

555555
bool found = true;
556556
for (uint32_t i = 0; i < n_tokens; i++) {
557-
// TODO: improve to accept cells that are masked by the SWA
558-
//if (!cells.is_empty(head_cur + i)) {
559-
560557
const llama_seq_id seq_id = ubatch.seq_id[i][0];
561558

559+
// can we use this cell? either:
560+
// - the cell is empty
561+
// - the cell is occupied only by the same sequence, and the sequence is not masked
562562
const bool can_use =
563563
cells.is_empty(head_cur + i) ||
564564
(
565-
cells.seq_has(head_cur + i, seq_id) && // TODO: seq_has_only
566-
is_masked_swa(cells.pos_get(head_cur + i), ubatch.seq_pos_min[seq_id])
565+
cells.pos_get(head_cur + i) <= ubatch.pos[i] && // causal mask
566+
cells.seq_has(head_cur + i, seq_id) && // sequence mask
567+
cells.seq_count(head_cur + i) == 1 &&
568+
is_masked_swa(cells.pos_get(head_cur + i), ubatch.seq_pos_min[seq_id]) // SWA mask
567569
);
568570

569571
if (!can_use) {

src/llama-kv-cells.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,13 @@ class llama_kv_cells_unified {
155155
return false;
156156
}
157157

158+
int seq_count(uint32_t i) const {
159+
assert(i < pos.size());
160+
assert(pos[i] != -1);
161+
162+
return seq[i].count();
163+
}
164+
158165
bool seq_has(uint32_t i, llama_seq_id seq_id) const {
159166
assert(i < pos.size());
160167
assert(seq_id >= 0);

0 commit comments

Comments
 (0)