Skip to content

Commit f46a727

Browse files
committed
tests: Add initial unit tests for kv caches
So far this only tests constructor logic (and barely that) Branch: HybridCache Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>
1 parent 0cc96c9 commit f46a727

File tree

2 files changed

+134
-0
lines changed

2 files changed

+134
-0
lines changed

tests/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,7 @@ endif()
145145
llama_build_and_test(test-log.cpp)
146146
llama_build_and_test(test-chat-template.cpp)
147147
llama_build_and_test(test-regex-partial.cpp)
148+
llama_build_and_test(test-memory.cpp)
148149

149150
# this fails on windows (github hosted runner) due to curl DLL not found (exit code 0xc0000135)
150151
if (NOT WIN32)

tests/test-memory.cpp

Lines changed: 133 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,133 @@
1+
#include "../src/llama-arch.h"
2+
#include "../src/llama-hparams.h"
3+
#include "../src/llama-impl.h"
4+
#include "../src/llama-kv-cache.h"
5+
#include "../src/llama-model.h"
6+
7+
#include "llama.h"
8+
9+
#include <algorithm>
10+
#include <cstdio>
11+
#include <memory>
12+
13+
/*- Helpers ------------------------------------------------------------------*/
14+
15+
static std::shared_ptr<llama_model> _make_model() {
16+
llama_model_params params;
17+
params.tensor_buft_overrides = nullptr;
18+
std::shared_ptr<llama_model> model(new llama_model(params));
19+
model->hparams = llama_hparams();
20+
model->arch = LLM_ARCH_LLAMA;
21+
return model;
22+
}
23+
24+
struct log_scope {
25+
const char * name;
26+
explicit log_scope(const char * name) : name(name) {
27+
LLAMA_LOG_INFO("--------\n");
28+
LLAMA_LOG_INFO("START: %s\n", name);
29+
}
30+
~log_scope() {
31+
LLAMA_LOG_INFO("END: %s\n", name);
32+
LLAMA_LOG_INFO("--------\n");
33+
}
34+
};
35+
36+
#define LOG_SCOPE() log_scope __log_scope(__func__)
37+
38+
/*- Unified Cache ------------------------------------------------------------*/
39+
40+
/* Test that the unified cache can be constructed and destructed safely */
41+
static void test_llama_kv_cache_unified_constructor() {
42+
LOG_SCOPE();
43+
auto model = _make_model();
44+
llama_kv_cache_unified cache(
45+
/* model */ *model,
46+
/* type_k */ GGML_TYPE_F32,
47+
/* type_v */ GGML_TYPE_F16,
48+
/* v_trans */ false,
49+
/* offload */ false,
50+
/* kv_size */ 10,
51+
/* padding */ 10
52+
);
53+
}
54+
55+
/*- Recurrent Cache ----------------------------------------------------------*/
56+
57+
/* Test that the recurrent cache can be constructed and destructed safely */
58+
static void test_llama_kv_cache_recurrent_constructor() {
59+
LOG_SCOPE();
60+
auto model = _make_model();
61+
llama_kv_cache_recurrent cache(
62+
/* model */ *model,
63+
/* type_k */ GGML_TYPE_F32,
64+
/* type_v */ GGML_TYPE_F16,
65+
/* offload */ false,
66+
/* kv_size */ 10
67+
);
68+
}
69+
70+
/*- Hybrid Cache -------------------------------------------------------------*/
71+
72+
/* Test that the hybrid cache can be constructed and destructed safely */
73+
static void test_llama_kv_cache_hybrid_constructor() {
74+
LOG_SCOPE();
75+
auto model = _make_model();
76+
model->hparams.n_layer = 4;
77+
model->hparams.n_embd_head_k = 4;
78+
model->hparams.n_embd_head_v = 4;
79+
auto& recurrent_layer_arr = model->hparams.recurrent_layer_arr;
80+
recurrent_layer_arr[0] = 1;
81+
recurrent_layer_arr[1] = 0;
82+
recurrent_layer_arr[2] = 1;
83+
recurrent_layer_arr[3] = 0;
84+
auto& n_head_kv_arr = model->hparams.n_head_kv_arr;
85+
n_head_kv_arr[0] = 16;
86+
n_head_kv_arr[1] = 8;
87+
n_head_kv_arr[2] = 16;
88+
n_head_kv_arr[3] = 8;
89+
90+
std::unique_ptr<llama_kv_cache_unified> u_cache(
91+
new llama_kv_cache_unified(
92+
/* model */ *model,
93+
/* type_k */ GGML_TYPE_F32,
94+
/* type_v */ GGML_TYPE_F16,
95+
/* v_trans */ false,
96+
/* offload */ false,
97+
/* kv_size */ 20,
98+
/* padding */ 2
99+
)
100+
);
101+
auto * u_cache_ptr = u_cache.get();
102+
std::unique_ptr<llama_kv_cache_recurrent> r_cache (
103+
new llama_kv_cache_recurrent(
104+
/* model */ *model,
105+
/* type_k */ GGML_TYPE_F32,
106+
/* type_v */ GGML_TYPE_F16,
107+
/* offload */ false,
108+
/* kv_size */ 10
109+
)
110+
);
111+
auto * r_cache_ptr = r_cache.get();
112+
113+
std::vector<llama_kv_cache_hybrid::child_cache> children;
114+
children.emplace_back(std::move(u_cache), std::vector<size_t>{1, 3});
115+
children.emplace_back(std::move(r_cache), std::vector<size_t>{0, 2});
116+
117+
llama_kv_cache_hybrid cache(model->hparams, std::move(children));
118+
119+
GGML_ASSERT(cache.get_child_cache<llama_kv_cache_unified>() == u_cache_ptr);
120+
GGML_ASSERT(cache.get_child_cache<llama_kv_cache_recurrent>() == r_cache_ptr);
121+
}
122+
123+
/*- Main ---------------------------------------------------------------------*/
124+
125+
int main() {
126+
// Unified Cache Tests
127+
test_llama_kv_cache_unified_constructor();
128+
// Recurrent Cache Tests
129+
test_llama_kv_cache_recurrent_constructor();
130+
// Hybrid Cache Tests
131+
test_llama_kv_cache_hybrid_constructor();
132+
return 0;
133+
}

0 commit comments

Comments
 (0)