Skip to content

Commit 60ed266

Browse files
committed
tests: Use real params for allocation in all tests
Branch: HybridCache Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>
1 parent f46a727 commit 60ed266

File tree

1 file changed

+44
-7
lines changed

1 file changed

+44
-7
lines changed

tests/test-memory.cpp

Lines changed: 44 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,40 @@
1212

1313
/*- Helpers ------------------------------------------------------------------*/
1414

15-
static std::shared_ptr<llama_model> _make_model() {
15+
static std::shared_ptr<llama_model> _make_model(
16+
llm_arch arch = LLM_ARCH_LLAMA,
17+
uint32_t n_layer = 4,
18+
uint32_t n_embd_head_k = 4,
19+
uint32_t n_embd_head_v = 4,
20+
uint32_t n_head = 8,
21+
uint32_t n_head_kv = 2) {
22+
1623
llama_model_params params;
1724
params.tensor_buft_overrides = nullptr;
1825
std::shared_ptr<llama_model> model(new llama_model(params));
1926
model->hparams = llama_hparams();
20-
model->arch = LLM_ARCH_LLAMA;
27+
model->arch = arch;
28+
29+
model->hparams.n_layer = n_layer;
30+
model->hparams.n_embd_head_k = n_embd_head_k;
31+
model->hparams.n_embd_head_v = n_embd_head_v;
32+
33+
auto& recurrent_layer_arr = model->hparams.recurrent_layer_arr;
34+
std::fill(
35+
recurrent_layer_arr.begin(),
36+
recurrent_layer_arr.end(),
37+
llm_arch_is_recurrent(arch));
38+
39+
// If set to 0, assume the test will fill out the array elementwise (hybrid)
40+
if (n_head > 0) {
41+
auto& n_head_arr = model->hparams.n_head_arr;
42+
std::fill(n_head_arr.begin(), n_head_arr.end(), n_head);
43+
}
44+
if (n_head_kv > 0) {
45+
auto& n_head_kv_arr = model->hparams.n_head_kv_arr;
46+
std::fill(n_head_kv_arr.begin(), n_head_kv_arr.end(), n_head_kv);
47+
}
48+
2149
return model;
2250
}
2351

@@ -57,7 +85,7 @@ static void test_llama_kv_cache_unified_constructor() {
5785
/* Test that the recurrent cache can be constructed and destructed safely */
5886
static void test_llama_kv_cache_recurrent_constructor() {
5987
LOG_SCOPE();
60-
auto model = _make_model();
88+
auto model = _make_model(LLM_ARCH_MAMBA);
6189
llama_kv_cache_recurrent cache(
6290
/* model */ *model,
6391
/* type_k */ GGML_TYPE_F32,
@@ -72,15 +100,24 @@ static void test_llama_kv_cache_recurrent_constructor() {
72100
/* Test that the hybrid cache can be constructed and destructed safely */
73101
static void test_llama_kv_cache_hybrid_constructor() {
74102
LOG_SCOPE();
75-
auto model = _make_model();
76-
model->hparams.n_layer = 4;
77-
model->hparams.n_embd_head_k = 4;
78-
model->hparams.n_embd_head_v = 4;
103+
auto model = _make_model(
104+
/* arch =*/ LLM_ARCH_LLAMA,
105+
/* n_layer =*/ 4,
106+
/* n_embd_head_k =*/ 4,
107+
/* n_embd_head_v =*/ 4,
108+
/* n_head =*/ 0,
109+
/* n_head_kv =*/ 0
110+
);
79111
auto& recurrent_layer_arr = model->hparams.recurrent_layer_arr;
80112
recurrent_layer_arr[0] = 1;
81113
recurrent_layer_arr[1] = 0;
82114
recurrent_layer_arr[2] = 1;
83115
recurrent_layer_arr[3] = 0;
116+
auto& n_head_arr = model->hparams.n_head_arr;
117+
n_head_arr[0] = 16;
118+
n_head_arr[1] = 32;
119+
n_head_arr[2] = 16;
120+
n_head_arr[3] = 32;
84121
auto& n_head_kv_arr = model->hparams.n_head_kv_arr;
85122
n_head_kv_arr[0] = 16;
86123
n_head_kv_arr[1] = 8;

0 commit comments

Comments
 (0)