Skip to content

Commit b1aa5ac

Browse files
committed
feat: First pass at llama_kv_cache_hybrid_recurrent
This follows the pattern in iswa where the two child caches are held explicitly to support the case where a model requires a single attention cache and a single recurrent cache where each layer uses exactly one of the caches. This is a rewrite of the more generic approach in the original hybrid cache PR: #13276 Branch: HybridRecurrentCache Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>
1 parent 6b58853 commit b1aa5ac

File tree

2 files changed

+307
-2
lines changed

2 files changed

+307
-2
lines changed

src/llama-kv-cache.cpp

Lines changed: 226 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1901,8 +1901,8 @@ llama_kv_cache_unified_iswa::llama_kv_cache_unified_iswa(
19011901
uint32_t n_seq_max,
19021902
uint32_t n_batch,
19031903
uint32_t n_pad) : hparams(model.hparams) {
1904-
llama_kv_cache_unified::layer_filter_cb filter_base = [&](int32_t il) { return !model.hparams.is_swa(il); };
1905-
llama_kv_cache_unified::layer_filter_cb filter_swa = [&](int32_t il) { return model.hparams.is_swa(il); };
1904+
llama_kv_cache::layer_filter_cb filter_base = [&](int32_t il) { return !model.hparams.is_swa(il); };
1905+
llama_kv_cache::layer_filter_cb filter_swa = [&](int32_t il) { return model.hparams.is_swa(il); };
19061906

19071907
const uint32_t size_base = kv_size;
19081908

@@ -3193,3 +3193,227 @@ bool llama_kv_cache_recurrent::state_read_data(llama_io_read_i & io, uint32_t ce
31933193

31943194
return true;
31953195
}
3196+
3197+
//
3198+
// llama_kv_cache_hybrid_recurrent
3199+
//
3200+
3201+
class llama_kv_cache_hybrid_recurrent_decode_state_t : public llama_memory_decode_state_i {
3202+
public:
3203+
llama_kv_cache_hybrid_recurrent_decode_state_t(llama_memory_status status) : status(status) {}
3204+
3205+
llama_kv_cache_hybrid_recurrent_decode_state_t(
3206+
llama_kv_cache_hybrid_recurrent * kv,
3207+
llama_sbatch sbatch,
3208+
std::vector<uint32_t> heads_attn,
3209+
std::vector<llama_ubatch> ubatches)
3210+
: status(LLAMA_MEMORY_STATUS_SUCCESS),
3211+
kv(kv),
3212+
sbatch(std::move(sbatch)),
3213+
heads_attn(std::move(heads_attn)),
3214+
ubatches(std::move(ubatches)) {
3215+
}
3216+
3217+
~llama_kv_cache_hybrid_recurrent_decode_state_t() = default;
3218+
3219+
llama_ubatch * next() override {
3220+
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
3221+
3222+
if (i_next >= ubatches.size()) {
3223+
return nullptr;
3224+
}
3225+
3226+
kv->get_kv_attn() ->fill_slot(heads_attn[i_next], ubatches[i_next]);
3227+
kv->get_kv_recurrent()->find_slot(ubatches[i_next]);
3228+
3229+
return &ubatches[i_next++];
3230+
}
3231+
3232+
std::vector<int64_t> & out_ids() override {
3233+
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
3234+
3235+
return sbatch.out_ids;
3236+
}
3237+
3238+
llama_memory_status get_status() const override {
3239+
return status;
3240+
}
3241+
3242+
private:
3243+
const llama_memory_status status;
3244+
3245+
llama_kv_cache_hybrid_recurrent * kv;
3246+
3247+
llama_sbatch sbatch;
3248+
3249+
// the index of the next ubatch to process
3250+
size_t i_next = 0;
3251+
3252+
std::vector<uint32_t> heads_attn;
3253+
std::vector<llama_ubatch> ubatches;
3254+
};
3255+
3256+
llama_kv_cache_hybrid_recurrent::llama_kv_cache_hybrid_recurrent(
3257+
const llama_model & model,
3258+
/* attn */
3259+
ggml_type attn_type_k,
3260+
ggml_type attn_type_v,
3261+
bool attn_v_trans,
3262+
uint32_t attn_kv_size,
3263+
uint32_t attn_n_pad,
3264+
uint32_t attn_n_swa,
3265+
llama_swa_type attn_swa_type,
3266+
/* recurrent */
3267+
ggml_type recurrent_type_k,
3268+
ggml_type recurrent_type_v,
3269+
uint32_t recurrent_kv_size,
3270+
/* common */
3271+
uint32_t n_seq_max,
3272+
bool offload) :
3273+
hparams(model.hparams),
3274+
kv_attn(new llama_kv_cache_unified(
3275+
model,
3276+
[&](int32_t il) { return !model.hparams.recurrent_layer(il); },
3277+
attn_type_k,
3278+
attn_type_v,
3279+
attn_v_trans,
3280+
offload,
3281+
attn_kv_size,
3282+
n_seq_max,
3283+
attn_n_pad,
3284+
attn_n_swa,
3285+
attn_swa_type
3286+
)),
3287+
kv_recurrent(new llama_kv_cache_recurrent(
3288+
model,
3289+
[&](int32_t il) { return model.hparams.recurrent_layer(il); },
3290+
recurrent_type_k,
3291+
recurrent_type_v,
3292+
offload,
3293+
recurrent_kv_size,
3294+
n_seq_max
3295+
)) {}
3296+
3297+
void llama_kv_cache_hybrid_recurrent::clear() {
3298+
kv_attn ->clear();
3299+
kv_recurrent->clear();
3300+
}
3301+
3302+
bool llama_kv_cache_hybrid_recurrent::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) {
3303+
// Try removing from the recurrent cache first since it may fail. If it does
3304+
// fail, the cache will not have been mutated.
3305+
if (!kv_recurrent->seq_rm(seq_id, p0, p1)) {
3306+
return false;
3307+
}
3308+
return kv_attn->seq_rm(seq_id, p0, p1);
3309+
}
3310+
3311+
void llama_kv_cache_hybrid_recurrent::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) {
3312+
kv_attn ->seq_cp(seq_id_src, seq_id_dst, p0, p1);
3313+
kv_recurrent->seq_cp(seq_id_src, seq_id_dst, p0, p1);
3314+
}
3315+
3316+
void llama_kv_cache_hybrid_recurrent::seq_keep(llama_seq_id seq_id) {
3317+
kv_attn ->seq_keep(seq_id);
3318+
kv_recurrent->seq_keep(seq_id);
3319+
}
3320+
3321+
void llama_kv_cache_hybrid_recurrent::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) {
3322+
kv_attn->seq_add(seq_id, p0, p1, shift);
3323+
kv_recurrent->seq_add(seq_id, p0, p1, shift);
3324+
}
3325+
3326+
void llama_kv_cache_hybrid_recurrent::seq_div(llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) {
3327+
kv_attn ->seq_div(seq_id, p0, p1, d);
3328+
kv_recurrent->seq_div(seq_id, p0, p1, d);
3329+
}
3330+
3331+
llama_pos llama_kv_cache_hybrid_recurrent::seq_pos_min(llama_seq_id seq_id) const {
3332+
// the min of the total cache is the max of the two caches' min values
3333+
return std::max(kv_attn->seq_pos_min(seq_id), kv_recurrent->seq_pos_min(seq_id));
3334+
}
3335+
3336+
llama_pos llama_kv_cache_hybrid_recurrent::seq_pos_max(llama_seq_id seq_id) const {
3337+
// the max of the total cache is the min of the two caches' max values
3338+
return std::min(kv_attn->seq_pos_max(seq_id), kv_recurrent->seq_pos_max(seq_id));
3339+
}
3340+
3341+
llama_memory_decode_state_ptr llama_kv_cache_hybrid_recurrent::init(const llama_batch & batch, uint32_t n_ubatch, bool embd_pooled, bool logits_all) {
3342+
3343+
// since this includes a recurrent cache, we cannot use split_simple
3344+
auto sbatch = llama_sbatch(batch, hparams.n_embd, true, logits_all);
3345+
3346+
// follow the recurrent pattern for creating the ubatch splits
3347+
std::vector<llama_ubatch> ubatches;
3348+
while (sbatch.n_tokens > 0) {
3349+
llama_ubatch ubatch;
3350+
3351+
if (embd_pooled) {
3352+
// Pooled embeddings cannot be split across ubatches (yet)
3353+
ubatch = sbatch.split_seq(n_ubatch);
3354+
} else {
3355+
ubatch = sbatch.split_equal(n_ubatch);
3356+
}
3357+
3358+
ubatches.push_back(ubatch);
3359+
}
3360+
3361+
// prepare the recurrent batches first
3362+
if (!kv_recurrent->prepare(ubatches)) {
3363+
// TODO: will the recurrent cache be in an undefined state at this point?
3364+
LLAMA_LOG_ERROR("%s: failed to prepare recurrent ubatches\n", __func__);
3365+
return std::make_unique<llama_kv_cache_hybrid_recurrent_decode_state_t>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
3366+
}
3367+
3368+
// prepare the attention cache
3369+
auto heads_attn = kv_attn->prepare(ubatches);
3370+
if (heads_attn.empty()) {
3371+
LLAMA_LOG_ERROR("%s: failed to prepare attention ubatches\n", __func__);
3372+
return std::make_unique<llama_kv_cache_hybrid_recurrent_decode_state_t>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
3373+
}
3374+
3375+
return std::make_unique<llama_kv_cache_hybrid_recurrent_decode_state_t>(
3376+
this, std::move(sbatch), std::move(heads_attn), std::move(ubatches));
3377+
}
3378+
3379+
bool llama_kv_cache_hybrid_recurrent::update(llama_context & lctx) {
3380+
bool res = false;
3381+
3382+
res = res | kv_attn ->update(lctx);
3383+
res = res | kv_recurrent->update(lctx);
3384+
3385+
return res;
3386+
}
3387+
3388+
void llama_kv_cache_hybrid_recurrent::defrag_sched(float thold) {
3389+
kv_attn ->defrag_sched(thold);
3390+
kv_recurrent->defrag_sched(thold);
3391+
}
3392+
3393+
void llama_kv_cache_hybrid_recurrent::set_full() {
3394+
kv_attn ->set_full();
3395+
kv_recurrent->set_full();
3396+
}
3397+
3398+
bool llama_kv_cache_hybrid_recurrent::get_can_shift() const {
3399+
// TODO: Should this return true if the attention cache can shift?
3400+
return false;
3401+
}
3402+
3403+
void llama_kv_cache_hybrid_recurrent::state_write(llama_io_write_i & io, llama_seq_id seq_id) const {
3404+
kv_attn ->state_write(io, seq_id);
3405+
kv_recurrent->state_write(io, seq_id);
3406+
}
3407+
3408+
void llama_kv_cache_hybrid_recurrent::state_read(llama_io_read_i & io, llama_seq_id seq_id) {
3409+
kv_attn ->state_read(io, seq_id);
3410+
kv_recurrent->state_read(io, seq_id);
3411+
}
3412+
3413+
llama_kv_cache_unified * llama_kv_cache_hybrid_recurrent::get_kv_attn() const {
3414+
return kv_attn.get();
3415+
}
3416+
3417+
llama_kv_cache_recurrent * llama_kv_cache_hybrid_recurrent::get_kv_recurrent() const {
3418+
return kv_recurrent.get();
3419+
}

src/llama-kv-cache.h

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -495,3 +495,84 @@ class llama_kv_cache_recurrent_state_i : public llama_memory_state_i {
495495
virtual int32_t s_copy(int i) const = 0;
496496
virtual float s_mask(int i) const = 0;
497497
};
498+
499+
//
500+
// llama_kv_cache_hybrid_recurrent
501+
//
502+
503+
// utilizes instances of llama_kv_cache_recurrent and llama_kv_cache_unified to
504+
// support models where each layer may be either attention-based or recurrent
505+
506+
class llama_kv_cache_hybrid_recurrent : public llama_kv_cache {
507+
public:
508+
llama_kv_cache_hybrid_recurrent(
509+
const llama_model & model,
510+
/* attn */
511+
ggml_type attn_type_k,
512+
ggml_type attn_type_v,
513+
bool attn_v_trans,
514+
uint32_t attn_kv_size,
515+
uint32_t attn_n_pad,
516+
uint32_t attn_n_swa,
517+
llama_swa_type attn_swa_type,
518+
/* recurrent */
519+
ggml_type recurrent_type_k,
520+
ggml_type recurrent_type_v,
521+
uint32_t recurrent_kv_size,
522+
/* common */
523+
uint32_t n_seq_max,
524+
bool offload);
525+
526+
~llama_kv_cache_hybrid_recurrent() = default;
527+
528+
//
529+
// llama_memory_i
530+
//
531+
532+
void clear() override;
533+
534+
bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) override;
535+
void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override;
536+
void seq_keep(llama_seq_id seq_id) override;
537+
void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) override;
538+
void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) override;
539+
540+
llama_pos seq_pos_min(llama_seq_id seq_id) const override;
541+
llama_pos seq_pos_max(llama_seq_id seq_id) const override;
542+
543+
//
544+
// llama_kv_cache
545+
//
546+
547+
llama_memory_decode_state_ptr init(
548+
const llama_batch & batch,
549+
uint32_t n_ubatch,
550+
bool embd_pooled,
551+
bool logits_all) override;
552+
553+
bool update(llama_context & lctx) override;
554+
555+
void defrag_sched(float thold) override;
556+
557+
void set_full() override;
558+
559+
bool get_can_shift() const override;
560+
561+
// state write/load
562+
563+
void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const override;
564+
void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1) override;
565+
566+
//
567+
// llama_kv_cache_hybrid_recurrent specific API
568+
//
569+
570+
llama_kv_cache_unified * get_kv_attn () const;
571+
llama_kv_cache_recurrent * get_kv_recurrent() const;
572+
573+
private:
574+
const llama_hparams & hparams;
575+
576+
const std::unique_ptr<llama_kv_cache_unified> kv_attn;
577+
const std::unique_ptr<llama_kv_cache_recurrent> kv_recurrent;
578+
};

0 commit comments

Comments
 (0)