Skip to content

Commit 728f514

Browse files
committed
feat: Construct hybrid recurrent cache for hybrid recurrent models
This includes a refactor of the create_memory logic to avoid needing to use the arch enum explicitly unless a model needs explicit cache instantiation logic beyond the standard logic for recurrent, hybrid, unified, and iswa. Branch: HybridRecurrentCache Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>
1 parent ec7695f commit 728f514

File tree

1 file changed

+62
-47
lines changed

1 file changed

+62
-47
lines changed

src/llama-model.cpp

Lines changed: 62 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -13190,6 +13190,8 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
1319013190
llama_memory_i * res;
1319113191

1319213192
switch (arch) {
13193+
// Models that need specific instantiation should be handled in the
13194+
// switch statement
1319313195
case LLM_ARCH_BERT:
1319413196
case LLM_ARCH_JINA_BERT_V2:
1319513197
case LLM_ARCH_NOMIC_BERT:
@@ -13198,58 +13200,71 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
1319813200
{
1319913201
res = nullptr;
1320013202
} break;
13201-
case LLM_ARCH_MAMBA:
13202-
case LLM_ARCH_RWKV6:
13203-
case LLM_ARCH_RWKV6QWEN2:
13204-
case LLM_ARCH_RWKV7:
13205-
case LLM_ARCH_ARWKV7:
13206-
{
13207-
res = new llama_kv_cache_recurrent(
13208-
*this,
13209-
nullptr,
13210-
GGML_TYPE_F32,
13211-
GGML_TYPE_F32,
13212-
cparams.offload_kqv,
13213-
std::max((uint32_t) 1, cparams.n_seq_max),
13214-
cparams.n_seq_max);
13215-
} break;
13203+
// Models that need standard caching should rely on recurrent/hybrid
13204+
// checks
1321613205
default:
1321713206
{
13218-
const auto padding = llama_kv_cache_unified::get_padding(cparams);
13219-
13220-
cparams.n_ctx = GGML_PAD(cparams.n_ctx, padding);
13221-
13222-
LLAMA_LOG_DEBUG("%s: n_ctx = %u (padded)\n", __func__, cparams.n_ctx);
13223-
13224-
if (hparams.swa_type != LLAMA_SWA_TYPE_NONE) {
13225-
GGML_ASSERT(hparams.is_swa_any());
13226-
13227-
res = new llama_kv_cache_unified_iswa(
13228-
*this,
13229-
params.type_k,
13230-
params.type_v,
13231-
!cparams.flash_attn,
13232-
cparams.offload_kqv,
13233-
params.swa_full,
13234-
cparams.n_ctx,
13235-
cparams.n_seq_max,
13236-
cparams.n_batch,
13237-
padding);
13238-
} else {
13239-
GGML_ASSERT(!hparams.is_swa_any());
13240-
13241-
res = new llama_kv_cache_unified(
13207+
if (llm_arch_is_recurrent(arch)) {
13208+
res = new llama_kv_cache_recurrent(
1324213209
*this,
1324313210
nullptr,
13244-
params.type_k,
13245-
params.type_v,
13246-
!cparams.flash_attn,
13211+
GGML_TYPE_F32,
13212+
GGML_TYPE_F32,
1324713213
cparams.offload_kqv,
13248-
cparams.n_ctx,
13249-
cparams.n_seq_max,
13250-
padding,
13251-
hparams.n_swa,
13252-
hparams.swa_type);
13214+
std::max((uint32_t) 1, cparams.n_seq_max),
13215+
cparams.n_seq_max);
13216+
} else if (llm_arch_is_hybrid_recurrent(arch)) {
13217+
res = new llama_kv_cache_hybrid_recurrent(
13218+
/* model */ *this,
13219+
/* attn_type_k */ params.type_k,
13220+
/* attn_type_v */ params.type_v,
13221+
/* attn_v_trans */ !cparams.flash_attn,
13222+
/* attn_kv_size */ cparams.n_ctx,
13223+
/* attn_n_pad */ llama_kv_cache_unified::get_padding(cparams),
13224+
/* attn_n_swa */ hparams.n_swa,
13225+
/* attn_swa_type */ hparams.swa_type,
13226+
/* recurrent_type_k */ GGML_TYPE_F32,
13227+
/* recurrent_type_v */ GGML_TYPE_F32,
13228+
/* recurrent_kv_size */ std::max((uint32_t) 1, cparams.n_seq_max),
13229+
/* n_seq_max */ cparams.n_seq_max,
13230+
/* offload */ cparams.offload_kqv);
13231+
} else {
13232+
const auto padding = llama_kv_cache_unified::get_padding(cparams);
13233+
13234+
cparams.n_ctx = GGML_PAD(cparams.n_ctx, padding);
13235+
13236+
LLAMA_LOG_DEBUG("%s: n_ctx = %u (padded)\n", __func__, cparams.n_ctx);
13237+
13238+
if (hparams.swa_type != LLAMA_SWA_TYPE_NONE) {
13239+
GGML_ASSERT(hparams.is_swa_any());
13240+
13241+
res = new llama_kv_cache_unified_iswa(
13242+
*this,
13243+
params.type_k,
13244+
params.type_v,
13245+
!cparams.flash_attn,
13246+
cparams.offload_kqv,
13247+
params.swa_full,
13248+
cparams.n_ctx,
13249+
cparams.n_seq_max,
13250+
cparams.n_batch,
13251+
padding);
13252+
} else {
13253+
GGML_ASSERT(!hparams.is_swa_any());
13254+
13255+
res = new llama_kv_cache_unified(
13256+
*this,
13257+
nullptr,
13258+
params.type_k,
13259+
params.type_v,
13260+
!cparams.flash_attn,
13261+
cparams.offload_kqv,
13262+
cparams.n_ctx,
13263+
cparams.n_seq_max,
13264+
padding,
13265+
hparams.n_swa,
13266+
hparams.swa_type);
13267+
}
1325313268
}
1325413269
}
1325513270
}

0 commit comments

Comments
 (0)