Skip to content

Commit 403ba7a

Browse files
committed
feat: Instantiate hybrid cache for hybrid models (currently none)
This includes a slight architectural change where create_memory now only uses model architectures in the switch statement if their required cache type is not handled by llm_arch_is_[recurrent|hybrid]. Branch: HybridCache Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>
1 parent 251ae54 commit 403ba7a

File tree

1 file changed

+92
-39
lines changed

1 file changed

+92
-39
lines changed

src/llama-model.cpp

Lines changed: 92 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -13195,59 +13195,112 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
1319513195
llama_memory_i * res;
1319613196

1319713197
switch (arch) {
13198+
// Models that need specific instantiation should be handled in the
13199+
// switch statement
1319813200
case LLM_ARCH_BERT:
1319913201
case LLM_ARCH_JINA_BERT_V2:
1320013202
case LLM_ARCH_NOMIC_BERT:
1320113203
case LLM_ARCH_NOMIC_BERT_MOE:
1320213204
{
1320313205
res = nullptr;
1320413206
} break;
13205-
case LLM_ARCH_MAMBA:
13206-
case LLM_ARCH_RWKV6:
13207-
case LLM_ARCH_RWKV6QWEN2:
13208-
case LLM_ARCH_RWKV7:
13209-
case LLM_ARCH_ARWKV7:
13210-
{
13211-
res = new llama_kv_cache_recurrent(
13212-
*this,
13213-
nullptr,
13214-
GGML_TYPE_F32,
13215-
GGML_TYPE_F32,
13216-
cparams.offload_kqv,
13217-
std::max((uint32_t) 1, cparams.n_seq_max));
13218-
} break;
13207+
// Models that need standard caching should rely on recurrent/hybrid
13208+
// checks
1321913209
default:
1322013210
{
13221-
const auto padding = llama_kv_cache_unified::get_padding(cparams);
13222-
13223-
cparams.n_ctx = GGML_PAD(cparams.n_ctx, padding);
13211+
if (llm_arch_is_hybrid(arch)) {
13212+
// make vectors of recurrent and non-recurrent layer indices
13213+
std::vector<size_t> recurrent_layers;
13214+
std::vector<size_t> unified_layers;
13215+
for (auto il = 0u; il < hparams.n_layer; ++il) {
13216+
if (hparams.recurrent_layer(il)) {
13217+
recurrent_layers.push_back(il);
13218+
} else {
13219+
unified_layers.push_back(il);
13220+
}
13221+
}
1322413222

13225-
LLAMA_LOG_DEBUG("%s: n_ctx = %u (padded)\n", __func__, cparams.n_ctx);
13223+
const auto padding = llama_kv_cache_unified::get_padding(cparams);
13224+
cparams.n_ctx = GGML_PAD(cparams.n_ctx, padding);
13225+
LLAMA_LOG_DEBUG("%s: n_ctx = %u (padded)\n", __func__, cparams.n_ctx);
13226+
13227+
// initialize the children
13228+
std::vector<llama_kv_cache_hybrid::child_cache> children;
13229+
children.emplace_back(
13230+
std::unique_ptr<llama_kv_cache>(
13231+
new llama_kv_cache_recurrent(
13232+
*this,
13233+
[&](int32_t il) {
13234+
return hparams.recurrent_layer(il);
13235+
},
13236+
GGML_TYPE_F32,
13237+
GGML_TYPE_F32,
13238+
cparams.offload_kqv,
13239+
std::max((uint32_t) 1, cparams.n_seq_max))
13240+
),
13241+
std::move(recurrent_layers)
13242+
);
13243+
children.emplace_back(
13244+
std::unique_ptr<llama_kv_cache>(
13245+
new llama_kv_cache_unified(
13246+
*this,
13247+
[&](int32_t il) {
13248+
return ! hparams.recurrent_layer(il);
13249+
},
13250+
params.type_k,
13251+
params.type_v,
13252+
!cparams.flash_attn,
13253+
cparams.offload_kqv,
13254+
cparams.n_ctx,
13255+
padding,
13256+
hparams.n_swa,
13257+
hparams.swa_type)
13258+
),
13259+
std::move(unified_layers)
13260+
);
1322613261

13227-
if (hparams.n_swa > 0) {
13228-
res = new llama_kv_cache_unified_iswa(
13229-
*this,
13230-
params.type_k,
13231-
params.type_v,
13232-
!cparams.flash_attn,
13233-
cparams.offload_kqv,
13234-
cparams.n_ctx,
13235-
params.swa_full,
13236-
cparams.n_seq_max,
13237-
cparams.n_batch,
13238-
padding);
13239-
} else {
13240-
res = new llama_kv_cache_unified(
13262+
// initialize the hybrid cache with both children
13263+
res = new llama_kv_cache_hybrid(hparams, std::move(children));
13264+
} else if (llm_arch_is_recurrent(arch)) {
13265+
res = new llama_kv_cache_recurrent(
1324113266
*this,
1324213267
nullptr,
13243-
params.type_k,
13244-
params.type_v,
13245-
!cparams.flash_attn,
13268+
GGML_TYPE_F32,
13269+
GGML_TYPE_F32,
1324613270
cparams.offload_kqv,
13247-
cparams.n_ctx,
13248-
padding,
13249-
hparams.n_swa,
13250-
hparams.swa_type);
13271+
std::max((uint32_t) 1, cparams.n_seq_max));
13272+
} else {
13273+
const auto padding = llama_kv_cache_unified::get_padding(cparams);
13274+
13275+
cparams.n_ctx = GGML_PAD(cparams.n_ctx, padding);
13276+
13277+
LLAMA_LOG_DEBUG("%s: n_ctx = %u (padded)\n", __func__, cparams.n_ctx);
13278+
13279+
if (hparams.n_swa > 0) {
13280+
res = new llama_kv_cache_unified_iswa(
13281+
*this,
13282+
params.type_k,
13283+
params.type_v,
13284+
!cparams.flash_attn,
13285+
cparams.offload_kqv,
13286+
cparams.n_ctx,
13287+
params.swa_full,
13288+
cparams.n_seq_max,
13289+
cparams.n_batch,
13290+
padding);
13291+
} else {
13292+
res = new llama_kv_cache_unified(
13293+
*this,
13294+
nullptr,
13295+
params.type_k,
13296+
params.type_v,
13297+
!cparams.flash_attn,
13298+
cparams.offload_kqv,
13299+
cparams.n_ctx,
13300+
padding,
13301+
hparams.n_swa,
13302+
hparams.swa_type);
13303+
}
1325113304
}
1325213305
}
1325313306
}

0 commit comments

Comments
 (0)