Skip to content

Commit f23e4cc

Browse files
committed
kv-cache : add comments
ggml-ci
1 parent 2b984f4 commit f23e4cc

File tree

2 files changed

+15
-3
lines changed

2 files changed

+15
-3
lines changed

src/llama-kv-cache.cpp

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -493,9 +493,13 @@ int32_t llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) const {
493493
LLAMA_LOG_WARN("\n%s\n", ss.c_str());
494494
}
495495

496-
LLAMA_LOG_WARN("kv_cells: n_swa = %4d, min[0] = %5d, max[0] = %5d\n", n_swa, cells.seq_pos_min(0), cells.seq_pos_max(0));
497-
LLAMA_LOG_WARN("kv_cells: n_swa = %4d, min[1] = %5d, max[1] = %5d\n", n_swa, cells.seq_pos_min(1), cells.seq_pos_max(1));
498-
LLAMA_LOG_WARN("kv_cells: n_swa = %4d, min[2] = %5d, max[2] = %5d\n", n_swa, cells.seq_pos_min(2), cells.seq_pos_max(2));
496+
for (int s = 0; s < LLAMA_MAX_PARALLEL_SEQUENCES; ++s) {
497+
if (cells.seq_pos_min(s) < 0) {
498+
continue;
499+
}
500+
501+
LLAMA_LOG_WARN("kv_cells: n_swa = %4d, min[%d] = %5d, max[%d] = %5d\n", n_swa, s, cells.seq_pos_min(s), s, cells.seq_pos_max(s));
502+
}
499503
#endif
500504

501505
uint32_t n_tested = 0;
@@ -538,6 +542,9 @@ int32_t llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) const {
538542
const llama_seq_id seq_id_cell = cells.seq_get(head_cur + i);
539543

540544
// SWA mask
545+
// note: we insert only in the cell with minimum pos in order to preserve the invariant that
546+
// all positions between [pos_min, pos_max] for each sequence will be present in the cache
547+
// ref: https://github.com/ggml-org/llama.cpp/pull/13746#issuecomment-2916057092
541548
if (pos_cell == seq_pos_min[seq_id_cell] &&
542549
is_masked_swa(pos_cell, cells.seq_pos_max(seq_id_cell) + 1)) {
543550
seq_pos_min[seq_id_cell]++;

src/llama-kv-cells.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,7 @@ class llama_kv_cells_unified {
138138
}
139139
}
140140

141+
// clear a non-empty cell
141142
void rm(uint32_t i) {
142143
assert(i < pos.size());
143144
assert(pos[i] != -1);
@@ -202,13 +203,15 @@ class llama_kv_cells_unified {
202203
return false;
203204
}
204205

206+
// number of different sequences in the cell
205207
int seq_count(uint32_t i) const {
206208
assert(i < pos.size());
207209
assert(pos[i] != -1);
208210

209211
return seq[i].count();
210212
}
211213

214+
// check if the cell contains seq_id
212215
bool seq_has(uint32_t i, llama_seq_id seq_id) const {
213216
assert(i < pos.size());
214217
assert(seq_id >= 0);
@@ -226,6 +229,8 @@ class llama_kv_cells_unified {
226229
seq_pos[seq_id].insert(pos[i]);
227230
}
228231

232+
// return the sequence id of this cell
233+
// note: call only for cells with exactly one sequence
229234
llama_seq_id seq_get(uint32_t i) const {
230235
assert(seq[i].count() == 1);
231236

0 commit comments

Comments
 (0)