Skip to content

Commit 33c72e5

Browse files
committed
fix: Fix recurrent cache impl for llama_memory_state_i paradigm after rebase
Branch: HybridRecurrentCache Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>
1 parent dee0814 commit 33c72e5

File tree

2 files changed

+71
-23
lines changed

2 files changed

+71
-23
lines changed

src/llama-kv-cache.cpp

Lines changed: 60 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -3198,11 +3198,21 @@ bool llama_kv_cache_recurrent::state_read_data(llama_io_read_i & io, uint32_t ce
31983198
// llama_kv_cache_hybrid_recurrent
31993199
//
32003200

3201-
class llama_kv_cache_hybrid_recurrent_decode_state_t : public llama_memory_decode_state_i {
3201+
class llama_kv_cache_hybrid_recurrent_state : public llama_kv_cache_hybrid_recurrent_state_i {
32023202
public:
3203-
llama_kv_cache_hybrid_recurrent_decode_state_t(llama_memory_status status) : status(status) {}
3204-
3205-
llama_kv_cache_hybrid_recurrent_decode_state_t(
3203+
// init failure
3204+
explicit llama_kv_cache_hybrid_recurrent_state(llama_memory_status status)
3205+
: status(status), state_attn(status), state_recurrent(status) {}
3206+
3207+
// init full
3208+
explicit llama_kv_cache_hybrid_recurrent_state(llama_kv_cache_hybrid_recurrent * kv)
3209+
: status(LLAMA_MEMORY_STATUS_SUCCESS),
3210+
kv(kv),
3211+
state_attn(status, kv->get_kv_attn()),
3212+
state_recurrent(status, kv->get_kv_recurrent()) {}
3213+
3214+
// init success
3215+
llama_kv_cache_hybrid_recurrent_state(
32063216
llama_kv_cache_hybrid_recurrent * kv,
32073217
llama_sbatch sbatch,
32083218
std::vector<uint32_t> heads_attn,
@@ -3211,22 +3221,33 @@ class llama_kv_cache_hybrid_recurrent_decode_state_t : public llama_memory_decod
32113221
kv(kv),
32123222
sbatch(std::move(sbatch)),
32133223
heads_attn(std::move(heads_attn)),
3214-
ubatches(std::move(ubatches)) {
3224+
ubatches(std::move(ubatches)),
3225+
// NOTE: these child states are only used as wrapper APIs for the
3226+
// const methods, so we use the "init full" signature since the
3227+
// actual state is not used.
3228+
state_attn(LLAMA_MEMORY_STATUS_SUCCESS, kv->get_kv_attn()),
3229+
state_recurrent(LLAMA_MEMORY_STATUS_SUCCESS, kv->get_kv_recurrent()) {
32153230
}
32163231

3217-
~llama_kv_cache_hybrid_recurrent_decode_state_t() = default;
3232+
~llama_kv_cache_hybrid_recurrent_state() = default;
32183233

3219-
llama_ubatch * next() override {
3234+
bool next() override {
32203235
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
32213236

3222-
if (i_next >= ubatches.size()) {
3223-
return nullptr;
3237+
if (++i_next >= ubatches.size()) {
3238+
return false;
32243239
}
32253240

3226-
kv->get_kv_attn() ->fill_slot(heads_attn[i_next], ubatches[i_next]);
3241+
return true;
3242+
}
3243+
3244+
bool apply() override {
3245+
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
3246+
3247+
kv->get_kv_attn() ->apply_ubatch(heads_attn[i_next], ubatches[i_next]);
32273248
kv->get_kv_recurrent()->find_slot(ubatches[i_next]);
32283249

3229-
return &ubatches[i_next++];
3250+
return true;
32303251
}
32313252

32323253
std::vector<int64_t> & out_ids() override {
@@ -3239,6 +3260,23 @@ class llama_kv_cache_hybrid_recurrent_decode_state_t : public llama_memory_decod
32393260
return status;
32403261
}
32413262

3263+
const llama_ubatch & get_ubatch() const override {
3264+
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
3265+
return ubatches[i_next];
3266+
}
3267+
3268+
//
3269+
// llama_kv_cache_hybrid_recurrent_state_i
3270+
//
3271+
3272+
const llama_kv_cache_unified_state_i * get_state_attn () const override {
3273+
return &state_attn;
3274+
}
3275+
3276+
const llama_kv_cache_recurrent_state_i * get_state_recurrent() const override {
3277+
return &state_recurrent;
3278+
}
3279+
32423280
private:
32433281
const llama_memory_status status;
32443282

@@ -3251,6 +3289,9 @@ class llama_kv_cache_hybrid_recurrent_decode_state_t : public llama_memory_decod
32513289

32523290
std::vector<uint32_t> heads_attn;
32533291
std::vector<llama_ubatch> ubatches;
3292+
3293+
const llama_kv_cache_unified_state state_attn;
3294+
const llama_kv_cache_recurrent_state_t state_recurrent;
32543295
};
32553296

32563297
llama_kv_cache_hybrid_recurrent::llama_kv_cache_hybrid_recurrent(
@@ -3338,7 +3379,7 @@ llama_pos llama_kv_cache_hybrid_recurrent::seq_pos_max(llama_seq_id seq_id) cons
33383379
return std::min(kv_attn->seq_pos_max(seq_id), kv_recurrent->seq_pos_max(seq_id));
33393380
}
33403381

3341-
llama_memory_decode_state_ptr llama_kv_cache_hybrid_recurrent::init(const llama_batch & batch, uint32_t n_ubatch, bool embd_pooled, bool logits_all) {
3382+
llama_memory_state_ptr llama_kv_cache_hybrid_recurrent::init_batch(const llama_batch & batch, uint32_t n_ubatch, bool embd_pooled, bool logits_all) {
33423383

33433384
// since this includes a recurrent cache, we cannot use split_simple
33443385
auto sbatch = llama_sbatch(batch, hparams.n_embd, false, logits_all);
@@ -3362,20 +3403,24 @@ llama_memory_decode_state_ptr llama_kv_cache_hybrid_recurrent::init(const llama_
33623403
if (!kv_recurrent->prepare(ubatches)) {
33633404
// TODO: will the recurrent cache be in an undefined state at this point?
33643405
LLAMA_LOG_ERROR("%s: failed to prepare recurrent ubatches\n", __func__);
3365-
return std::make_unique<llama_kv_cache_hybrid_recurrent_decode_state_t>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
3406+
return std::make_unique<llama_kv_cache_hybrid_recurrent_state>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
33663407
}
33673408

33683409
// prepare the attention cache
33693410
auto heads_attn = kv_attn->prepare(ubatches);
33703411
if (heads_attn.empty()) {
33713412
LLAMA_LOG_ERROR("%s: failed to prepare attention ubatches\n", __func__);
3372-
return std::make_unique<llama_kv_cache_hybrid_recurrent_decode_state_t>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
3413+
return std::make_unique<llama_kv_cache_hybrid_recurrent_state>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
33733414
}
33743415

3375-
return std::make_unique<llama_kv_cache_hybrid_recurrent_decode_state_t>(
3416+
return std::make_unique<llama_kv_cache_hybrid_recurrent_state>(
33763417
this, std::move(sbatch), std::move(heads_attn), std::move(ubatches));
33773418
}
33783419

3420+
llama_memory_state_ptr llama_kv_cache_hybrid_recurrent::init_full() {
3421+
return std::make_unique<llama_kv_cache_hybrid_recurrent_state>(this);
3422+
}
3423+
33793424
bool llama_kv_cache_hybrid_recurrent::update(llama_context & lctx) {
33803425
bool res = false;
33813426

@@ -3390,11 +3435,6 @@ void llama_kv_cache_hybrid_recurrent::defrag_sched(float thold) {
33903435
kv_recurrent->defrag_sched(thold);
33913436
}
33923437

3393-
void llama_kv_cache_hybrid_recurrent::set_full() {
3394-
kv_attn ->set_full();
3395-
kv_recurrent->set_full();
3396-
}
3397-
33983438
bool llama_kv_cache_hybrid_recurrent::get_can_shift() const {
33993439
// TODO: Should this return true if the attention cache can shift?
34003440
return false;

src/llama-kv-cache.h

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -544,18 +544,18 @@ class llama_kv_cache_hybrid_recurrent : public llama_kv_cache {
544544
// llama_kv_cache
545545
//
546546

547-
llama_memory_decode_state_ptr init(
547+
llama_memory_state_ptr init_batch(
548548
const llama_batch & batch,
549549
uint32_t n_ubatch,
550550
bool embd_pooled,
551551
bool logits_all) override;
552552

553+
llama_memory_state_ptr init_full() override;
554+
553555
bool update(llama_context & lctx) override;
554556

555557
void defrag_sched(float thold) override;
556558

557-
void set_full() override;
558-
559559
bool get_can_shift() const override;
560560

561561
// state write/load
@@ -576,3 +576,11 @@ class llama_kv_cache_hybrid_recurrent : public llama_kv_cache {
576576
const std::unique_ptr<llama_kv_cache_unified> kv_attn;
577577
const std::unique_ptr<llama_kv_cache_recurrent> kv_recurrent;
578578
};
579+
580+
class llama_kv_cache_hybrid_recurrent_state_i : public llama_memory_state_i {
581+
public:
582+
virtual ~llama_kv_cache_hybrid_recurrent_state_i() = default;
583+
584+
virtual const llama_kv_cache_unified_state_i * get_state_attn () const = 0;
585+
virtual const llama_kv_cache_recurrent_state_i * get_state_recurrent() const = 0;
586+
};

0 commit comments

Comments
 (0)