Skip to content

Commit 1eec34a

Browse files
committed
kv-cache : simplify the "struct llama_kv_cache" interface
ggml-ci
1 parent de2ef53 commit 1eec34a

11 files changed

+646
-438
lines changed

include/llama.h

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -259,9 +259,9 @@ extern "C" {
259259
llama_token * token;
260260
float * embd;
261261
llama_pos * pos;
262-
int32_t * n_seq_id;
263-
llama_seq_id ** seq_id;
264-
int8_t * logits; // TODO: rename this to "output"
262+
int32_t * n_seq_id; // TODO: remove, should belong to only 1 sequence
263+
llama_seq_id ** seq_id; // TODO: become llama_seq_id * seq_id;
264+
int8_t * logits; // TODO: rename this to "output"
265265
} llama_batch;
266266

267267
enum llama_model_kv_override_type {
@@ -698,6 +698,7 @@ extern "C" {
698698
LLAMA_API bool llama_kv_self_can_shift(const struct llama_context * ctx);
699699

700700
// Apply the KV cache updates (such as K-shifts, defragmentation, etc.)
701+
// TODO: deprecate and always update the cache lazily
701702
LLAMA_API void llama_kv_self_update(struct llama_context * ctx);
702703

703704
//

src/llama-batch.cpp

Lines changed: 39 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,21 @@
44
#include <cstring>
55
#include <algorithm>
66

7+
void llama_ubatch::update() {
8+
if (equal_seqs) {
9+
// TODO: for now don't compute min/max for recurrent batches since we don't need this.
10+
// the batches will be refactored anyway, so we'll fix this later
11+
return;
12+
}
13+
14+
for (uint32_t i = 0; i < n_tokens; ++i) {
15+
const llama_seq_id s = seq_id[i][0];
16+
17+
seq_pos_min[s] = seq_pos_min[s] == -1 ? pos[i] : std::min(seq_pos_min[s], pos[i]);
18+
seq_pos_max[s] = seq_pos_max[s] == -1 ? pos[i] : std::max(seq_pos_max[s], pos[i]);
19+
}
20+
}
21+
722
llama_ubatch llama_sbatch::reserve_ubatch(size_t n_ubatch, bool has_embd) {
823
// clear empty sequences
924
// the previous ubatch is assumed to be gone,
@@ -15,24 +30,33 @@ llama_ubatch llama_sbatch::reserve_ubatch(size_t n_ubatch, bool has_embd) {
1530
break;
1631
}
1732
}
18-
ubatch_token.resize(!has_embd ? n_ubatch : 0);
19-
ubatch_embd.resize(has_embd ? n_embd * n_ubatch : 0);
20-
ubatch_pos.resize(n_ubatch);
21-
ubatch_n_seq_id.resize(n_ubatch);
22-
ubatch_seq_id.resize(n_ubatch);
23-
ubatch_output.resize(n_ubatch);
33+
34+
udatas.push_back({});
35+
36+
auto & udata = udatas.back();
37+
38+
udata.token.resize(!has_embd ? n_ubatch : 0);
39+
udata.embd.resize(has_embd ? n_embd * n_ubatch : 0);
40+
udata.pos.resize(n_ubatch);
41+
udata.n_seq_id.resize(n_ubatch);
42+
udata.seq_id.resize(n_ubatch);
43+
udata.output.resize(n_ubatch);
44+
2445
llama_ubatch ubatch = {
2546
/*equal_seqs =*/ true,
2647
/*n_tokens =*/ 0,
2748
/*n_seq_tokens =*/ 0,
2849
/*n_seqs =*/ 0,
29-
/*token =*/ !has_embd ? ubatch_token.data() : nullptr,
30-
/*embd =*/ has_embd ? ubatch_embd.data() : nullptr,
31-
/*pos =*/ ubatch_pos.data(),
32-
/*n_seq_id =*/ ubatch_n_seq_id.data(),
33-
/*seq_id =*/ ubatch_seq_id.data(),
34-
/*output =*/ ubatch_output.data(),
50+
/*seq_pos_min =*/ {-1},
51+
/*seq_pos_max =*/ {-1},
52+
/*token =*/ !has_embd ? udata.token.data() : nullptr,
53+
/*embd =*/ has_embd ? udata.embd.data() : nullptr,
54+
/*pos =*/ udata.pos.data(),
55+
/*n_seq_id =*/ udata.n_seq_id.data(),
56+
/*seq_id =*/ udata.seq_id.data(),
57+
/*output =*/ udata.output.data(),
3558
};
59+
3660
return ubatch;
3761
}
3862

@@ -148,6 +172,7 @@ llama_ubatch llama_sbatch::split_simple(size_t n_ubatch) {
148172
GGML_ASSERT(seq.size() == 1 && s.n_seq_id == 0); // don't mix with other splits
149173
add_seq_to_ubatch(ubatch, s, length);
150174
}
175+
ubatch.update();
151176
return ubatch;
152177
}
153178

@@ -175,6 +200,7 @@ llama_ubatch llama_sbatch::split_equal(size_t n_ubatch) {
175200
if (length + n_tokens_in_ubatch > n_ubatch) { break; }
176201
}
177202
}
203+
ubatch.update();
178204
return ubatch;
179205
}
180206

@@ -187,6 +213,7 @@ llama_ubatch llama_sbatch::split_seq(size_t n_ubatch) {
187213
GGML_ASSERT(s.n_seq_id > 0); // should not be mixed with simple splits
188214
add_seq_to_ubatch(ubatch, s, length);
189215
}
216+
ubatch.update();
190217
return ubatch;
191218
}
192219

src/llama-batch.h

Lines changed: 21 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,31 @@
11
#pragma once
22

33
#include "llama.h"
4+
#include "llama-cparams.h"
45

56
#include <array>
67
#include <vector>
78

89
// very similar to llama_batch,
910
// but has more metadata about sequences
1011
struct llama_ubatch {
12+
void update();
13+
1114
bool equal_seqs;
1215
// TODO: whole_seqs for embeddings?
1316

14-
uint32_t n_tokens; // total tokens (n_seq_tokens * n_seqs)
17+
uint32_t n_tokens; // total tokens (n_seq_tokens * n_seqs)
1518
uint32_t n_seq_tokens; // tokens per sequence
1619
uint32_t n_seqs;
1720

21+
llama_pos seq_pos_min[LLAMA_MAX_PARALLEL_SEQUENCES]; // min position of each sequence
22+
llama_pos seq_pos_max[LLAMA_MAX_PARALLEL_SEQUENCES]; // max position of each sequence
23+
1824
llama_token * token; // [n_tokens]
1925
float * embd; // [n_embd, n_tokens]
2026
llama_pos * pos; // [n_tokens]
21-
int32_t * n_seq_id; // [n_seqs]
22-
llama_seq_id ** seq_id; // [n_seqs]
27+
int32_t * n_seq_id; // [n_seqs] // TODO: remove, should belong to only 1 sequence
28+
llama_seq_id ** seq_id; // [n_seqs] // TODO: become llama_seq_id * seq_id;
2329
int8_t * output; // [n_tokens]
2430
};
2531

@@ -49,13 +55,18 @@ struct llama_sbatch {
4955

5056
const llama_batch * batch = nullptr;
5157

52-
// buffers for the ubatch
53-
std::vector<llama_token> ubatch_token;
54-
std::vector<float> ubatch_embd;
55-
std::vector<llama_pos> ubatch_pos;
56-
std::vector<int32_t> ubatch_n_seq_id;
57-
std::vector<llama_seq_id *> ubatch_seq_id;
58-
std::vector<int8_t> ubatch_output;
58+
// buffers for the ubatches
59+
// TODO: very hacky, this needs a complete rework
60+
struct ubatch_data {
61+
std::vector<llama_token> token;
62+
std::vector<float> embd;
63+
std::vector<llama_pos> pos;
64+
std::vector<int32_t> n_seq_id;
65+
std::vector<llama_seq_id *> seq_id;
66+
std::vector<int8_t> output;
67+
};
68+
69+
std::vector<ubatch_data> udatas;
5970

6071
llama_ubatch reserve_ubatch(size_t n_ubatch, bool has_embd = false);
6172

0 commit comments

Comments
 (0)