@@ -13190,6 +13190,8 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
13190
13190
llama_memory_i * res;
13191
13191
13192
13192
switch (arch) {
13193
+ // Models that need specific instantiation should be handled in the
13194
+ // switch statement
13193
13195
case LLM_ARCH_BERT:
13194
13196
case LLM_ARCH_JINA_BERT_V2:
13195
13197
case LLM_ARCH_NOMIC_BERT:
@@ -13198,58 +13200,71 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
13198
13200
{
13199
13201
res = nullptr;
13200
13202
} break;
13201
- case LLM_ARCH_MAMBA:
13202
- case LLM_ARCH_RWKV6:
13203
- case LLM_ARCH_RWKV6QWEN2:
13204
- case LLM_ARCH_RWKV7:
13205
- case LLM_ARCH_ARWKV7:
13206
- {
13207
- res = new llama_kv_cache_recurrent(
13208
- *this,
13209
- nullptr,
13210
- GGML_TYPE_F32,
13211
- GGML_TYPE_F32,
13212
- cparams.offload_kqv,
13213
- std::max((uint32_t) 1, cparams.n_seq_max),
13214
- cparams.n_seq_max);
13215
- } break;
13203
+ // Models that need standard caching should rely on recurrent/hybrid
13204
+ // checks
13216
13205
default:
13217
13206
{
13218
- const auto padding = llama_kv_cache_unified::get_padding(cparams);
13219
-
13220
- cparams.n_ctx = GGML_PAD(cparams.n_ctx, padding);
13221
-
13222
- LLAMA_LOG_DEBUG("%s: n_ctx = %u (padded)\n", __func__, cparams.n_ctx);
13223
-
13224
- if (hparams.swa_type != LLAMA_SWA_TYPE_NONE) {
13225
- GGML_ASSERT(hparams.is_swa_any());
13226
-
13227
- res = new llama_kv_cache_unified_iswa(
13228
- *this,
13229
- params.type_k,
13230
- params.type_v,
13231
- !cparams.flash_attn,
13232
- cparams.offload_kqv,
13233
- params.swa_full,
13234
- cparams.n_ctx,
13235
- cparams.n_seq_max,
13236
- cparams.n_batch,
13237
- padding);
13238
- } else {
13239
- GGML_ASSERT(!hparams.is_swa_any());
13240
-
13241
- res = new llama_kv_cache_unified(
13207
+ if (llm_arch_is_recurrent(arch)) {
13208
+ res = new llama_kv_cache_recurrent(
13242
13209
*this,
13243
13210
nullptr,
13244
- params.type_k,
13245
- params.type_v,
13246
- !cparams.flash_attn,
13211
+ GGML_TYPE_F32,
13212
+ GGML_TYPE_F32,
13247
13213
cparams.offload_kqv,
13248
- cparams.n_ctx,
13249
- cparams.n_seq_max,
13250
- padding,
13251
- hparams.n_swa,
13252
- hparams.swa_type);
13214
+ std::max((uint32_t) 1, cparams.n_seq_max),
13215
+ cparams.n_seq_max);
13216
+ } else if (llm_arch_is_hybrid_recurrent(arch)) {
13217
+ res = new llama_kv_cache_hybrid_recurrent(
13218
+ /* model */ *this,
13219
+ /* attn_type_k */ params.type_k,
13220
+ /* attn_type_v */ params.type_v,
13221
+ /* attn_v_trans */ !cparams.flash_attn,
13222
+ /* attn_kv_size */ cparams.n_ctx,
13223
+ /* attn_n_pad */ llama_kv_cache_unified::get_padding(cparams),
13224
+ /* attn_n_swa */ hparams.n_swa,
13225
+ /* attn_swa_type */ hparams.swa_type,
13226
+ /* recurrent_type_k */ GGML_TYPE_F32,
13227
+ /* recurrent_type_v */ GGML_TYPE_F32,
13228
+ /* recurrent_kv_size */ std::max((uint32_t) 1, cparams.n_seq_max),
13229
+ /* n_seq_max */ cparams.n_seq_max,
13230
+ /* offload */ cparams.offload_kqv);
13231
+ } else {
13232
+ const auto padding = llama_kv_cache_unified::get_padding(cparams);
13233
+
13234
+ cparams.n_ctx = GGML_PAD(cparams.n_ctx, padding);
13235
+
13236
+ LLAMA_LOG_DEBUG("%s: n_ctx = %u (padded)\n", __func__, cparams.n_ctx);
13237
+
13238
+ if (hparams.swa_type != LLAMA_SWA_TYPE_NONE) {
13239
+ GGML_ASSERT(hparams.is_swa_any());
13240
+
13241
+ res = new llama_kv_cache_unified_iswa(
13242
+ *this,
13243
+ params.type_k,
13244
+ params.type_v,
13245
+ !cparams.flash_attn,
13246
+ cparams.offload_kqv,
13247
+ params.swa_full,
13248
+ cparams.n_ctx,
13249
+ cparams.n_seq_max,
13250
+ cparams.n_batch,
13251
+ padding);
13252
+ } else {
13253
+ GGML_ASSERT(!hparams.is_swa_any());
13254
+
13255
+ res = new llama_kv_cache_unified(
13256
+ *this,
13257
+ nullptr,
13258
+ params.type_k,
13259
+ params.type_v,
13260
+ !cparams.flash_attn,
13261
+ cparams.offload_kqv,
13262
+ cparams.n_ctx,
13263
+ cparams.n_seq_max,
13264
+ padding,
13265
+ hparams.n_swa,
13266
+ hparams.swa_type);
13267
+ }
13253
13268
}
13254
13269
}
13255
13270
}
0 commit comments