Skip to content

Commit 1ec785c

Browse files
committed
kv-cells : get_pos() -> pos_get() + comments
ggml-ci
1 parent b495877 commit 1ec785c

File tree

2 files changed

+19
-16
lines changed

2 files changed

+19
-16
lines changed

src/llama-kv-cache.cpp

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -290,7 +290,7 @@ llama_pos llama_kv_cache_unified::seq_pos_min(llama_seq_id seq_id) const {
290290

291291
for (uint32_t i = 0; i < cells.size(); ++i) {
292292
if (cells.seq_has(i, seq_id)) {
293-
result = std::min(result, cells.get_pos(i));
293+
result = std::min(result, cells.pos_get(i));
294294
}
295295
}
296296

@@ -306,7 +306,7 @@ llama_pos llama_kv_cache_unified::seq_pos_max(llama_seq_id seq_id) const {
306306

307307
for (uint32_t i = 0; i < cells.size(); ++i) {
308308
if (cells.seq_has(i, seq_id)) {
309-
result = std::max(result, cells.get_pos(i));
309+
result = std::max(result, cells.pos_get(i));
310310
}
311311
}
312312

@@ -611,7 +611,7 @@ void llama_kv_cache_unified::prune_swa(llama_seq_id seq_id, llama_pos pmin, llam
611611
continue;
612612
}
613613

614-
const llama_pos p0 = cells.get_pos(i);
614+
const llama_pos p0 = cells.pos_get(i);
615615

616616
if (p0 <= pmin && !is_masked_swa(p0, pmin)) {
617617
n_attended++;
@@ -664,7 +664,7 @@ void llama_kv_cache_unified::set_input_kq_mask(ggml_tensor * dst, const llama_ub
664664
if (cells.is_empty(i)) {
665665
masked = true;
666666
} else {
667-
const llama_pos p0 = cells.get_pos(i);
667+
const llama_pos p0 = cells.pos_get(i);
668668

669669
// mask the token if not the same sequence
670670
masked = masked || (!cells.seq_has(i, seq_id));
@@ -724,7 +724,7 @@ void llama_kv_cache_unified::set_input_pos_bucket(ggml_tensor * dst, const llama
724724
for (int j = 0; j < n_tokens; ++j) {
725725
for (int i = 0; i < n_kv; ++i) {
726726
// the position when the cells is empty is irrelevant - it will be masked out later in the attention
727-
const llama_pos p0 = cells.is_empty(i) ? -1 : cells.get_pos(i);
727+
const llama_pos p0 = cells.is_empty(i) ? -1 : cells.pos_get(i);
728728

729729
data[h*(n_kv*n_tokens) + j*n_kv + i] = llama_relative_position_bucket(p0, ubatch->pos[j], hparams.n_rel_attn_bkts, false);
730730
}
@@ -1250,7 +1250,7 @@ void llama_kv_cache_unified::state_write_meta(llama_io_write_i & io, const std::
12501250
}
12511251
}
12521252

1253-
const llama_pos pos = cells.get_pos(i);
1253+
const llama_pos pos = cells.pos_get(i);
12541254
const uint32_t n_seq_id = seq_ids.size();
12551255

12561256
io.write(&pos, sizeof(pos));
@@ -1394,8 +1394,8 @@ bool llama_kv_cache_unified::state_read_meta(llama_io_read_i & io, uint32_t cell
13941394
// DEBUG CHECK: kv.head should be our first cell, kv.head + cell_count - 1 should be our last cell (verify seq_id and pos values)
13951395
// Assume that this is one contiguous block of cells
13961396
GGML_ASSERT(head + cell_count <= cells.size());
1397-
GGML_ASSERT(cells.get_pos(head) == batch.pos[0]);
1398-
GGML_ASSERT(cells.get_pos(head + cell_count - 1) == batch.pos[cell_count - 1]);
1397+
GGML_ASSERT(cells.pos_get(head) == batch.pos[0]);
1398+
GGML_ASSERT(cells.pos_get(head + cell_count - 1) == batch.pos[cell_count - 1]);
13991399
GGML_ASSERT(cells.seq_has(head, dest_seq_id));
14001400
GGML_ASSERT(cells.seq_has(head + cell_count - 1, dest_seq_id));
14011401
} else {

src/llama-kv-cells.h

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ class llama_kv_cells_unified {
5757
return has_shift;
5858
}
5959

60-
// move cell isrc to idst
60+
// move cell isrc to idst (used during defrag)
6161
void mv(uint32_t isrc, uint32_t idst) {
6262
assert(isrc < pos.size());
6363
assert(idst < pos.size());
@@ -71,7 +71,7 @@ class llama_kv_cells_unified {
7171
seq [isrc].reset();
7272
}
7373

74-
// copy the state of cells [i, i + n)
74+
// copy the state of cells [i, i + n) (used for save/restore the state of the cells)
7575
llama_kv_cells_unified cp(uint32_t i, uint32_t n) const {
7676
assert(i + n <= pos.size());
7777

@@ -89,7 +89,7 @@ class llama_kv_cells_unified {
8989
return res;
9090
}
9191

92-
// set the state of cells [i, i + other.pos.size())
92+
// set the state of cells [i, i + other.pos.size()) (used for save/restore the state of the cells)
9393
void set(uint32_t i, const llama_kv_cells_unified & other) {
9494
assert(i + other.pos.size() <= pos.size());
9595

@@ -179,7 +179,7 @@ class llama_kv_cells_unified {
179179
}
180180

181181
// note: call only if the cell is not empty
182-
llama_pos get_pos(uint32_t i) const {
182+
llama_pos pos_get(uint32_t i) const {
183183
assert(i < pos.size());
184184
assert(pos[i] != -1);
185185

@@ -264,11 +264,14 @@ class llama_kv_cells_unified {
264264
// cells.pos_add(x, shift_x);
265265
// cells.pos_div(y, shift_y);
266266
// ...
267-
// for (int i = 0; i < n; ++i) {
268-
// auto shift_i = cells.get_shift(i);
269-
// ...
267+
//
268+
// if (cells.has_shift()) {
269+
// for (int i = 0; i < n; ++i) {
270+
// auto shift_i = cells.get_shift(i);
271+
// ...
272+
// }
273+
// cells.reset_shift();
270274
// }
271-
// cells.reset_shift();
272275
//
273276
std::vector<llama_pos> shift;
274277

0 commit comments

Comments
 (0)