|
20 | 20 | #include "../src/llama-model.h"
|
21 | 21 |
|
22 | 22 | #include "common.h"
|
| 23 | +#include "ggml.h" |
23 | 24 | #include "llama.h"
|
24 | 25 |
|
25 | 26 | #include <algorithm>
|
@@ -59,6 +60,43 @@ static std::shared_ptr<llama_model> _make_model(
|
59 | 60 | return model;
|
60 | 61 | }
|
61 | 62 |
|
| 63 | +static llama_batch _make_batch( |
| 64 | + std::vector<std::vector<llama_token>> token_seqs, |
| 65 | + std::vector<std::vector<llama_seq_id>> seq_ids) { |
| 66 | + GGML_ASSERT(token_seqs.size() == seq_ids.size()); |
| 67 | + |
| 68 | + size_t total_tokens = 0; |
| 69 | + for (const auto & token_seq : token_seqs) { |
| 70 | + total_tokens += token_seq.size(); |
| 71 | + } |
| 72 | + size_t max_seq_ids = 0; |
| 73 | + for (const auto & seq_ids_i : seq_ids) { |
| 74 | + max_seq_ids = std::max(max_seq_ids, seq_ids_i.size()); |
| 75 | + } |
| 76 | + llama_batch batch = llama_batch_init(total_tokens, 0, max_seq_ids); |
| 77 | + |
| 78 | + for (size_t i = 0; i < token_seqs.size(); ++i) { |
| 79 | + const auto& token_seq = token_seqs[i]; |
| 80 | + const auto& seq_ids_i = seq_ids[i]; |
| 81 | + for (int pos = 0; pos < (int)token_seq.size(); ++pos) { |
| 82 | + common_batch_add(batch, token_seq[pos], pos, seq_ids_i, false); |
| 83 | + } |
| 84 | + } |
| 85 | + return batch; |
| 86 | +} |
| 87 | + |
| 88 | +static bool is_source_tensor(ggml_tensor * child, ggml_tensor * parent) { |
| 89 | + if (!child || !parent) return false; |
| 90 | + for (size_t i = 0; i < GGML_MAX_SRC; ++i) { |
| 91 | + if (child->src[i] == parent) { |
| 92 | + return true; |
| 93 | + } else if (child->src[i] != nullptr && is_source_tensor(child->src[i], parent)) { |
| 94 | + return true; |
| 95 | + } |
| 96 | + } |
| 97 | + return false; |
| 98 | +} |
| 99 | + |
62 | 100 | struct log_scope {
|
63 | 101 | const char * name;
|
64 | 102 | explicit log_scope(const char * name) : name(name) {
|
@@ -121,33 +159,47 @@ static void test_llama_kv_cache_unified_single_seq() {
|
121 | 159 | );
|
122 | 160 |
|
123 | 161 | // Create the micro batch with a single 3-token sequence
|
124 |
| - // |
125 |
| - // NOTE: A bunch of these asserts were just me figuring out how the batches |
126 |
| - // relate to each other, but they're left for future readers to help in the |
127 |
| - // same understanding process. |
128 |
| - llama_seq_id seq_id = 42; |
129 |
| - llama_batch batch = llama_batch_init(3, 0, 1); |
130 |
| - common_batch_add(batch, 101, 0, {seq_id}, false); |
131 |
| - common_batch_add(batch, 1, 1, {seq_id}, false); |
132 |
| - common_batch_add(batch, 102, 2, {seq_id}, false); |
133 |
| - llama_sbatch sbatch(batch, 0, true, false); |
134 |
| - GGML_ASSERT(batch.n_tokens == 3); |
135 |
| - GGML_ASSERT(sbatch.n_tokens == 3); |
136 |
| - GGML_ASSERT(!sbatch.seq.empty()); |
137 |
| - llama_ubatch ubatch = sbatch.split_simple(4); |
138 |
| - printf("ubatch.n_seqs=%d\n", ubatch.n_seqs); |
139 |
| - GGML_ASSERT(ubatch.n_seqs == 3); |
140 |
| - GGML_ASSERT(ubatch.n_seq_tokens == 1); |
141 |
| - GGML_ASSERT(ubatch.n_tokens == 3); |
142 |
| - GGML_ASSERT(ubatch.seq_id[0][0] == seq_id); |
143 |
| - GGML_ASSERT(ubatch.seq_id[1][0] == seq_id); |
144 |
| - GGML_ASSERT(ubatch.seq_id[2][0] == seq_id); |
| 162 | + llama_batch batch1 = _make_batch({{101, 1, 102}}, {{42}}); |
| 163 | + llama_sbatch sbatch1 = cache.sbatch_init(batch1, false); |
| 164 | + llama_ubatch ubatch1 = cache.ubatch_next(sbatch1, 4, false); |
145 | 165 |
|
146 | 166 | // Find a slot for a new sequence
|
147 |
| - GGML_ASSERT(cache.find_slot(ubatch)); |
| 167 | + GGML_ASSERT(cache.find_slot(ubatch1)); |
| 168 | + |
| 169 | + // Cache the k/v for a single layer in this slot |
| 170 | + ggml_context * ctx = ggml_init({10240, NULL, false}); |
| 171 | + ggml_tensor * k1 = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, model->hparams.n_embd_k_gqa(0)); |
| 172 | + ggml_tensor * v1 = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, model->hparams.n_embd_v_gqa(0)); |
| 173 | + ggml_tensor * k1_view = cache.cpy_k(ctx, k1, 0); |
| 174 | + ggml_tensor * v1_view = cache.cpy_v(ctx, v1, 0); |
| 175 | + GGML_ASSERT(is_source_tensor(k1_view, k1)); |
| 176 | + GGML_ASSERT(is_source_tensor(v1_view, v1)); |
| 177 | + |
| 178 | + // Create a second batch with different tokens and find a slot for it |
| 179 | + llama_batch batch2 = _make_batch({{1, 2, 3, 4}}, {{5}}); |
| 180 | + llama_sbatch sbatch2 = cache.sbatch_init(batch2, false); |
| 181 | + llama_ubatch ubatch2 = cache.ubatch_next(sbatch2, 4, false); |
| 182 | + GGML_ASSERT(cache.find_slot(ubatch2)); |
| 183 | + |
| 184 | + // Add some different tensors |
| 185 | + ggml_tensor * k2 = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, model->hparams.n_embd_k_gqa(0)); |
| 186 | + ggml_tensor * v2 = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, model->hparams.n_embd_v_gqa(0)); |
| 187 | + ggml_tensor * k2_view = cache.cpy_k(ctx, k2, 0); |
| 188 | + ggml_tensor * v2_view = cache.cpy_v(ctx, v2, 0); |
| 189 | + GGML_ASSERT(is_source_tensor(k2_view, k2)); |
| 190 | + GGML_ASSERT(is_source_tensor(v2_view, v2)); |
| 191 | + |
| 192 | + // Make sure first batch's k/v aren't cache hit |
| 193 | + GGML_ASSERT(!is_source_tensor(k2_view, k1)); |
| 194 | + GGML_ASSERT(!is_source_tensor(v2_view, v1)); |
| 195 | + |
| 196 | + // Re-find the slot for the first batch and make sure they cache hit |
| 197 | + GGML_ASSERT(cache.find_slot(ubatch1)); |
148 | 198 |
|
149 | 199 | // Clean up
|
150 |
| - llama_batch_free(batch); |
| 200 | + llama_batch_free(batch1); |
| 201 | + llama_batch_free(batch2); |
| 202 | + ggml_free(ctx); |
151 | 203 | }
|
152 | 204 |
|
153 | 205 | /*- Recurrent Cache ----------------------------------------------------------*/
|
|
0 commit comments