@@ -13196,6 +13196,8 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
13196
13196
llama_memory_i * res;
13197
13197
13198
13198
switch (arch) {
13199
+ // Models that need specific instantiation should be handled in the
13200
+ // switch statement
13199
13201
case LLM_ARCH_BERT:
13200
13202
case LLM_ARCH_JINA_BERT_V2:
13201
13203
case LLM_ARCH_NOMIC_BERT:
@@ -13204,58 +13206,71 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
13204
13206
{
13205
13207
res = nullptr;
13206
13208
} break;
13207
- case LLM_ARCH_MAMBA:
13208
- case LLM_ARCH_RWKV6:
13209
- case LLM_ARCH_RWKV6QWEN2:
13210
- case LLM_ARCH_RWKV7:
13211
- case LLM_ARCH_ARWKV7:
13212
- {
13213
- res = new llama_kv_cache_recurrent(
13214
- *this,
13215
- nullptr,
13216
- GGML_TYPE_F32,
13217
- GGML_TYPE_F32,
13218
- cparams.offload_kqv,
13219
- std::max((uint32_t) 1, cparams.n_seq_max),
13220
- cparams.n_seq_max);
13221
- } break;
13209
+ // Models that need standard caching should rely on recurrent/hybrid
13210
+ // checks
13222
13211
default:
13223
13212
{
13224
- const auto padding = llama_kv_cache_unified::get_padding(cparams);
13225
-
13226
- cparams.n_ctx = GGML_PAD(cparams.n_ctx, padding);
13227
-
13228
- LLAMA_LOG_DEBUG("%s: n_ctx = %u (padded)\n", __func__, cparams.n_ctx);
13229
-
13230
- if (hparams.swa_type != LLAMA_SWA_TYPE_NONE) {
13231
- GGML_ASSERT(hparams.is_swa_any());
13232
-
13233
- res = new llama_kv_cache_unified_iswa(
13234
- *this,
13235
- params.type_k,
13236
- params.type_v,
13237
- !cparams.flash_attn,
13238
- cparams.offload_kqv,
13239
- params.swa_full,
13240
- cparams.n_ctx,
13241
- cparams.n_seq_max,
13242
- cparams.n_batch,
13243
- padding);
13244
- } else {
13245
- GGML_ASSERT(!hparams.is_swa_any());
13246
-
13247
- res = new llama_kv_cache_unified(
13213
+ if (llm_arch_is_recurrent(arch)) {
13214
+ res = new llama_kv_cache_recurrent(
13248
13215
*this,
13249
13216
nullptr,
13250
- params.type_k,
13251
- params.type_v,
13252
- !cparams.flash_attn,
13217
+ GGML_TYPE_F32,
13218
+ GGML_TYPE_F32,
13253
13219
cparams.offload_kqv,
13254
- cparams.n_ctx,
13255
- cparams.n_seq_max,
13256
- padding,
13257
- hparams.n_swa,
13258
- hparams.swa_type);
13220
+ std::max((uint32_t) 1, cparams.n_seq_max),
13221
+ cparams.n_seq_max);
13222
+ } else if (llm_arch_is_hybrid_recurrent(arch)) {
13223
+ res = new llama_kv_cache_hybrid_recurrent(
13224
+ /* model */ *this,
13225
+ /* attn_type_k */ params.type_k,
13226
+ /* attn_type_v */ params.type_v,
13227
+ /* attn_v_trans */ !cparams.flash_attn,
13228
+ /* attn_kv_size */ cparams.n_ctx,
13229
+ /* attn_n_pad */ llama_kv_cache_unified::get_padding(cparams),
13230
+ /* attn_n_swa */ hparams.n_swa,
13231
+ /* attn_swa_type */ hparams.swa_type,
13232
+ /* recurrent_type_k */ GGML_TYPE_F32,
13233
+ /* recurrent_type_v */ GGML_TYPE_F32,
13234
+ /* recurrent_kv_size */ std::max((uint32_t) 1, cparams.n_seq_max),
13235
+ /* n_seq_max */ cparams.n_seq_max,
13236
+ /* offload */ cparams.offload_kqv);
13237
+ } else {
13238
+ const auto padding = llama_kv_cache_unified::get_padding(cparams);
13239
+
13240
+ cparams.n_ctx = GGML_PAD(cparams.n_ctx, padding);
13241
+
13242
+ LLAMA_LOG_DEBUG("%s: n_ctx = %u (padded)\n", __func__, cparams.n_ctx);
13243
+
13244
+ if (hparams.swa_type != LLAMA_SWA_TYPE_NONE) {
13245
+ GGML_ASSERT(hparams.is_swa_any());
13246
+
13247
+ res = new llama_kv_cache_unified_iswa(
13248
+ *this,
13249
+ params.type_k,
13250
+ params.type_v,
13251
+ !cparams.flash_attn,
13252
+ cparams.offload_kqv,
13253
+ params.swa_full,
13254
+ cparams.n_ctx,
13255
+ cparams.n_seq_max,
13256
+ cparams.n_batch,
13257
+ padding);
13258
+ } else {
13259
+ GGML_ASSERT(!hparams.is_swa_any());
13260
+
13261
+ res = new llama_kv_cache_unified(
13262
+ *this,
13263
+ nullptr,
13264
+ params.type_k,
13265
+ params.type_v,
13266
+ !cparams.flash_attn,
13267
+ cparams.offload_kqv,
13268
+ cparams.n_ctx,
13269
+ cparams.n_seq_max,
13270
+ padding,
13271
+ hparams.n_swa,
13272
+ hparams.swa_type);
13273
+ }
13259
13274
}
13260
13275
}
13261
13276
}
0 commit comments