|
13 | 13 | *----------------------------------------------------------------------------*/
|
14 | 14 |
|
15 | 15 | #include "../src/llama-arch.h"
|
| 16 | +#include "../src/llama-batch.h" |
16 | 17 | #include "../src/llama-hparams.h"
|
17 | 18 | #include "../src/llama-impl.h"
|
18 | 19 | #include "../src/llama-kv-cache.h"
|
19 | 20 | #include "../src/llama-model.h"
|
20 | 21 |
|
| 22 | +#include "common.h" |
21 | 23 | #include "llama.h"
|
22 | 24 |
|
23 | 25 | #include <algorithm>
|
@@ -103,6 +105,58 @@ static void test_llama_kv_cache_unified_constructor() {
|
103 | 105 | );
|
104 | 106 | }
|
105 | 107 |
|
| 108 | +/* Test that the unified cache can operate with a single seq */ |
| 109 | +static void test_llama_kv_cache_unified_single_seq() { |
| 110 | + auto model = _make_model(); |
| 111 | + llama_kv_cache_unified cache( |
| 112 | + /* model */ *model, |
| 113 | + /* type_k */ GGML_TYPE_F16, |
| 114 | + /* type_v */ GGML_TYPE_F16, |
| 115 | + /* v_trans */ false, |
| 116 | + /* offload */ false, |
| 117 | + /* kv_size */ 10, |
| 118 | + /* padding */ 10 |
| 119 | + ); |
| 120 | + GGML_ASSERT(cache.get_used_cells() == 0); |
| 121 | + |
| 122 | + // Create the micro batch with a single 3-token sequence |
| 123 | + // |
| 124 | + // NOTE: A bunch of these asserts were just me figuring out how the batches |
| 125 | + // relate to each other, but they're left for future readers to help in the |
| 126 | + // same understanding process. |
| 127 | + llama_seq_id seq_id = 42; |
| 128 | + llama_batch batch = llama_batch_init(3, 0, 1); |
| 129 | + common_batch_add(batch, 101, 0, {seq_id}, false); |
| 130 | + common_batch_add(batch, 1, 1, {seq_id}, false); |
| 131 | + common_batch_add(batch, 102, 2, {seq_id}, false); |
| 132 | + llama_sbatch sbatch(batch, 0, true, false); |
| 133 | + GGML_ASSERT(batch.n_tokens == 3); |
| 134 | + GGML_ASSERT(sbatch.n_tokens == 3); |
| 135 | + GGML_ASSERT(!sbatch.seq.empty()); |
| 136 | + llama_ubatch ubatch = sbatch.split_simple(4); |
| 137 | + printf("ubatch.n_seqs=%d\n", ubatch.n_seqs); |
| 138 | + GGML_ASSERT(ubatch.n_seqs == 3); |
| 139 | + GGML_ASSERT(ubatch.n_seq_tokens == 1); |
| 140 | + GGML_ASSERT(ubatch.n_tokens == 3); |
| 141 | + GGML_ASSERT(ubatch.seq_id[0][0] == seq_id); |
| 142 | + GGML_ASSERT(ubatch.seq_id[1][0] == seq_id); |
| 143 | + GGML_ASSERT(ubatch.seq_id[2][0] == seq_id); |
| 144 | + |
| 145 | + // Find a slot for a new sequence |
| 146 | + GGML_ASSERT(cache.find_slot(ubatch)); |
| 147 | + printf("cache.head=%d\n", cache.head); |
| 148 | + GGML_ASSERT(cache.head == 0); // Ready to start filling at the beginning |
| 149 | + GGML_ASSERT(cache.used == 3); |
| 150 | + for (int i = 0; i < 3; ++i) { |
| 151 | + GGML_ASSERT(cache.cells[i].seq_id.size() == 1); |
| 152 | + GGML_ASSERT(*cache.cells[i].seq_id.begin() == seq_id); |
| 153 | + GGML_ASSERT(cache.cells[i].pos == i); |
| 154 | + } |
| 155 | + |
| 156 | + // Clean up |
| 157 | + llama_batch_free(batch); |
| 158 | +} |
| 159 | + |
106 | 160 | /*- Recurrent Cache ----------------------------------------------------------*/
|
107 | 161 |
|
108 | 162 | /* Test that the recurrent cache can be constructed and destructed safely */
|
|
0 commit comments