@@ -1901,8 +1901,8 @@ llama_kv_cache_unified_iswa::llama_kv_cache_unified_iswa(
1901
1901
uint32_t n_seq_max,
1902
1902
uint32_t n_batch,
1903
1903
uint32_t n_pad) : hparams(model.hparams) {
1904
- llama_kv_cache_unified ::layer_filter_cb filter_base = [&](int32_t il) { return !model.hparams .is_swa (il); };
1905
- llama_kv_cache_unified ::layer_filter_cb filter_swa = [&](int32_t il) { return model.hparams .is_swa (il); };
1904
+ llama_kv_cache ::layer_filter_cb filter_base = [&](int32_t il) { return !model.hparams .is_swa (il); };
1905
+ llama_kv_cache ::layer_filter_cb filter_swa = [&](int32_t il) { return model.hparams .is_swa (il); };
1906
1906
1907
1907
const uint32_t size_base = kv_size;
1908
1908
@@ -3193,3 +3193,227 @@ bool llama_kv_cache_recurrent::state_read_data(llama_io_read_i & io, uint32_t ce
3193
3193
3194
3194
return true ;
3195
3195
}
3196
+
3197
+ //
3198
+ // llama_kv_cache_hybrid_recurrent
3199
+ //
3200
+
3201
+ class llama_kv_cache_hybrid_recurrent_decode_state_t : public llama_memory_decode_state_i {
3202
+ public:
3203
+ llama_kv_cache_hybrid_recurrent_decode_state_t (llama_memory_status status) : status(status) {}
3204
+
3205
+ llama_kv_cache_hybrid_recurrent_decode_state_t (
3206
+ llama_kv_cache_hybrid_recurrent * kv,
3207
+ llama_sbatch sbatch,
3208
+ std::vector<uint32_t > heads_attn,
3209
+ std::vector<llama_ubatch> ubatches)
3210
+ : status(LLAMA_MEMORY_STATUS_SUCCESS),
3211
+ kv (kv),
3212
+ sbatch(std::move(sbatch)),
3213
+ heads_attn(std::move(heads_attn)),
3214
+ ubatches(std::move(ubatches)) {
3215
+ }
3216
+
3217
+ ~llama_kv_cache_hybrid_recurrent_decode_state_t () = default ;
3218
+
3219
+ llama_ubatch * next () override {
3220
+ assert (status == LLAMA_MEMORY_STATUS_SUCCESS);
3221
+
3222
+ if (i_next >= ubatches.size ()) {
3223
+ return nullptr ;
3224
+ }
3225
+
3226
+ kv->get_kv_attn () ->fill_slot (heads_attn[i_next], ubatches[i_next]);
3227
+ kv->get_kv_recurrent ()->find_slot (ubatches[i_next]);
3228
+
3229
+ return &ubatches[i_next++];
3230
+ }
3231
+
3232
+ std::vector<int64_t > & out_ids () override {
3233
+ assert (status == LLAMA_MEMORY_STATUS_SUCCESS);
3234
+
3235
+ return sbatch.out_ids ;
3236
+ }
3237
+
3238
+ llama_memory_status get_status () const override {
3239
+ return status;
3240
+ }
3241
+
3242
+ private:
3243
+ const llama_memory_status status;
3244
+
3245
+ llama_kv_cache_hybrid_recurrent * kv;
3246
+
3247
+ llama_sbatch sbatch;
3248
+
3249
+ // the index of the next ubatch to process
3250
+ size_t i_next = 0 ;
3251
+
3252
+ std::vector<uint32_t > heads_attn;
3253
+ std::vector<llama_ubatch> ubatches;
3254
+ };
3255
+
3256
+ llama_kv_cache_hybrid_recurrent::llama_kv_cache_hybrid_recurrent (
3257
+ const llama_model & model,
3258
+ /* attn */
3259
+ ggml_type attn_type_k,
3260
+ ggml_type attn_type_v,
3261
+ bool attn_v_trans,
3262
+ uint32_t attn_kv_size,
3263
+ uint32_t attn_n_pad,
3264
+ uint32_t attn_n_swa,
3265
+ llama_swa_type attn_swa_type,
3266
+ /* recurrent */
3267
+ ggml_type recurrent_type_k,
3268
+ ggml_type recurrent_type_v,
3269
+ uint32_t recurrent_kv_size,
3270
+ /* common */
3271
+ uint32_t n_seq_max,
3272
+ bool offload) :
3273
+ hparams(model.hparams),
3274
+ kv_attn(new llama_kv_cache_unified(
3275
+ model,
3276
+ [&](int32_t il) { return !model.hparams .recurrent_layer (il); },
3277
+ attn_type_k,
3278
+ attn_type_v,
3279
+ attn_v_trans,
3280
+ offload,
3281
+ attn_kv_size,
3282
+ n_seq_max,
3283
+ attn_n_pad,
3284
+ attn_n_swa,
3285
+ attn_swa_type
3286
+ )),
3287
+ kv_recurrent (new llama_kv_cache_recurrent(
3288
+ model,
3289
+ [&](int32_t il) { return model.hparams .recurrent_layer (il); },
3290
+ recurrent_type_k,
3291
+ recurrent_type_v,
3292
+ offload,
3293
+ recurrent_kv_size,
3294
+ n_seq_max
3295
+ )) {}
3296
+
3297
+ void llama_kv_cache_hybrid_recurrent::clear () {
3298
+ kv_attn ->clear ();
3299
+ kv_recurrent->clear ();
3300
+ }
3301
+
3302
+ bool llama_kv_cache_hybrid_recurrent::seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) {
3303
+ // Try removing from the recurrent cache first since it may fail. If it does
3304
+ // fail, the cache will not have been mutated.
3305
+ if (!kv_recurrent->seq_rm (seq_id, p0, p1)) {
3306
+ return false ;
3307
+ }
3308
+ return kv_attn->seq_rm (seq_id, p0, p1);
3309
+ }
3310
+
3311
+ void llama_kv_cache_hybrid_recurrent::seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) {
3312
+ kv_attn ->seq_cp (seq_id_src, seq_id_dst, p0, p1);
3313
+ kv_recurrent->seq_cp (seq_id_src, seq_id_dst, p0, p1);
3314
+ }
3315
+
3316
+ void llama_kv_cache_hybrid_recurrent::seq_keep (llama_seq_id seq_id) {
3317
+ kv_attn ->seq_keep (seq_id);
3318
+ kv_recurrent->seq_keep (seq_id);
3319
+ }
3320
+
3321
+ void llama_kv_cache_hybrid_recurrent::seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) {
3322
+ kv_attn->seq_add (seq_id, p0, p1, shift);
3323
+ kv_recurrent->seq_add (seq_id, p0, p1, shift);
3324
+ }
3325
+
3326
+ void llama_kv_cache_hybrid_recurrent::seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) {
3327
+ kv_attn ->seq_div (seq_id, p0, p1, d);
3328
+ kv_recurrent->seq_div (seq_id, p0, p1, d);
3329
+ }
3330
+
3331
+ llama_pos llama_kv_cache_hybrid_recurrent::seq_pos_min (llama_seq_id seq_id) const {
3332
+ // the min of the total cache is the max of the two caches' min values
3333
+ return std::max (kv_attn->seq_pos_min (seq_id), kv_recurrent->seq_pos_min (seq_id));
3334
+ }
3335
+
3336
+ llama_pos llama_kv_cache_hybrid_recurrent::seq_pos_max (llama_seq_id seq_id) const {
3337
+ // the max of the total cache is the min of the two caches' max values
3338
+ return std::min (kv_attn->seq_pos_max (seq_id), kv_recurrent->seq_pos_max (seq_id));
3339
+ }
3340
+
3341
+ llama_memory_decode_state_ptr llama_kv_cache_hybrid_recurrent::init (const llama_batch & batch, uint32_t n_ubatch, bool embd_pooled, bool logits_all) {
3342
+
3343
+ // since this includes a recurrent cache, we cannot use split_simple
3344
+ auto sbatch = llama_sbatch (batch, hparams.n_embd , true , logits_all);
3345
+
3346
+ // follow the recurrent pattern for creating the ubatch splits
3347
+ std::vector<llama_ubatch> ubatches;
3348
+ while (sbatch.n_tokens > 0 ) {
3349
+ llama_ubatch ubatch;
3350
+
3351
+ if (embd_pooled) {
3352
+ // Pooled embeddings cannot be split across ubatches (yet)
3353
+ ubatch = sbatch.split_seq (n_ubatch);
3354
+ } else {
3355
+ ubatch = sbatch.split_equal (n_ubatch);
3356
+ }
3357
+
3358
+ ubatches.push_back (ubatch);
3359
+ }
3360
+
3361
+ // prepare the recurrent batches first
3362
+ if (!kv_recurrent->prepare (ubatches)) {
3363
+ // TODO: will the recurrent cache be in an undefined state at this point?
3364
+ LLAMA_LOG_ERROR (" %s: failed to prepare recurrent ubatches\n " , __func__);
3365
+ return std::make_unique<llama_kv_cache_hybrid_recurrent_decode_state_t >(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
3366
+ }
3367
+
3368
+ // prepare the attention cache
3369
+ auto heads_attn = kv_attn->prepare (ubatches);
3370
+ if (heads_attn.empty ()) {
3371
+ LLAMA_LOG_ERROR (" %s: failed to prepare attention ubatches\n " , __func__);
3372
+ return std::make_unique<llama_kv_cache_hybrid_recurrent_decode_state_t >(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
3373
+ }
3374
+
3375
+ return std::make_unique<llama_kv_cache_hybrid_recurrent_decode_state_t >(
3376
+ this , std::move (sbatch), std::move (heads_attn), std::move (ubatches));
3377
+ }
3378
+
3379
+ bool llama_kv_cache_hybrid_recurrent::update (llama_context & lctx) {
3380
+ bool res = false ;
3381
+
3382
+ res = res | kv_attn ->update (lctx);
3383
+ res = res | kv_recurrent->update (lctx);
3384
+
3385
+ return res;
3386
+ }
3387
+
3388
+ void llama_kv_cache_hybrid_recurrent::defrag_sched (float thold) {
3389
+ kv_attn ->defrag_sched (thold);
3390
+ kv_recurrent->defrag_sched (thold);
3391
+ }
3392
+
3393
+ void llama_kv_cache_hybrid_recurrent::set_full () {
3394
+ kv_attn ->set_full ();
3395
+ kv_recurrent->set_full ();
3396
+ }
3397
+
3398
+ bool llama_kv_cache_hybrid_recurrent::get_can_shift () const {
3399
+ // TODO: Should this return true if the attention cache can shift?
3400
+ return false ;
3401
+ }
3402
+
3403
+ void llama_kv_cache_hybrid_recurrent::state_write (llama_io_write_i & io, llama_seq_id seq_id) const {
3404
+ kv_attn ->state_write (io, seq_id);
3405
+ kv_recurrent->state_write (io, seq_id);
3406
+ }
3407
+
3408
+ void llama_kv_cache_hybrid_recurrent::state_read (llama_io_read_i & io, llama_seq_id seq_id) {
3409
+ kv_attn ->state_read (io, seq_id);
3410
+ kv_recurrent->state_read (io, seq_id);
3411
+ }
3412
+
3413
+ llama_kv_cache_unified * llama_kv_cache_hybrid_recurrent::get_kv_attn () const {
3414
+ return kv_attn.get ();
3415
+ }
3416
+
3417
+ llama_kv_cache_recurrent * llama_kv_cache_hybrid_recurrent::get_kv_recurrent () const {
3418
+ return kv_recurrent.get ();
3419
+ }
0 commit comments