Skip to content

Commit cbf6b10

Browse files
committed
feat!: Instantiate hybrid cache for hybrid models
There is a small breaking change here that extends the create_memory method signature to include the hparams. Currently, this member is only used inside llama_context and is not part of an interface that's expected to be extended by classes derived from llama_model, so I don't think this should actually break any downstream use cases. Branch: HybridCache Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>
1 parent bb7d4bd commit cbf6b10

File tree

3 files changed

+78
-27
lines changed

3 files changed

+78
-27
lines changed

src/llama-context.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -180,7 +180,7 @@ llama_context::llama_context(
180180
/*.type_v =*/ params.type_v,
181181
};
182182

183-
memory.reset(model.create_memory(params_mem, cparams));
183+
memory.reset(model.create_memory(params_mem, cparams, hparams));
184184
}
185185

186186
// init backends

src/llama-model.cpp

Lines changed: 73 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -13040,46 +13040,94 @@ struct llm_build_bailingmoe : public llm_graph_context {
1304013040
}
1304113041
};
1304213042

13043-
llama_memory_i * llama_model::create_memory(const llama_memory_params & params, llama_cparams & cparams) const {
13043+
llama_memory_i * llama_model::create_memory(
13044+
const llama_memory_params & params,
13045+
llama_cparams & cparams,
13046+
const llama_hparams & hparams) const {
1304413047
llama_memory_i * res;
1304513048

1304613049
switch (arch) {
13050+
// Models that need specific instantiation should be handled in the
13051+
// switch statement
1304713052
case LLM_ARCH_BERT:
1304813053
case LLM_ARCH_JINA_BERT_V2:
1304913054
case LLM_ARCH_NOMIC_BERT:
1305013055
case LLM_ARCH_NOMIC_BERT_MOE:
1305113056
{
1305213057
res = nullptr;
1305313058
} break;
13054-
case LLM_ARCH_MAMBA:
13055-
case LLM_ARCH_RWKV6:
13056-
case LLM_ARCH_RWKV6QWEN2:
13057-
case LLM_ARCH_RWKV7:
13058-
case LLM_ARCH_ARWKV7:
13059-
{
13060-
res = new llama_kv_cache_recurrent(
13061-
*this,
13062-
GGML_TYPE_F32,
13063-
GGML_TYPE_F32,
13064-
cparams.offload_kqv,
13065-
std::max((uint32_t) 1, cparams.n_seq_max));
13066-
} break;
13059+
// Models that need standard caching should rely on recurrent/hybrid
13060+
// checks
1306713061
default:
1306813062
{
13069-
const auto padding = llama_kv_cache_unified::get_padding(cparams);
13063+
if (llm_arch_is_hybrid(arch)) {
13064+
// make vectors of recurrent and non-recurrent layer indices
13065+
std::vector<size_t> recurrent_layers;
13066+
std::vector<size_t> unified_layers;
13067+
for (auto il = 0u; il < hparams.n_layer; ++il) {
13068+
if (hparams.recurrent_layer(il)) {
13069+
recurrent_layers.push_back(il);
13070+
} else {
13071+
unified_layers.push_back(il);
13072+
}
13073+
}
13074+
13075+
const auto padding = llama_kv_cache_unified::get_padding(cparams);
13076+
cparams.n_ctx = GGML_PAD(cparams.n_ctx, padding);
13077+
LLAMA_LOG_DEBUG("%s: n_ctx = %u (padded)\n", __func__, cparams.n_ctx);
13078+
13079+
// initialize the children
13080+
std::vector<llama_kv_cache_hybrid::child_cache> children;
13081+
children.emplace_back(
13082+
std::unique_ptr<llama_kv_cache>(
13083+
new llama_kv_cache_recurrent(
13084+
*this,
13085+
GGML_TYPE_F32,
13086+
GGML_TYPE_F32,
13087+
cparams.offload_kqv,
13088+
std::max((uint32_t) 1, cparams.n_seq_max))
13089+
),
13090+
std::move(recurrent_layers)
13091+
);
13092+
children.emplace_back(
13093+
std::unique_ptr<llama_kv_cache>(
13094+
new llama_kv_cache_unified(
13095+
*this,
13096+
params.type_k,
13097+
params.type_v,
13098+
!cparams.flash_attn,
13099+
cparams.offload_kqv,
13100+
cparams.n_ctx,
13101+
padding)
13102+
),
13103+
std::move(unified_layers)
13104+
);
13105+
13106+
// initialize the hybrid cache with both children
13107+
res = new llama_kv_cache_hybrid(hparams, std::move(children));
13108+
} else if (llm_arch_is_recurrent(arch)) {
13109+
res = new llama_kv_cache_recurrent(
13110+
*this,
13111+
GGML_TYPE_F32,
13112+
GGML_TYPE_F32,
13113+
cparams.offload_kqv,
13114+
std::max((uint32_t) 1, cparams.n_seq_max));
13115+
} else {
13116+
const auto padding = llama_kv_cache_unified::get_padding(cparams);
1307013117

13071-
cparams.n_ctx = GGML_PAD(cparams.n_ctx, padding);
13118+
cparams.n_ctx = GGML_PAD(cparams.n_ctx, padding);
1307213119

13073-
LLAMA_LOG_DEBUG("%s: n_ctx = %u (padded)\n", __func__, cparams.n_ctx);
13120+
LLAMA_LOG_DEBUG("%s: n_ctx = %u (padded)\n", __func__, cparams.n_ctx);
1307413121

13075-
res = new llama_kv_cache_unified(
13076-
*this,
13077-
params.type_k,
13078-
params.type_v,
13079-
!cparams.flash_attn,
13080-
cparams.offload_kqv,
13081-
cparams.n_ctx,
13082-
padding);
13122+
res = new llama_kv_cache_unified(
13123+
*this,
13124+
params.type_k,
13125+
params.type_v,
13126+
!cparams.flash_attn,
13127+
cparams.offload_kqv,
13128+
cparams.n_ctx,
13129+
padding);
13130+
}
1308313131
}
1308413132
}
1308513133

src/llama-model.h

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -402,7 +402,10 @@ struct llama_model {
402402

403403
// note: can mutate `cparams`
404404
// TODO: move this to new llm_arch_model_i interface
405-
llama_memory_i * create_memory(const llama_memory_params & params, llama_cparams & cparams) const;
405+
llama_memory_i * create_memory(
406+
const llama_memory_params & params,
407+
llama_cparams & cparams,
408+
const llama_hparams & hparams) const;
406409

407410
// TODO: move this to new llm_arch_model_i interface
408411
llm_graph_result_ptr build_graph(

0 commit comments

Comments
 (0)