Skip to content

Commit ba118a2

Browse files
committed
tests(wip): More robust test for unified cache
I'm still not clear how cache hits should be detected since find_slot does not seem to take into account the tokens themselves and simply looks for a sequence of cells that fits the size of the ubatch and has no set positions in any of the cells. I'm clearly still missing something about how this works! Branch: HybridCache Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>
1 parent 5c11928 commit ba118a2

File tree

1 file changed

+75
-23
lines changed

1 file changed

+75
-23
lines changed

tests/test-memory.cpp

Lines changed: 75 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
#include "../src/llama-model.h"
2121

2222
#include "common.h"
23+
#include "ggml.h"
2324
#include "llama.h"
2425

2526
#include <algorithm>
@@ -59,6 +60,43 @@ static std::shared_ptr<llama_model> _make_model(
5960
return model;
6061
}
6162

63+
static llama_batch _make_batch(
64+
std::vector<std::vector<llama_token>> token_seqs,
65+
std::vector<std::vector<llama_seq_id>> seq_ids) {
66+
GGML_ASSERT(token_seqs.size() == seq_ids.size());
67+
68+
size_t total_tokens = 0;
69+
for (const auto & token_seq : token_seqs) {
70+
total_tokens += token_seq.size();
71+
}
72+
size_t max_seq_ids = 0;
73+
for (const auto & seq_ids_i : seq_ids) {
74+
max_seq_ids = std::max(max_seq_ids, seq_ids_i.size());
75+
}
76+
llama_batch batch = llama_batch_init(total_tokens, 0, max_seq_ids);
77+
78+
for (size_t i = 0; i < token_seqs.size(); ++i) {
79+
const auto& token_seq = token_seqs[i];
80+
const auto& seq_ids_i = seq_ids[i];
81+
for (int pos = 0; pos < (int)token_seq.size(); ++pos) {
82+
common_batch_add(batch, token_seq[pos], pos, seq_ids_i, false);
83+
}
84+
}
85+
return batch;
86+
}
87+
88+
static bool is_source_tensor(ggml_tensor * child, ggml_tensor * parent) {
89+
if (!child || !parent) return false;
90+
for (size_t i = 0; i < GGML_MAX_SRC; ++i) {
91+
if (child->src[i] == parent) {
92+
return true;
93+
} else if (child->src[i] != nullptr && is_source_tensor(child->src[i], parent)) {
94+
return true;
95+
}
96+
}
97+
return false;
98+
}
99+
62100
struct log_scope {
63101
const char * name;
64102
explicit log_scope(const char * name) : name(name) {
@@ -121,33 +159,47 @@ static void test_llama_kv_cache_unified_single_seq() {
121159
);
122160

123161
// Create the micro batch with a single 3-token sequence
124-
//
125-
// NOTE: A bunch of these asserts were just me figuring out how the batches
126-
// relate to each other, but they're left for future readers to help in the
127-
// same understanding process.
128-
llama_seq_id seq_id = 42;
129-
llama_batch batch = llama_batch_init(3, 0, 1);
130-
common_batch_add(batch, 101, 0, {seq_id}, false);
131-
common_batch_add(batch, 1, 1, {seq_id}, false);
132-
common_batch_add(batch, 102, 2, {seq_id}, false);
133-
llama_sbatch sbatch(batch, 0, true, false);
134-
GGML_ASSERT(batch.n_tokens == 3);
135-
GGML_ASSERT(sbatch.n_tokens == 3);
136-
GGML_ASSERT(!sbatch.seq.empty());
137-
llama_ubatch ubatch = sbatch.split_simple(4);
138-
printf("ubatch.n_seqs=%d\n", ubatch.n_seqs);
139-
GGML_ASSERT(ubatch.n_seqs == 3);
140-
GGML_ASSERT(ubatch.n_seq_tokens == 1);
141-
GGML_ASSERT(ubatch.n_tokens == 3);
142-
GGML_ASSERT(ubatch.seq_id[0][0] == seq_id);
143-
GGML_ASSERT(ubatch.seq_id[1][0] == seq_id);
144-
GGML_ASSERT(ubatch.seq_id[2][0] == seq_id);
162+
llama_batch batch1 = _make_batch({{101, 1, 102}}, {{42}});
163+
llama_sbatch sbatch1 = cache.sbatch_init(batch1, false);
164+
llama_ubatch ubatch1 = cache.ubatch_next(sbatch1, 4, false);
145165

146166
// Find a slot for a new sequence
147-
GGML_ASSERT(cache.find_slot(ubatch));
167+
GGML_ASSERT(cache.find_slot(ubatch1));
168+
169+
// Cache the k/v for a single layer in this slot
170+
ggml_context * ctx = ggml_init({10240, NULL, false});
171+
ggml_tensor * k1 = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, model->hparams.n_embd_k_gqa(0));
172+
ggml_tensor * v1 = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, model->hparams.n_embd_v_gqa(0));
173+
ggml_tensor * k1_view = cache.cpy_k(ctx, k1, 0);
174+
ggml_tensor * v1_view = cache.cpy_v(ctx, v1, 0);
175+
GGML_ASSERT(is_source_tensor(k1_view, k1));
176+
GGML_ASSERT(is_source_tensor(v1_view, v1));
177+
178+
// Create a second batch with different tokens and find a slot for it
179+
llama_batch batch2 = _make_batch({{1, 2, 3, 4}}, {{5}});
180+
llama_sbatch sbatch2 = cache.sbatch_init(batch2, false);
181+
llama_ubatch ubatch2 = cache.ubatch_next(sbatch2, 4, false);
182+
GGML_ASSERT(cache.find_slot(ubatch2));
183+
184+
// Add some different tensors
185+
ggml_tensor * k2 = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, model->hparams.n_embd_k_gqa(0));
186+
ggml_tensor * v2 = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, model->hparams.n_embd_v_gqa(0));
187+
ggml_tensor * k2_view = cache.cpy_k(ctx, k2, 0);
188+
ggml_tensor * v2_view = cache.cpy_v(ctx, v2, 0);
189+
GGML_ASSERT(is_source_tensor(k2_view, k2));
190+
GGML_ASSERT(is_source_tensor(v2_view, v2));
191+
192+
// Make sure first batch's k/v aren't cache hit
193+
GGML_ASSERT(!is_source_tensor(k2_view, k1));
194+
GGML_ASSERT(!is_source_tensor(v2_view, v1));
195+
196+
// Re-find the slot for the first batch and make sure they cache hit
197+
GGML_ASSERT(cache.find_slot(ubatch1));
148198

149199
// Clean up
150-
llama_batch_free(batch);
200+
llama_batch_free(batch1);
201+
llama_batch_free(batch2);
202+
ggml_free(ctx);
151203
}
152204

153205
/*- Recurrent Cache ----------------------------------------------------------*/

0 commit comments

Comments
 (0)