12
12
13
13
/* - Helpers ------------------------------------------------------------------*/
14
14
15
- static std::shared_ptr<llama_model> _make_model () {
15
+ static std::shared_ptr<llama_model> _make_model (
16
+ llm_arch arch = LLM_ARCH_LLAMA,
17
+ uint32_t n_layer = 4 ,
18
+ uint32_t n_embd_head_k = 4 ,
19
+ uint32_t n_embd_head_v = 4 ,
20
+ uint32_t n_head = 8 ,
21
+ uint32_t n_head_kv = 2 ) {
22
+
16
23
llama_model_params params;
17
24
params.tensor_buft_overrides = nullptr ;
18
25
std::shared_ptr<llama_model> model (new llama_model (params));
19
26
model->hparams = llama_hparams ();
20
- model->arch = LLM_ARCH_LLAMA;
27
+ model->arch = arch;
28
+
29
+ model->hparams .n_layer = n_layer;
30
+ model->hparams .n_embd_head_k = n_embd_head_k;
31
+ model->hparams .n_embd_head_v = n_embd_head_v;
32
+
33
+ auto & recurrent_layer_arr = model->hparams .recurrent_layer_arr ;
34
+ std::fill (
35
+ recurrent_layer_arr.begin (),
36
+ recurrent_layer_arr.end (),
37
+ llm_arch_is_recurrent (arch));
38
+
39
+ // If set to 0, assume the test will fill out the array elementwise (hybrid)
40
+ if (n_head > 0 ) {
41
+ auto & n_head_arr = model->hparams .n_head_arr ;
42
+ std::fill (n_head_arr.begin (), n_head_arr.end (), n_head);
43
+ }
44
+ if (n_head_kv > 0 ) {
45
+ auto & n_head_kv_arr = model->hparams .n_head_kv_arr ;
46
+ std::fill (n_head_kv_arr.begin (), n_head_kv_arr.end (), n_head_kv);
47
+ }
48
+
21
49
return model;
22
50
}
23
51
@@ -57,7 +85,7 @@ static void test_llama_kv_cache_unified_constructor() {
57
85
/* Test that the recurrent cache can be constructed and destructed safely */
58
86
static void test_llama_kv_cache_recurrent_constructor () {
59
87
LOG_SCOPE ();
60
- auto model = _make_model ();
88
+ auto model = _make_model (LLM_ARCH_MAMBA );
61
89
llama_kv_cache_recurrent cache (
62
90
/* model */ *model,
63
91
/* type_k */ GGML_TYPE_F32,
@@ -72,15 +100,24 @@ static void test_llama_kv_cache_recurrent_constructor() {
72
100
/* Test that the hybrid cache can be constructed and destructed safely */
73
101
static void test_llama_kv_cache_hybrid_constructor () {
74
102
LOG_SCOPE ();
75
- auto model = _make_model ();
76
- model->hparams .n_layer = 4 ;
77
- model->hparams .n_embd_head_k = 4 ;
78
- model->hparams .n_embd_head_v = 4 ;
103
+ auto model = _make_model (
104
+ /* arch =*/ LLM_ARCH_LLAMA,
105
+ /* n_layer =*/ 4 ,
106
+ /* n_embd_head_k =*/ 4 ,
107
+ /* n_embd_head_v =*/ 4 ,
108
+ /* n_head =*/ 0 ,
109
+ /* n_head_kv =*/ 0
110
+ );
79
111
auto & recurrent_layer_arr = model->hparams .recurrent_layer_arr ;
80
112
recurrent_layer_arr[0 ] = 1 ;
81
113
recurrent_layer_arr[1 ] = 0 ;
82
114
recurrent_layer_arr[2 ] = 1 ;
83
115
recurrent_layer_arr[3 ] = 0 ;
116
+ auto & n_head_arr = model->hparams .n_head_arr ;
117
+ n_head_arr[0 ] = 16 ;
118
+ n_head_arr[1 ] = 32 ;
119
+ n_head_arr[2 ] = 16 ;
120
+ n_head_arr[3 ] = 32 ;
84
121
auto & n_head_kv_arr = model->hparams .n_head_kv_arr ;
85
122
n_head_kv_arr[0 ] = 16 ;
86
123
n_head_kv_arr[1 ] = 8 ;
0 commit comments