Skip to content

Commit a4cc4aa

Browse files
committed
fix: Overhaul hybrid cache for refactor part3 (::init interface)
Branch: HybridCache Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>
1 parent 8aee2e7 commit a4cc4aa

File tree

3 files changed

+137
-78
lines changed

3 files changed

+137
-78
lines changed

src/llama-kv-cache.cpp

Lines changed: 126 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -399,7 +399,7 @@ std::vector<uint32_t> llama_kv_cache_unified::prepare(const std::vector<llama_ub
399399
break;
400400
}
401401

402-
// remeber the position that we found
402+
// remember the position that we found
403403
res.push_back(head_new);
404404

405405
// store the old state of the cells in the recovery stack
@@ -3037,11 +3037,93 @@ bool llama_kv_cache_recurrent::state_read_data(llama_io_read_i & io, uint32_t ce
30373037
//
30383038
// llama_kv_cache_hybrid
30393039
//
3040-
llama_kv_cache_hybrid::llama_kv_cache_hybrid(
3041-
const llama_hparams & hparams,
3042-
std::vector<child_cache> children) :
3043-
m_hparams(hparams),
3044-
m_children(
3040+
3041+
3042+
class llama_kv_cache_hybrid_decode_state_t : public llama_memory_decode_state_i {
3043+
public:
3044+
explicit llama_kv_cache_hybrid_decode_state_t(
3045+
std::vector<llama_memory_decode_state_ptr> decode_states) :
3046+
status([](const std::vector<llama_memory_decode_state_ptr> & decode_states) -> llama_memory_status {
3047+
for (const auto & decode_state : decode_states) {
3048+
if (!decode_state) {
3049+
return LLAMA_MEMORY_STATUS_FAILED_PREPARE;
3050+
}
3051+
const auto & status = decode_state->get_status();
3052+
if (status != LLAMA_MEMORY_STATUS_SUCCESS) {
3053+
return status;
3054+
}
3055+
}
3056+
return LLAMA_MEMORY_STATUS_SUCCESS;
3057+
}(decode_states)),
3058+
decode_states(std::move(decode_states)) {
3059+
3060+
// make sure at least one decode state
3061+
assert(!decode_states.empty());
3062+
3063+
// make sure all out_ids match across states
3064+
// TODO: This could be expensive, so maybe don't do it?
3065+
const auto & out_ids = decode_states[0]->out_ids();
3066+
for (size_t i = 1; i < decode_states.size(); ++i) {
3067+
const auto & out_ids_i = decode_states[i]->out_ids();
3068+
assert(out_ids.size() == out_ids_i.size());
3069+
for (size_t j = 0; j < out_ids.size(); ++j) {
3070+
assert(out_ids[j] == out_ids_i[j]);
3071+
}
3072+
}
3073+
}
3074+
3075+
~llama_kv_cache_hybrid_decode_state_t() = default;
3076+
3077+
llama_ubatch * next() override {
3078+
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
3079+
3080+
// hit next on each child
3081+
std::vector<llama_ubatch *> next_ubatches;
3082+
for (const auto & decode_state : decode_states) {
3083+
next_ubatches.push_back(decode_state->next());
3084+
}
3085+
3086+
// make sure they all match
3087+
// TODO: unnecessary safety?
3088+
llama_ubatch * res = next_ubatches[0];
3089+
assert(res);
3090+
for (size_t i = 1; i < next_ubatches.size(); ++i) {
3091+
llama_ubatch * ubatch_i = next_ubatches[i];
3092+
assert(ubatch_i);
3093+
assert(ubatch_i->n_tokens == res->n_tokens);
3094+
assert(ubatch_i->n_seq_tokens == res->n_seq_tokens);
3095+
assert(ubatch_i->n_seqs == res->n_seqs);
3096+
for (size_t j = 0; j < res->n_tokens; ++j) {
3097+
assert(ubatch_i->token[j] == res->token[j]);
3098+
assert(ubatch_i->pos[j] == res->pos[j]);
3099+
assert(ubatch_i->output[j] == res->output[j]);
3100+
}
3101+
for (size_t j = 0; j < res->n_seqs; ++j) {
3102+
assert(ubatch_i->n_seq_id[j] == res->n_seq_id[j]);
3103+
}
3104+
}
3105+
3106+
// return the first ubatch since they all match
3107+
return res;
3108+
}
3109+
3110+
std::vector<int64_t> & out_ids() override {
3111+
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
3112+
3113+
return decode_states[0]->out_ids();
3114+
}
3115+
3116+
llama_memory_status get_status() const override {
3117+
return status;
3118+
}
3119+
3120+
private:
3121+
const llama_memory_status status;
3122+
std::vector<llama_memory_decode_state_ptr> decode_states;
3123+
};
3124+
3125+
llama_kv_cache_hybrid::llama_kv_cache_hybrid(std::vector<child_cache> children_) :
3126+
children(
30453127
[](std::vector<child_cache>& caches) -> std::set<std::unique_ptr<llama_kv_cache>> {
30463128
// Sort the caches by the lowest layer ID so the order is repeatable
30473129
for (auto & cache : caches) {
@@ -3056,26 +3138,26 @@ llama_kv_cache_hybrid::llama_kv_cache_hybrid(
30563138
unique_caches.emplace(cache.child.release());
30573139
}
30583140
return unique_caches;
3059-
}(children)
3141+
}(children_)
30603142
),
3061-
m_has_recurrent(
3143+
has_recurrent(
30623144
[](const std::set<std::unique_ptr<llama_kv_cache>> & caches) -> bool {
30633145
for (const auto & cache : caches) {
30643146
if (dynamic_cast<llama_kv_cache_recurrent *>(cache.get())) {
30653147
return true;
30663148
}
30673149
}
30683150
return false;
3069-
}(m_children)
3151+
}(children)
30703152
)
30713153
{
30723154
// Ensure at least one child
3073-
GGML_ASSERT(m_children.size() > 0);
3155+
GGML_ASSERT(children.size() > 0);
30743156

30753157
// Ensure layers are not overlapping and are concurrent
30763158
std::set<size_t> seen_layers;
30773159
size_t max_layer = 0;
3078-
for (const auto & cache : children) {
3160+
for (const auto & cache : children_) {
30793161
for (const auto & layer_id : cache.layer_ids) {
30803162
GGML_ASSERT(seen_layers.find(layer_id) == seen_layers.end());
30813163
seen_layers.insert(layer_id);
@@ -3089,7 +3171,7 @@ llama_kv_cache_hybrid::llama_kv_cache_hybrid(
30893171
}
30903172

30913173
void llama_kv_cache_hybrid::clear() {
3092-
for (const auto & cache : m_children) {
3174+
for (const auto & cache : children) {
30933175
cache->clear();
30943176
}
30953177
}
@@ -3102,40 +3184,40 @@ bool llama_kv_cache_hybrid::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos
31023184
}
31033185

31043186
// Do the removal from each child which should never fail
3105-
for (const auto & cache : m_children) {
3187+
for (const auto & cache : children) {
31063188
const bool failed = cache->seq_rm(seq_id, p0, p1);
31073189
GGML_ASSERT(!failed);
31083190
}
31093191
return true;
31103192
}
31113193

31123194
void llama_kv_cache_hybrid::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) {
3113-
for (const auto & cache : m_children) {
3195+
for (const auto & cache : children) {
31143196
cache->seq_cp(seq_id_src, seq_id_dst, p0, p1);
31153197
}
31163198
}
31173199

31183200
void llama_kv_cache_hybrid::seq_keep(llama_seq_id seq_id) {
3119-
for (const auto & cache : m_children) {
3201+
for (const auto & cache : children) {
31203202
cache->seq_keep(seq_id);
31213203
}
31223204
}
31233205

31243206
void llama_kv_cache_hybrid::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta) {
3125-
for (const auto & cache : m_children) {
3207+
for (const auto & cache : children) {
31263208
cache->seq_add(seq_id, p0, p1, delta);
31273209
}
31283210
}
31293211

31303212
void llama_kv_cache_hybrid::seq_div(llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) {
3131-
for (const auto & cache : m_children) {
3213+
for (const auto & cache : children) {
31323214
cache->seq_div(seq_id, p0, p1, d);
31333215
}
31343216
}
31353217

31363218
llama_pos llama_kv_cache_hybrid::seq_pos_min(llama_seq_id seq_id) const {
31373219
llama_pos min_pos = -1;
3138-
for (const auto & cache : m_children) {
3220+
for (const auto & cache : children) {
31393221
const auto child_min_pos = cache->seq_pos_min(seq_id);
31403222
min_pos = min_pos == -1 ? child_min_pos : std::min(min_pos, child_min_pos);
31413223
}
@@ -3144,81 +3226,67 @@ llama_pos llama_kv_cache_hybrid::seq_pos_min(llama_seq_id seq_id) const {
31443226

31453227
llama_pos llama_kv_cache_hybrid::seq_pos_max(llama_seq_id seq_id) const {
31463228
llama_pos max_pos = 0;
3147-
for (const auto & cache : m_children) {
3229+
for (const auto & cache : children) {
31483230
max_pos = std::max(max_pos, cache->seq_pos_max(seq_id));
31493231
}
31503232
return max_pos;
31513233
}
31523234

3153-
void llama_kv_cache_hybrid::restore() {
3154-
for (const auto & cache : m_children) {
3155-
cache->restore();
3156-
}
3157-
}
3235+
llama_memory_decode_state_ptr llama_kv_cache_hybrid::init(
3236+
const llama_batch & batch,
3237+
uint32_t n_ubatch,
3238+
bool embd_pooled,
3239+
bool logits_all,
3240+
bool split_equal) {
31583241

3159-
void llama_kv_cache_hybrid::commit() {
3160-
for (const auto & cache : m_children) {
3161-
cache->commit();
3242+
// recurrent children require equal splits
3243+
// TODO: just ignore this if set incorrectly?
3244+
assert(!has_recurrent || split_equal);
3245+
3246+
// init all children and capture their decode states
3247+
std::vector<llama_memory_decode_state_ptr> decode_states;
3248+
for (const auto & child : children) {
3249+
decode_states.emplace_back(
3250+
child->init(batch, n_ubatch, embd_pooled, logits_all, split_equal));
31623251
}
3252+
3253+
// return the hybrid decode state
3254+
return std::make_unique<llama_kv_cache_hybrid_decode_state_t>(std::move(decode_states));
31633255
}
31643256

31653257
bool llama_kv_cache_hybrid::update(llama_context & ctx) {
31663258
bool updated = false;
3167-
for (const auto & cache : m_children) {
3259+
for (const auto & cache : children) {
31683260
updated = cache->update(ctx) || updated;
31693261
}
31703262
return updated;
31713263
}
31723264

31733265
void llama_kv_cache_hybrid::defrag_sched(float thold) {
3174-
for (const auto & cache : m_children) {
3266+
for (const auto & cache : children) {
31753267
cache->defrag_sched(thold);
31763268
}
31773269
}
31783270

31793271
void llama_kv_cache_hybrid::set_full() {
3180-
for (const auto & cache : m_children) {
3272+
for (const auto & cache : children) {
31813273
cache->set_full();
31823274
}
31833275
}
31843276

31853277
bool llama_kv_cache_hybrid::can_seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) const {
3186-
for (const auto & cache : m_children) {
3278+
for (const auto & cache : children) {
31873279
if (!cache->can_seq_rm(seq_id, p0, p1)) {
31883280
return false;
31893281
}
31903282
}
31913283
return true;
31923284
}
31933285

3194-
llama_sbatch llama_kv_cache_hybrid::sbatch_init(const llama_batch & batch, bool logits_all) {
3195-
// If any of the caches are recurrent, require equal split
3196-
return llama_sbatch(batch, m_hparams.n_embd, !m_has_recurrent, logits_all);
3197-
}
3198-
3199-
llama_ubatch llama_kv_cache_hybrid::ubatch_next(llama_sbatch & sbatch, uint32_t n_ubatch, bool embd_pooled) const {
3200-
if (embd_pooled) {
3201-
// Pooled embeddings cannot be split across ubatches (yet)
3202-
return sbatch.split_seq(n_ubatch);
3203-
}
3204-
if (m_has_recurrent) {
3205-
return sbatch.split_equal(n_ubatch);
3206-
}
3207-
return sbatch.split_simple(n_ubatch);
3208-
}
3209-
3210-
bool llama_kv_cache_hybrid::find_slot(const llama_ubatch & batch) {
3211-
bool found = true;
3212-
for (const auto & cache : m_children) {
3213-
found = cache->find_slot(batch) && found;
3214-
}
3215-
return found;
3216-
}
3217-
32183286
bool llama_kv_cache_hybrid::get_can_shift() const {
32193287
// TODO: Is this correct?
32203288
// If any children can shift, return true
3221-
for (const auto & cache : m_children) {
3289+
for (const auto & cache : children) {
32223290
if (cache->get_can_shift()) {
32233291
return true;
32243292
}
@@ -3229,15 +3297,15 @@ bool llama_kv_cache_hybrid::get_can_shift() const {
32293297
void llama_kv_cache_hybrid::state_write(llama_io_write_i & io, llama_seq_id seq_id) const {
32303298
// Write each cache state in order. Note that order is guaranteed at
32313299
// initialization by using an ordered set sorted by lowest layer ID
3232-
for (const auto & cache : m_children) {
3300+
for (const auto & cache : children) {
32333301
cache->state_write(io, seq_id);
32343302
}
32353303
}
32363304

32373305
void llama_kv_cache_hybrid::state_read(llama_io_read_i & io, llama_seq_id seq_id) {
32383306
// Read each cache state in order. Note that order is guaranteed at
32393307
// initialization by using an ordered set sorted by lowest layer ID
3240-
for (const auto & cache : m_children) {
3308+
for (const auto & cache : children) {
32413309
cache->state_read(io, seq_id);
32423310
}
32433311
}

src/llama-kv-cache.h

Lines changed: 10 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -462,18 +462,15 @@ class llama_kv_cache_hybrid : public llama_kv_cache {
462462
: child(std::move(child_)), layer_ids(std::move(layer_ids_)) {}
463463
};
464464

465-
llama_kv_cache_hybrid(
466-
const llama_hparams & hparams,
467-
std::vector<child_cache> children);
468-
465+
explicit llama_kv_cache_hybrid(std::vector<child_cache> children);
469466
virtual ~llama_kv_cache_hybrid() = default;
470467

471468
// getters for specific child cache type
472469
// NOTE: This will fail if there are multiple of the given type
473470
template<typename child_t>
474471
const child_t * get_child_cache() const {
475472
const child_t * child = nullptr;
476-
for (const auto & child_cache : m_children) {
473+
for (const auto & child_cache : children) {
477474
const child_t * child_cast = dynamic_cast<const child_t *>(child_cache.get());
478475
if (child_cast) {
479476
GGML_ASSERT(!child);
@@ -502,8 +499,12 @@ class llama_kv_cache_hybrid : public llama_kv_cache {
502499
// llama_kv_cache
503500
//
504501

505-
void restore() override;
506-
void commit() override;
502+
llama_memory_decode_state_ptr init(
503+
const llama_batch & batch,
504+
uint32_t n_ubatch,
505+
bool embd_pooled,
506+
bool logits_all,
507+
bool split_equal = true) override;
507508

508509
bool update(llama_context & ctx) override;
509510

@@ -513,15 +514,6 @@ class llama_kv_cache_hybrid : public llama_kv_cache {
513514

514515
bool can_seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) const override;
515516

516-
llama_sbatch sbatch_init(const llama_batch & batch, bool logits_all) override;
517-
518-
llama_ubatch ubatch_next(llama_sbatch & sbatch, uint32_t n_ubatch, bool embd_pooled) const override;
519-
520-
// updates the cache head
521-
// Note: On success, it's important that cache.head points
522-
// to the first cell of the slot.
523-
bool find_slot(const llama_ubatch & batch) override;
524-
525517
bool get_can_shift() const override;
526518

527519
// state write/load
@@ -531,7 +523,6 @@ class llama_kv_cache_hybrid : public llama_kv_cache {
531523

532524
private:
533525

534-
const llama_hparams & m_hparams;
535-
const std::set<std::unique_ptr<llama_kv_cache>> m_children; // Ordered for state IO
536-
const bool m_has_recurrent;
526+
const std::set<std::unique_ptr<llama_kv_cache>> children; // Ordered for state IO
527+
const bool has_recurrent;
537528
};

src/llama-model.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13264,7 +13264,7 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
1326413264
);
1326513265

1326613266
// initialize the hybrid cache with both children
13267-
res = new llama_kv_cache_hybrid(hparams, std::move(children));
13267+
res = new llama_kv_cache_hybrid(std::move(children));
1326813268
} else if (llm_arch_is_recurrent(arch)) {
1326913269
res = new llama_kv_cache_recurrent(
1327013270
*this,

0 commit comments

Comments
 (0)