@@ -3198,11 +3198,21 @@ bool llama_kv_cache_recurrent::state_read_data(llama_io_read_i & io, uint32_t ce
3198
3198
// llama_kv_cache_hybrid_recurrent
3199
3199
//
3200
3200
3201
- class llama_kv_cache_hybrid_recurrent_decode_state_t : public llama_memory_decode_state_i {
3201
+ class llama_kv_cache_hybrid_recurrent_state : public llama_kv_cache_hybrid_recurrent_state_i {
3202
3202
public:
3203
- llama_kv_cache_hybrid_recurrent_decode_state_t (llama_memory_status status) : status(status) {}
3204
-
3205
- llama_kv_cache_hybrid_recurrent_decode_state_t (
3203
+ // init failure
3204
+ explicit llama_kv_cache_hybrid_recurrent_state (llama_memory_status status)
3205
+ : status(status), state_attn(status), state_recurrent(status) {}
3206
+
3207
+ // init full
3208
+ explicit llama_kv_cache_hybrid_recurrent_state (llama_kv_cache_hybrid_recurrent * kv)
3209
+ : status(LLAMA_MEMORY_STATUS_SUCCESS),
3210
+ kv(kv),
3211
+ state_attn(status, kv->get_kv_attn ()),
3212
+ state_recurrent(status, kv->get_kv_recurrent ()) {}
3213
+
3214
+ // init success
3215
+ llama_kv_cache_hybrid_recurrent_state (
3206
3216
llama_kv_cache_hybrid_recurrent * kv,
3207
3217
llama_sbatch sbatch,
3208
3218
std::vector<uint32_t > heads_attn,
@@ -3211,22 +3221,33 @@ class llama_kv_cache_hybrid_recurrent_decode_state_t : public llama_memory_decod
3211
3221
kv(kv),
3212
3222
sbatch(std::move(sbatch)),
3213
3223
heads_attn(std::move(heads_attn)),
3214
- ubatches(std::move(ubatches)) {
3224
+ ubatches(std::move(ubatches)),
3225
+ // NOTE: these child states are only used as wrapper APIs for the
3226
+ // const methods, so we use the "init full" signature since the
3227
+ // actual state is not used.
3228
+ state_attn(LLAMA_MEMORY_STATUS_SUCCESS, kv->get_kv_attn ()),
3229
+ state_recurrent(LLAMA_MEMORY_STATUS_SUCCESS, kv->get_kv_recurrent ()) {
3215
3230
}
3216
3231
3217
- ~llama_kv_cache_hybrid_recurrent_decode_state_t () = default ;
3232
+ ~llama_kv_cache_hybrid_recurrent_state () = default ;
3218
3233
3219
- llama_ubatch * next () override {
3234
+ bool next () override {
3220
3235
assert (status == LLAMA_MEMORY_STATUS_SUCCESS);
3221
3236
3222
- if (i_next >= ubatches.size ()) {
3223
- return nullptr ;
3237
+ if (++ i_next >= ubatches.size ()) {
3238
+ return false ;
3224
3239
}
3225
3240
3226
- kv->get_kv_attn () ->fill_slot (heads_attn[i_next], ubatches[i_next]);
3241
+ return true ;
3242
+ }
3243
+
3244
+ bool apply () override {
3245
+ assert (status == LLAMA_MEMORY_STATUS_SUCCESS);
3246
+
3247
+ kv->get_kv_attn () ->apply_ubatch (heads_attn[i_next], ubatches[i_next]);
3227
3248
kv->get_kv_recurrent ()->find_slot (ubatches[i_next]);
3228
3249
3229
- return &ubatches[i_next++] ;
3250
+ return true ;
3230
3251
}
3231
3252
3232
3253
std::vector<int64_t > & out_ids () override {
@@ -3239,6 +3260,23 @@ class llama_kv_cache_hybrid_recurrent_decode_state_t : public llama_memory_decod
3239
3260
return status;
3240
3261
}
3241
3262
3263
+ const llama_ubatch & get_ubatch () const override {
3264
+ assert (status == LLAMA_MEMORY_STATUS_SUCCESS);
3265
+ return ubatches[i_next];
3266
+ }
3267
+
3268
+ //
3269
+ // llama_kv_cache_hybrid_recurrent_state_i
3270
+ //
3271
+
3272
+ const llama_kv_cache_unified_state_i * get_state_attn () const override {
3273
+ return &state_attn;
3274
+ }
3275
+
3276
+ const llama_kv_cache_recurrent_state_i * get_state_recurrent () const override {
3277
+ return &state_recurrent;
3278
+ }
3279
+
3242
3280
private:
3243
3281
const llama_memory_status status;
3244
3282
@@ -3251,6 +3289,9 @@ class llama_kv_cache_hybrid_recurrent_decode_state_t : public llama_memory_decod
3251
3289
3252
3290
std::vector<uint32_t > heads_attn;
3253
3291
std::vector<llama_ubatch> ubatches;
3292
+
3293
+ const llama_kv_cache_unified_state state_attn;
3294
+ const llama_kv_cache_recurrent_state_t state_recurrent;
3254
3295
};
3255
3296
3256
3297
llama_kv_cache_hybrid_recurrent::llama_kv_cache_hybrid_recurrent (
@@ -3338,7 +3379,7 @@ llama_pos llama_kv_cache_hybrid_recurrent::seq_pos_max(llama_seq_id seq_id) cons
3338
3379
return std::min (kv_attn->seq_pos_max (seq_id), kv_recurrent->seq_pos_max (seq_id));
3339
3380
}
3340
3381
3341
- llama_memory_decode_state_ptr llama_kv_cache_hybrid_recurrent::init (const llama_batch & batch, uint32_t n_ubatch, bool embd_pooled, bool logits_all) {
3382
+ llama_memory_state_ptr llama_kv_cache_hybrid_recurrent::init_batch (const llama_batch & batch, uint32_t n_ubatch, bool embd_pooled, bool logits_all) {
3342
3383
3343
3384
// since this includes a recurrent cache, we cannot use split_simple
3344
3385
auto sbatch = llama_sbatch (batch, hparams.n_embd , false , logits_all);
@@ -3362,20 +3403,24 @@ llama_memory_decode_state_ptr llama_kv_cache_hybrid_recurrent::init(const llama_
3362
3403
if (!kv_recurrent->prepare (ubatches)) {
3363
3404
// TODO: will the recurrent cache be in an undefined state at this point?
3364
3405
LLAMA_LOG_ERROR (" %s: failed to prepare recurrent ubatches\n " , __func__);
3365
- return std::make_unique<llama_kv_cache_hybrid_recurrent_decode_state_t >(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
3406
+ return std::make_unique<llama_kv_cache_hybrid_recurrent_state >(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
3366
3407
}
3367
3408
3368
3409
// prepare the attention cache
3369
3410
auto heads_attn = kv_attn->prepare (ubatches);
3370
3411
if (heads_attn.empty ()) {
3371
3412
LLAMA_LOG_ERROR (" %s: failed to prepare attention ubatches\n " , __func__);
3372
- return std::make_unique<llama_kv_cache_hybrid_recurrent_decode_state_t >(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
3413
+ return std::make_unique<llama_kv_cache_hybrid_recurrent_state >(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
3373
3414
}
3374
3415
3375
- return std::make_unique<llama_kv_cache_hybrid_recurrent_decode_state_t >(
3416
+ return std::make_unique<llama_kv_cache_hybrid_recurrent_state >(
3376
3417
this , std::move (sbatch), std::move (heads_attn), std::move (ubatches));
3377
3418
}
3378
3419
3420
+ llama_memory_state_ptr llama_kv_cache_hybrid_recurrent::init_full () {
3421
+ return std::make_unique<llama_kv_cache_hybrid_recurrent_state>(this );
3422
+ }
3423
+
3379
3424
bool llama_kv_cache_hybrid_recurrent::update (llama_context & lctx) {
3380
3425
bool res = false ;
3381
3426
@@ -3390,11 +3435,6 @@ void llama_kv_cache_hybrid_recurrent::defrag_sched(float thold) {
3390
3435
kv_recurrent->defrag_sched (thold);
3391
3436
}
3392
3437
3393
- void llama_kv_cache_hybrid_recurrent::set_full () {
3394
- kv_attn ->set_full ();
3395
- kv_recurrent->set_full ();
3396
- }
3397
-
3398
3438
bool llama_kv_cache_hybrid_recurrent::get_can_shift () const {
3399
3439
// TODO: Should this return true if the attention cache can shift?
3400
3440
return false ;
0 commit comments