@@ -13040,46 +13040,94 @@ struct llm_build_bailingmoe : public llm_graph_context {
13040
13040
}
13041
13041
};
13042
13042
13043
- llama_memory_i * llama_model::create_memory(const llama_memory_params & params, llama_cparams & cparams) const {
13043
+ llama_memory_i * llama_model::create_memory(
13044
+ const llama_memory_params & params,
13045
+ llama_cparams & cparams,
13046
+ const llama_hparams & hparams) const {
13044
13047
llama_memory_i * res;
13045
13048
13046
13049
switch (arch) {
13050
+ // Models that need specific instantiation should be handled in the
13051
+ // switch statement
13047
13052
case LLM_ARCH_BERT:
13048
13053
case LLM_ARCH_JINA_BERT_V2:
13049
13054
case LLM_ARCH_NOMIC_BERT:
13050
13055
case LLM_ARCH_NOMIC_BERT_MOE:
13051
13056
{
13052
13057
res = nullptr;
13053
13058
} break;
13054
- case LLM_ARCH_MAMBA:
13055
- case LLM_ARCH_RWKV6:
13056
- case LLM_ARCH_RWKV6QWEN2:
13057
- case LLM_ARCH_RWKV7:
13058
- case LLM_ARCH_ARWKV7:
13059
- {
13060
- res = new llama_kv_cache_recurrent(
13061
- *this,
13062
- GGML_TYPE_F32,
13063
- GGML_TYPE_F32,
13064
- cparams.offload_kqv,
13065
- std::max((uint32_t) 1, cparams.n_seq_max));
13066
- } break;
13059
+ // Models that need standard caching should rely on recurrent/hybrid
13060
+ // checks
13067
13061
default:
13068
13062
{
13069
- const auto padding = llama_kv_cache_unified::get_padding(cparams);
13063
+ if (llm_arch_is_hybrid(arch)) {
13064
+ // make vectors of recurrent and non-recurrent layer indices
13065
+ std::vector<size_t> recurrent_layers;
13066
+ std::vector<size_t> unified_layers;
13067
+ for (auto il = 0u; il < hparams.n_layer; ++il) {
13068
+ if (hparams.recurrent_layer(il)) {
13069
+ recurrent_layers.push_back(il);
13070
+ } else {
13071
+ unified_layers.push_back(il);
13072
+ }
13073
+ }
13074
+
13075
+ const auto padding = llama_kv_cache_unified::get_padding(cparams);
13076
+ cparams.n_ctx = GGML_PAD(cparams.n_ctx, padding);
13077
+ LLAMA_LOG_DEBUG("%s: n_ctx = %u (padded)\n", __func__, cparams.n_ctx);
13078
+
13079
+ // initialize the children
13080
+ std::vector<llama_kv_cache_hybrid::child_cache> children;
13081
+ children.emplace_back(
13082
+ std::unique_ptr<llama_kv_cache>(
13083
+ new llama_kv_cache_recurrent(
13084
+ *this,
13085
+ GGML_TYPE_F32,
13086
+ GGML_TYPE_F32,
13087
+ cparams.offload_kqv,
13088
+ std::max((uint32_t) 1, cparams.n_seq_max))
13089
+ ),
13090
+ std::move(recurrent_layers)
13091
+ );
13092
+ children.emplace_back(
13093
+ std::unique_ptr<llama_kv_cache>(
13094
+ new llama_kv_cache_unified(
13095
+ *this,
13096
+ params.type_k,
13097
+ params.type_v,
13098
+ !cparams.flash_attn,
13099
+ cparams.offload_kqv,
13100
+ cparams.n_ctx,
13101
+ padding)
13102
+ ),
13103
+ std::move(unified_layers)
13104
+ );
13105
+
13106
+ // initialize the hybrid cache with both children
13107
+ res = new llama_kv_cache_hybrid(hparams, std::move(children));
13108
+ } else if (llm_arch_is_recurrent(arch)) {
13109
+ res = new llama_kv_cache_recurrent(
13110
+ *this,
13111
+ GGML_TYPE_F32,
13112
+ GGML_TYPE_F32,
13113
+ cparams.offload_kqv,
13114
+ std::max((uint32_t) 1, cparams.n_seq_max));
13115
+ } else {
13116
+ const auto padding = llama_kv_cache_unified::get_padding(cparams);
13070
13117
13071
- cparams.n_ctx = GGML_PAD(cparams.n_ctx, padding);
13118
+ cparams.n_ctx = GGML_PAD(cparams.n_ctx, padding);
13072
13119
13073
- LLAMA_LOG_DEBUG("%s: n_ctx = %u (padded)\n", __func__, cparams.n_ctx);
13120
+ LLAMA_LOG_DEBUG("%s: n_ctx = %u (padded)\n", __func__, cparams.n_ctx);
13074
13121
13075
- res = new llama_kv_cache_unified(
13076
- *this,
13077
- params.type_k,
13078
- params.type_v,
13079
- !cparams.flash_attn,
13080
- cparams.offload_kqv,
13081
- cparams.n_ctx,
13082
- padding);
13122
+ res = new llama_kv_cache_unified(
13123
+ *this,
13124
+ params.type_k,
13125
+ params.type_v,
13126
+ !cparams.flash_attn,
13127
+ cparams.offload_kqv,
13128
+ cparams.n_ctx,
13129
+ padding);
13130
+ }
13083
13131
}
13084
13132
}
13085
13133
0 commit comments