Skip to content

Commit 44856a7

Browse files
committed
kv-cells : use "shift" instead of "delta" consistently
ggml-ci
1 parent 9023ae3 commit 44856a7

File tree

4 files changed

+60
-47
lines changed

4 files changed

+60
-47
lines changed

src/llama-kv-cache.cpp

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -217,8 +217,8 @@ void llama_kv_cache_unified::seq_keep(llama_seq_id seq_id) {
217217
}
218218
}
219219

220-
void llama_kv_cache_unified::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta) {
221-
if (delta == 0) {
220+
void llama_kv_cache_unified::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) {
221+
if (shift == 0) {
222222
return;
223223
}
224224

@@ -243,7 +243,7 @@ void llama_kv_cache_unified::seq_add(llama_seq_id seq_id, llama_pos p0, llama_po
243243
}
244244

245245
if (cells.seq_has(i, seq_id)) {
246-
if (cells.pos_add(i, delta)) {
246+
if (cells.pos_add(i, shift)) {
247247
if (new_head == cells.size()) {
248248
new_head = i;
249249
}
@@ -336,7 +336,7 @@ bool llama_kv_cache_unified::update(llama_context & lctx) {
336336

337337
auto * sched = lctx.get_sched();
338338

339-
if (cells.pos_has_shift()) {
339+
if (cells.get_has_shift()) {
340340
if (!get_can_shift()) {
341341
GGML_ABORT("The current KV cache / model configuration does not support K-shift");
342342
}
@@ -360,7 +360,7 @@ bool llama_kv_cache_unified::update(llama_context & lctx) {
360360
need_reserve = true;
361361
}
362362

363-
cells.pos_reset_delta();
363+
cells.reset_shift();
364364
}
365365

366366
if (do_defrag) {
@@ -706,7 +706,7 @@ void llama_kv_cache_unified::set_input_k_shift(ggml_tensor * dst) const {
706706
int32_t * data = (int32_t *) dst->data;
707707

708708
for (uint32_t i = 0; i < cells.size(); ++i) {
709-
data[i] = cells.is_empty(i) ? 0 : cells.get_delta(i);
709+
data[i] = cells.is_empty(i) ? 0 : cells.get_shift(i);
710710
}
711711
}
712712

@@ -1631,9 +1631,9 @@ void llama_kv_cache_unified_iswa::seq_keep(llama_seq_id seq_id) {
16311631
kv_swa ->seq_keep(seq_id);
16321632
}
16331633

1634-
void llama_kv_cache_unified_iswa::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta) {
1635-
kv_base->seq_add(seq_id, p0, p1, delta);
1636-
kv_swa ->seq_add(seq_id, p0, p1, delta);
1634+
void llama_kv_cache_unified_iswa::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) {
1635+
kv_base->seq_add(seq_id, p0, p1, shift);
1636+
kv_swa ->seq_add(seq_id, p0, p1, shift);
16371637
}
16381638

16391639
void llama_kv_cache_unified_iswa::seq_div(llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) {
@@ -2005,8 +2005,8 @@ void llama_kv_cache_recurrent::seq_keep(llama_seq_id seq_id) {
20052005
}
20062006
}
20072007

2008-
void llama_kv_cache_recurrent::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta) {
2009-
if (delta == 0) {
2008+
void llama_kv_cache_recurrent::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) {
2009+
if (shift == 0) {
20102010
return;
20112011
}
20122012

@@ -2029,7 +2029,7 @@ void llama_kv_cache_recurrent::seq_add(llama_seq_id seq_id, llama_pos p0, llama_
20292029
if (tail_id >= 0) {
20302030
kv_cell & cell = cells[tail_id];
20312031
if (cell.has_seq_id(seq_id) && p0 <= cell.pos && cell.pos < p1) {
2032-
cell.pos += delta;
2032+
cell.pos += shift;
20332033
}
20342034
}
20352035
}

src/llama-kv-cache.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,7 @@ class llama_kv_cache_unified : public llama_kv_cache {
123123
bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) override;
124124
void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override;
125125
void seq_keep(llama_seq_id seq_id) override;
126-
void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta) override;
126+
void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) override;
127127
void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) override;
128128

129129
llama_pos seq_pos_min(llama_seq_id seq_id) const override;
@@ -316,7 +316,7 @@ class llama_kv_cache_unified_iswa : public llama_kv_cache {
316316
bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) override;
317317
void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override;
318318
void seq_keep(llama_seq_id seq_id) override;
319-
void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta) override;
319+
void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) override;
320320
void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) override;
321321

322322
llama_pos seq_pos_min(llama_seq_id seq_id) const override;
@@ -422,7 +422,7 @@ class llama_kv_cache_recurrent : public llama_kv_cache {
422422
bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) override;
423423
void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override;
424424
void seq_keep(llama_seq_id seq_id) override;
425-
void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta) override;
425+
void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) override;
426426
void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) override;
427427

428428
llama_pos seq_pos_min(llama_seq_id seq_id) const override;

src/llama-kv-cells.h

Lines changed: 44 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,32 @@
11
#pragma once
22

3+
#include "llama.h"
4+
35
#include <bitset>
46
#include <cassert>
57
#include <vector>
68

7-
using llama_pos = int32_t;
8-
using llama_seq_id = int32_t;
9-
109
// meta information about KV cells that can be part of multiple sequences at the same time
1110
// TODO: add unit tests
12-
struct llama_kv_cells_unified {
11+
class llama_kv_cells_unified {
12+
public:
1313
void reset() {
1414
for (uint32_t i = 0; i < pos.size(); ++i) {
1515
pos[i] = -1;
16-
delta[i] = 0;
16+
shift[i] = 0;
1717
seq[i].reset();
1818
}
1919

2020
used = 0;
21-
has_delta = false;
21+
has_shift = false;
22+
}
23+
24+
void reset_shift() {
25+
has_shift = false;
26+
27+
for (uint32_t i = 0; i < shift.size(); ++i) {
28+
shift[i] = 0;
29+
}
2230
}
2331

2432
uint32_t size() const {
@@ -27,7 +35,7 @@ struct llama_kv_cells_unified {
2735

2836
void resize(uint32_t n) {
2937
pos.resize(n);
30-
delta.resize(n);
38+
shift.resize(n);
3139
seq.resize(n);
3240

3341
reset();
@@ -44,17 +52,21 @@ struct llama_kv_cells_unified {
4452
return used;
4553
}
4654

55+
bool get_has_shift() const {
56+
return has_shift;
57+
}
58+
4759
// move cell isrc to idst
4860
void mv(uint32_t isrc, uint32_t idst) {
4961
assert(isrc < pos.size());
5062
assert(idst < pos.size());
5163

5264
pos [idst] = pos [isrc];
53-
delta[idst] = delta[isrc];
65+
shift[idst] = shift[isrc];
5466
seq [idst] = seq [isrc];
5567

5668
pos [isrc] = -1;
57-
delta[isrc] = 0;
69+
shift[isrc] = 0;
5870
seq [isrc].reset();
5971
}
6072

@@ -70,7 +82,7 @@ struct llama_kv_cells_unified {
7082
res.pos[j] = pos[i + j];
7183
res.seq[j] = seq[i + j];
7284

73-
assert(delta[i + j] == 0);
85+
assert(shift[i + j] == 0);
7486
}
7587

7688
return res;
@@ -92,7 +104,7 @@ struct llama_kv_cells_unified {
92104
pos[i + j] = other.pos[j];
93105
seq[i + j] = other.seq[j];
94106

95-
assert(delta[i + j] == 0);
107+
assert(shift[i + j] == 0);
96108
}
97109
}
98110

@@ -174,11 +186,11 @@ struct llama_kv_cells_unified {
174186
}
175187

176188
// note: call only if the cell is not empty
177-
llama_pos get_delta(uint32_t i) const {
189+
llama_pos get_shift(uint32_t i) const {
178190
assert(i < pos.size());
179191
assert(pos[i] != -1);
180192

181-
return delta[i];
193+
return shift[i];
182194
}
183195

184196
bool pos_in(uint32_t i, llama_pos p0, llama_pos p1) const {
@@ -203,9 +215,9 @@ struct llama_kv_cells_unified {
203215
assert(pos[i] != -1);
204216

205217
pos[i] += d;
206-
delta[i] += d;
218+
shift[i] += d;
207219

208-
has_delta = true;
220+
has_shift = true;
209221

210222
if (pos[i] < 0) {
211223
pos[i] = -1;
@@ -228,30 +240,31 @@ struct llama_kv_cells_unified {
228240
const llama_pos p_old = pos[i];
229241

230242
pos[i] /= d;
231-
delta[i] += p_old - pos[i];
243+
shift[i] += p_old - pos[i];
232244

233-
has_delta = true;
234-
}
235-
236-
bool pos_has_shift() const {
237-
return has_delta;
238-
}
239-
240-
void pos_reset_delta() {
241-
has_delta = false;
242-
243-
for (uint32_t i = 0; i < delta.size(); ++i) {
244-
delta[i] = 0;
245-
}
245+
has_shift = true;
246246
}
247247

248248
private:
249249
uint32_t used = 0; // used cells (i.e. at least one seq_id)
250250

251-
bool has_delta = false;
251+
bool has_shift = false;
252252

253253
std::vector<llama_pos> pos;
254-
std::vector<llama_pos> delta;
254+
255+
// this array accumulates any applied shifts to the pos array since the last reset_shift() call
256+
// this is used to queue multiple updates to the pos array, which in the end can be applied in one go:
257+
//
258+
// cells.pos_add(x, shift_x);
259+
// cells.pos_div(y, shift_y);
260+
// ...
261+
// for (int i = 0; i < n; ++i) {
262+
// auto shift_i = cells.get_shift(i);
263+
// ...
264+
// }
265+
// cells.reset_shift();
266+
//
267+
std::vector<llama_pos> shift;
255268

256269
// TODO: assert n_seq_max <= 64
257270
std::vector<std::bitset<64>> seq;

src/llama-memory.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ class llama_memory_i {
2222
virtual bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) = 0;
2323
virtual void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) = 0;
2424
virtual void seq_keep(llama_seq_id seq_id) = 0;
25-
virtual void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta) = 0;
25+
virtual void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) = 0;
2626
virtual void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) = 0;
2727

2828
virtual llama_pos seq_pos_min(llama_seq_id seq_id) const = 0;

0 commit comments

Comments
 (0)