Skip to content

Commit 8dc179e

Browse files
committed
feat: Construct hybrid recurrent cache for hybrid recurrent models
This includes a refactor of the create_memory logic to avoid needing to use the arch enum explicitly unless a model needs explicit cache instantiation logic beyond the standard logic for recurrent, hybrid, unified, and iswa. Branch: HybridRecurrentCache Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>
1 parent b1aa5ac commit 8dc179e

File tree

1 file changed

+62
-47
lines changed

1 file changed

+62
-47
lines changed

src/llama-model.cpp

Lines changed: 62 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -13196,6 +13196,8 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
1319613196
llama_memory_i * res;
1319713197

1319813198
switch (arch) {
13199+
// Models that need specific instantiation should be handled in the
13200+
// switch statement
1319913201
case LLM_ARCH_BERT:
1320013202
case LLM_ARCH_JINA_BERT_V2:
1320113203
case LLM_ARCH_NOMIC_BERT:
@@ -13204,58 +13206,71 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
1320413206
{
1320513207
res = nullptr;
1320613208
} break;
13207-
case LLM_ARCH_MAMBA:
13208-
case LLM_ARCH_RWKV6:
13209-
case LLM_ARCH_RWKV6QWEN2:
13210-
case LLM_ARCH_RWKV7:
13211-
case LLM_ARCH_ARWKV7:
13212-
{
13213-
res = new llama_kv_cache_recurrent(
13214-
*this,
13215-
nullptr,
13216-
GGML_TYPE_F32,
13217-
GGML_TYPE_F32,
13218-
cparams.offload_kqv,
13219-
std::max((uint32_t) 1, cparams.n_seq_max),
13220-
cparams.n_seq_max);
13221-
} break;
13209+
// Models that need standard caching should rely on recurrent/hybrid
13210+
// checks
1322213211
default:
1322313212
{
13224-
const auto padding = llama_kv_cache_unified::get_padding(cparams);
13225-
13226-
cparams.n_ctx = GGML_PAD(cparams.n_ctx, padding);
13227-
13228-
LLAMA_LOG_DEBUG("%s: n_ctx = %u (padded)\n", __func__, cparams.n_ctx);
13229-
13230-
if (hparams.swa_type != LLAMA_SWA_TYPE_NONE) {
13231-
GGML_ASSERT(hparams.is_swa_any());
13232-
13233-
res = new llama_kv_cache_unified_iswa(
13234-
*this,
13235-
params.type_k,
13236-
params.type_v,
13237-
!cparams.flash_attn,
13238-
cparams.offload_kqv,
13239-
params.swa_full,
13240-
cparams.n_ctx,
13241-
cparams.n_seq_max,
13242-
cparams.n_batch,
13243-
padding);
13244-
} else {
13245-
GGML_ASSERT(!hparams.is_swa_any());
13246-
13247-
res = new llama_kv_cache_unified(
13213+
if (llm_arch_is_recurrent(arch)) {
13214+
res = new llama_kv_cache_recurrent(
1324813215
*this,
1324913216
nullptr,
13250-
params.type_k,
13251-
params.type_v,
13252-
!cparams.flash_attn,
13217+
GGML_TYPE_F32,
13218+
GGML_TYPE_F32,
1325313219
cparams.offload_kqv,
13254-
cparams.n_ctx,
13255-
cparams.n_seq_max,
13256-
padding,
13257-
hparams.n_swa,
13258-
hparams.swa_type);
13220+
std::max((uint32_t) 1, cparams.n_seq_max),
13221+
cparams.n_seq_max);
13222+
} else if (llm_arch_is_hybrid_recurrent(arch)) {
13223+
res = new llama_kv_cache_hybrid_recurrent(
13224+
/* model */ *this,
13225+
/* attn_type_k */ params.type_k,
13226+
/* attn_type_v */ params.type_v,
13227+
/* attn_v_trans */ !cparams.flash_attn,
13228+
/* attn_kv_size */ cparams.n_ctx,
13229+
/* attn_n_pad */ llama_kv_cache_unified::get_padding(cparams),
13230+
/* attn_n_swa */ hparams.n_swa,
13231+
/* attn_swa_type */ hparams.swa_type,
13232+
/* recurrent_type_k */ GGML_TYPE_F32,
13233+
/* recurrent_type_v */ GGML_TYPE_F32,
13234+
/* recurrent_kv_size */ std::max((uint32_t) 1, cparams.n_seq_max),
13235+
/* n_seq_max */ cparams.n_seq_max,
13236+
/* offload */ cparams.offload_kqv);
13237+
} else {
13238+
const auto padding = llama_kv_cache_unified::get_padding(cparams);
13239+
13240+
cparams.n_ctx = GGML_PAD(cparams.n_ctx, padding);
13241+
13242+
LLAMA_LOG_DEBUG("%s: n_ctx = %u (padded)\n", __func__, cparams.n_ctx);
13243+
13244+
if (hparams.swa_type != LLAMA_SWA_TYPE_NONE) {
13245+
GGML_ASSERT(hparams.is_swa_any());
13246+
13247+
res = new llama_kv_cache_unified_iswa(
13248+
*this,
13249+
params.type_k,
13250+
params.type_v,
13251+
!cparams.flash_attn,
13252+
cparams.offload_kqv,
13253+
params.swa_full,
13254+
cparams.n_ctx,
13255+
cparams.n_seq_max,
13256+
cparams.n_batch,
13257+
padding);
13258+
} else {
13259+
GGML_ASSERT(!hparams.is_swa_any());
13260+
13261+
res = new llama_kv_cache_unified(
13262+
*this,
13263+
nullptr,
13264+
params.type_k,
13265+
params.type_v,
13266+
!cparams.flash_attn,
13267+
cparams.offload_kqv,
13268+
cparams.n_ctx,
13269+
cparams.n_seq_max,
13270+
padding,
13271+
hparams.n_swa,
13272+
hparams.swa_type);
13273+
}
1325913274
}
1326013275
}
1326113276
}

0 commit comments

Comments
 (0)