Skip to content

Commit ec08571

Browse files
committed
tests: Add a test for finding a slot for a single sequence
Branch: HybridCache Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>
1 parent 5c42d3a commit ec08571

File tree

1 file changed

+54
-0
lines changed

1 file changed

+54
-0
lines changed

tests/test-memory.cpp

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,13 @@
1313
*----------------------------------------------------------------------------*/
1414

1515
#include "../src/llama-arch.h"
16+
#include "../src/llama-batch.h"
1617
#include "../src/llama-hparams.h"
1718
#include "../src/llama-impl.h"
1819
#include "../src/llama-kv-cache.h"
1920
#include "../src/llama-model.h"
2021

22+
#include "common.h"
2123
#include "llama.h"
2224

2325
#include <algorithm>
@@ -103,6 +105,58 @@ static void test_llama_kv_cache_unified_constructor() {
103105
);
104106
}
105107

108+
/* Test that the unified cache can operate with a single seq */
109+
static void test_llama_kv_cache_unified_single_seq() {
110+
auto model = _make_model();
111+
llama_kv_cache_unified cache(
112+
/* model */ *model,
113+
/* type_k */ GGML_TYPE_F16,
114+
/* type_v */ GGML_TYPE_F16,
115+
/* v_trans */ false,
116+
/* offload */ false,
117+
/* kv_size */ 10,
118+
/* padding */ 10
119+
);
120+
GGML_ASSERT(cache.get_used_cells() == 0);
121+
122+
// Create the micro batch with a single 3-token sequence
123+
//
124+
// NOTE: A bunch of these asserts were just me figuring out how the batches
125+
// relate to each other, but they're left for future readers to help in the
126+
// same understanding process.
127+
llama_seq_id seq_id = 42;
128+
llama_batch batch = llama_batch_init(3, 0, 1);
129+
common_batch_add(batch, 101, 0, {seq_id}, false);
130+
common_batch_add(batch, 1, 1, {seq_id}, false);
131+
common_batch_add(batch, 102, 2, {seq_id}, false);
132+
llama_sbatch sbatch(batch, 0, true, false);
133+
GGML_ASSERT(batch.n_tokens == 3);
134+
GGML_ASSERT(sbatch.n_tokens == 3);
135+
GGML_ASSERT(!sbatch.seq.empty());
136+
llama_ubatch ubatch = sbatch.split_simple(4);
137+
printf("ubatch.n_seqs=%d\n", ubatch.n_seqs);
138+
GGML_ASSERT(ubatch.n_seqs == 3);
139+
GGML_ASSERT(ubatch.n_seq_tokens == 1);
140+
GGML_ASSERT(ubatch.n_tokens == 3);
141+
GGML_ASSERT(ubatch.seq_id[0][0] == seq_id);
142+
GGML_ASSERT(ubatch.seq_id[1][0] == seq_id);
143+
GGML_ASSERT(ubatch.seq_id[2][0] == seq_id);
144+
145+
// Find a slot for a new sequence
146+
GGML_ASSERT(cache.find_slot(ubatch));
147+
printf("cache.head=%d\n", cache.head);
148+
GGML_ASSERT(cache.head == 0); // Ready to start filling at the beginning
149+
GGML_ASSERT(cache.used == 3);
150+
for (int i = 0; i < 3; ++i) {
151+
GGML_ASSERT(cache.cells[i].seq_id.size() == 1);
152+
GGML_ASSERT(*cache.cells[i].seq_id.begin() == seq_id);
153+
GGML_ASSERT(cache.cells[i].pos == i);
154+
}
155+
156+
// Clean up
157+
llama_batch_free(batch);
158+
}
159+
106160
/*- Recurrent Cache ----------------------------------------------------------*/
107161

108162
/* Test that the recurrent cache can be constructed and destructed safely */

0 commit comments

Comments
 (0)