@@ -13195,59 +13195,112 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
13195
13195
llama_memory_i * res;
13196
13196
13197
13197
switch (arch) {
13198
+ // Models that need specific instantiation should be handled in the
13199
+ // switch statement
13198
13200
case LLM_ARCH_BERT:
13199
13201
case LLM_ARCH_JINA_BERT_V2:
13200
13202
case LLM_ARCH_NOMIC_BERT:
13201
13203
case LLM_ARCH_NOMIC_BERT_MOE:
13202
13204
{
13203
13205
res = nullptr;
13204
13206
} break;
13205
- case LLM_ARCH_MAMBA:
13206
- case LLM_ARCH_RWKV6:
13207
- case LLM_ARCH_RWKV6QWEN2:
13208
- case LLM_ARCH_RWKV7:
13209
- case LLM_ARCH_ARWKV7:
13210
- {
13211
- res = new llama_kv_cache_recurrent(
13212
- *this,
13213
- nullptr,
13214
- GGML_TYPE_F32,
13215
- GGML_TYPE_F32,
13216
- cparams.offload_kqv,
13217
- std::max((uint32_t) 1, cparams.n_seq_max));
13218
- } break;
13207
+ // Models that need standard caching should rely on recurrent/hybrid
13208
+ // checks
13219
13209
default:
13220
13210
{
13221
- const auto padding = llama_kv_cache_unified::get_padding(cparams);
13222
-
13223
- cparams.n_ctx = GGML_PAD(cparams.n_ctx, padding);
13211
+ if (llm_arch_is_hybrid(arch)) {
13212
+ // make vectors of recurrent and non-recurrent layer indices
13213
+ std::vector<size_t> recurrent_layers;
13214
+ std::vector<size_t> unified_layers;
13215
+ for (auto il = 0u; il < hparams.n_layer; ++il) {
13216
+ if (hparams.recurrent_layer(il)) {
13217
+ recurrent_layers.push_back(il);
13218
+ } else {
13219
+ unified_layers.push_back(il);
13220
+ }
13221
+ }
13224
13222
13225
- LLAMA_LOG_DEBUG("%s: n_ctx = %u (padded)\n", __func__, cparams.n_ctx);
13223
+ const auto padding = llama_kv_cache_unified::get_padding(cparams);
13224
+ cparams.n_ctx = GGML_PAD(cparams.n_ctx, padding);
13225
+ LLAMA_LOG_DEBUG("%s: n_ctx = %u (padded)\n", __func__, cparams.n_ctx);
13226
+
13227
+ // initialize the children
13228
+ std::vector<llama_kv_cache_hybrid::child_cache> children;
13229
+ children.emplace_back(
13230
+ std::unique_ptr<llama_kv_cache>(
13231
+ new llama_kv_cache_recurrent(
13232
+ *this,
13233
+ [&](int32_t il) {
13234
+ return hparams.recurrent_layer(il);
13235
+ },
13236
+ GGML_TYPE_F32,
13237
+ GGML_TYPE_F32,
13238
+ cparams.offload_kqv,
13239
+ std::max((uint32_t) 1, cparams.n_seq_max))
13240
+ ),
13241
+ std::move(recurrent_layers)
13242
+ );
13243
+ children.emplace_back(
13244
+ std::unique_ptr<llama_kv_cache>(
13245
+ new llama_kv_cache_unified(
13246
+ *this,
13247
+ [&](int32_t il) {
13248
+ return ! hparams.recurrent_layer(il);
13249
+ },
13250
+ params.type_k,
13251
+ params.type_v,
13252
+ !cparams.flash_attn,
13253
+ cparams.offload_kqv,
13254
+ cparams.n_ctx,
13255
+ padding,
13256
+ hparams.n_swa,
13257
+ hparams.swa_type)
13258
+ ),
13259
+ std::move(unified_layers)
13260
+ );
13226
13261
13227
- if (hparams.n_swa > 0) {
13228
- res = new llama_kv_cache_unified_iswa(
13229
- *this,
13230
- params.type_k,
13231
- params.type_v,
13232
- !cparams.flash_attn,
13233
- cparams.offload_kqv,
13234
- cparams.n_ctx,
13235
- params.swa_full,
13236
- cparams.n_seq_max,
13237
- cparams.n_batch,
13238
- padding);
13239
- } else {
13240
- res = new llama_kv_cache_unified(
13262
+ // initialize the hybrid cache with both children
13263
+ res = new llama_kv_cache_hybrid(hparams, std::move(children));
13264
+ } else if (llm_arch_is_recurrent(arch)) {
13265
+ res = new llama_kv_cache_recurrent(
13241
13266
*this,
13242
13267
nullptr,
13243
- params.type_k,
13244
- params.type_v,
13245
- !cparams.flash_attn,
13268
+ GGML_TYPE_F32,
13269
+ GGML_TYPE_F32,
13246
13270
cparams.offload_kqv,
13247
- cparams.n_ctx,
13248
- padding,
13249
- hparams.n_swa,
13250
- hparams.swa_type);
13271
+ std::max((uint32_t) 1, cparams.n_seq_max));
13272
+ } else {
13273
+ const auto padding = llama_kv_cache_unified::get_padding(cparams);
13274
+
13275
+ cparams.n_ctx = GGML_PAD(cparams.n_ctx, padding);
13276
+
13277
+ LLAMA_LOG_DEBUG("%s: n_ctx = %u (padded)\n", __func__, cparams.n_ctx);
13278
+
13279
+ if (hparams.n_swa > 0) {
13280
+ res = new llama_kv_cache_unified_iswa(
13281
+ *this,
13282
+ params.type_k,
13283
+ params.type_v,
13284
+ !cparams.flash_attn,
13285
+ cparams.offload_kqv,
13286
+ cparams.n_ctx,
13287
+ params.swa_full,
13288
+ cparams.n_seq_max,
13289
+ cparams.n_batch,
13290
+ padding);
13291
+ } else {
13292
+ res = new llama_kv_cache_unified(
13293
+ *this,
13294
+ nullptr,
13295
+ params.type_k,
13296
+ params.type_v,
13297
+ !cparams.flash_attn,
13298
+ cparams.offload_kqv,
13299
+ cparams.n_ctx,
13300
+ padding,
13301
+ hparams.n_swa,
13302
+ hparams.swa_type);
13303
+ }
13251
13304
}
13252
13305
}
13253
13306
}
0 commit comments