Skip to content

Commit 7778b5b

Browse files
committed
feat: First pass at llama_kv_cache_hybrid
This implementation covers both `llama_memory_i` and `llama_kv_cache` interfaces, but they could very well not be correct. Branch: HybridCache Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>
1 parent 2f56761 commit 7778b5b

File tree

2 files changed

+299
-0
lines changed

2 files changed

+299
-0
lines changed

src/llama-kv-cache.cpp

Lines changed: 225 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2384,6 +2384,231 @@ bool llama_kv_cache_recurrent::state_read_data(llama_io_read_i & io, uint32_t ce
23842384
return true;
23852385
}
23862386

2387+
//
2388+
// llama_kv_cache_hybrid
2389+
//
2390+
llama_kv_cache_hybrid::llama_kv_cache_hybrid(
2391+
const llama_hparams & hparams,
2392+
const std::vector<child_cache> & children) :
2393+
m_hparams(hparams),
2394+
m_layer_cache_map(
2395+
[](const std::vector<child_cache>& caches) -> std::unordered_map<size_t, llama_kv_cache*> {
2396+
std::unordered_map<size_t, llama_kv_cache*> map;
2397+
for (const auto & cache : caches) {
2398+
for (size_t layer_id : cache.layer_ids) {
2399+
map[layer_id] = cache.child;
2400+
}
2401+
}
2402+
2403+
return map;
2404+
}(children)
2405+
),
2406+
m_children(
2407+
[](std::vector<child_cache> caches) -> std::set<llama_kv_cache*> {
2408+
// Sort the caches by the lowest layer ID so the order is repeatable
2409+
for (auto & cache : caches) {
2410+
GGML_ASSERT(cache.layer_ids.size() > 0);
2411+
std::sort(cache.layer_ids.begin(), cache.layer_ids.end());
2412+
}
2413+
std::sort(caches.begin(), caches.end(), [](const child_cache & a, const child_cache & b) {
2414+
return a.layer_ids[0] < b.layer_ids[0];
2415+
});
2416+
std::set<llama_kv_cache*> unique_caches;
2417+
for (const auto & cache : caches) {
2418+
unique_caches.insert(cache.child);
2419+
}
2420+
return unique_caches;
2421+
}(children)
2422+
),
2423+
m_has_recurrent(
2424+
[](const std::vector<child_cache>& caches) -> bool {
2425+
for (const auto & cache : caches) {
2426+
if (dynamic_cast<llama_kv_cache_recurrent *>(cache.child)) {
2427+
return true;
2428+
}
2429+
}
2430+
return false;
2431+
}(children)
2432+
)
2433+
{
2434+
// Ensure at least one child
2435+
GGML_ASSERT(m_children.size() > 0);
2436+
2437+
// Ensure layers are not overlapping and are concurrent
2438+
std::set<size_t> seen_layers;
2439+
size_t max_layer = 0;
2440+
for (const auto & cache : children) {
2441+
for (const auto & layer_id : cache.layer_ids) {
2442+
GGML_ASSERT(seen_layers.find(layer_id) == seen_layers.end());
2443+
seen_layers.insert(layer_id);
2444+
if (layer_id > max_layer) {
2445+
max_layer = layer_id;
2446+
}
2447+
}
2448+
}
2449+
GGML_ASSERT(max_layer == seen_layers.size());
2450+
}
2451+
2452+
void llama_kv_cache_hybrid::clear() {
2453+
for (const auto & cache : m_children) {
2454+
cache->clear();
2455+
}
2456+
}
2457+
2458+
bool llama_kv_cache_hybrid::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) {
2459+
// TODO: Will it cause problems if some caches are able to remove the seq
2460+
// but others aren't?
2461+
bool removed = true;
2462+
for (const auto & cache : m_children) {
2463+
removed = cache->seq_rm(seq_id, p0, p1) && removed;
2464+
}
2465+
return removed;
2466+
}
2467+
2468+
void llama_kv_cache_hybrid::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) {
2469+
for (const auto & cache : m_children) {
2470+
cache->seq_cp(seq_id_src, seq_id_dst, p0, p1);
2471+
}
2472+
}
2473+
2474+
void llama_kv_cache_hybrid::seq_keep(llama_seq_id seq_id) {
2475+
for (const auto & cache : m_children) {
2476+
cache->seq_keep(seq_id);
2477+
}
2478+
}
2479+
2480+
void llama_kv_cache_hybrid::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta) {
2481+
for (const auto & cache : m_children) {
2482+
cache->seq_add(seq_id, p0, p1, delta);
2483+
}
2484+
}
2485+
2486+
void llama_kv_cache_hybrid::seq_div(llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) {
2487+
for (const auto & cache : m_children) {
2488+
cache->seq_div(seq_id, p0, p1, d);
2489+
}
2490+
}
2491+
2492+
llama_pos llama_kv_cache_hybrid::seq_pos_max(llama_seq_id seq_id) const {
2493+
llama_pos max_pos = 0;
2494+
for (const auto & cache : m_children) {
2495+
max_pos = std::max(max_pos, cache->seq_pos_max(seq_id));
2496+
}
2497+
return max_pos;
2498+
}
2499+
2500+
void llama_kv_cache_hybrid::restore() {
2501+
for (const auto & cache : m_children) {
2502+
cache->restore();
2503+
}
2504+
}
2505+
2506+
void llama_kv_cache_hybrid::commit() {
2507+
for (const auto & cache : m_children) {
2508+
cache->commit();
2509+
}
2510+
}
2511+
2512+
bool llama_kv_cache_hybrid::update(llama_context & ctx) {
2513+
bool updated = false;
2514+
for (const auto & cache : m_children) {
2515+
updated = cache->update(ctx) || updated;
2516+
}
2517+
return updated;
2518+
}
2519+
2520+
void llama_kv_cache_hybrid::defrag_sched(float thold) {
2521+
for (const auto & cache : m_children) {
2522+
cache->defrag_sched(thold);
2523+
}
2524+
}
2525+
2526+
void llama_kv_cache_hybrid::set_full() {
2527+
for (const auto & cache : m_children) {
2528+
cache->set_full();
2529+
}
2530+
}
2531+
2532+
llama_sbatch llama_kv_cache_hybrid::sbatch_init(const llama_batch & batch, bool logits_all) {
2533+
// If any of the caches are recurrent, require simple split
2534+
return llama_sbatch(batch, m_hparams.n_embd, m_has_recurrent, logits_all);
2535+
}
2536+
2537+
llama_ubatch llama_kv_cache_hybrid::ubatch_next(llama_sbatch & sbatch, uint32_t n_ubatch, bool embd_pooled) const {
2538+
if (m_has_recurrent) {
2539+
return sbatch.split_simple(n_ubatch);
2540+
}
2541+
if (embd_pooled) {
2542+
// Pooled embeddings cannot be split across ubatches (yet)
2543+
return sbatch.split_seq(n_ubatch);
2544+
}
2545+
return sbatch.split_equal(n_ubatch);
2546+
}
2547+
2548+
bool llama_kv_cache_hybrid::find_slot(const llama_ubatch & batch) {
2549+
bool found = true;
2550+
for (const auto & cache : m_children) {
2551+
found = cache->find_slot(batch) && found;
2552+
}
2553+
return found;
2554+
}
2555+
2556+
int32_t llama_kv_cache_hybrid::get_n_tokens() const {
2557+
// The number of tokens should be the same across all child caches
2558+
int32_t n_tokens = -1;
2559+
for (const auto & cache : m_children) {
2560+
const auto cache_n_tokens = cache->get_n_tokens();
2561+
GGML_ASSERT(n_tokens == -1 || cache_n_tokens == n_tokens);
2562+
n_tokens = cache_n_tokens;
2563+
}
2564+
return n_tokens;
2565+
}
2566+
2567+
int32_t llama_kv_cache_hybrid::get_used_cells() const {
2568+
// TODO: Is this correct?
2569+
// Return the largetst number of used cells
2570+
int32_t used_cells = -1;
2571+
for (const auto & cache : m_children) {
2572+
used_cells = std::max(used_cells, cache->get_used_cells());
2573+
}
2574+
return used_cells;
2575+
}
2576+
2577+
llama_pos llama_kv_cache_hybrid::get_pos_max() const {
2578+
llama_pos pos_max = -1;
2579+
for (const auto & cache : m_children) {
2580+
pos_max = std::max(pos_max, cache->get_pos_max());
2581+
}
2582+
return pos_max;
2583+
}
2584+
2585+
bool llama_kv_cache_hybrid::get_can_shift() const {
2586+
// TODO: Is this correct?
2587+
// If any children can shift, return true
2588+
for (const auto & cache : m_children) {
2589+
if (cache->get_can_shift()) {
2590+
return true;
2591+
}
2592+
}
2593+
return false;
2594+
}
2595+
2596+
void llama_kv_cache_hybrid::state_write(llama_io_write_i & io, llama_seq_id seq_id) const {
2597+
// Write each cache state in order. Note that order is guaranteed at
2598+
// initialization by using an ordered set sorted by lowest layer ID
2599+
for (const auto & cache : m_children) {
2600+
cache->state_write(io, seq_id);
2601+
}
2602+
}
2603+
2604+
void llama_kv_cache_hybrid::state_read(llama_io_read_i & io, llama_seq_id seq_id) {
2605+
// Read each cache state in order. Note that order is guaranteed at
2606+
// initialization by using an ordered set sorted by lowest layer ID
2607+
for (const auto & cache : m_children) {
2608+
cache->state_read(io, seq_id);
2609+
}
2610+
}
2611+
23872612
//
23882613
// kv cache view
23892614
//

src/llama-kv-cache.h

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
#include <set>
1111
#include <vector>
12+
#include <unordered_map>
1213

1314
struct llama_cparams;
1415
struct llama_hparams;
@@ -395,6 +396,79 @@ class llama_kv_cache_recurrent : public llama_kv_cache {
395396
bool state_read_data(llama_io_read_i & io, uint32_t cell_count);
396397
};
397398

399+
//
400+
// llama_kv_cache_hybrid
401+
//
402+
403+
class llama_kv_cache_hybrid : public llama_kv_cache {
404+
public:
405+
406+
struct child_cache {
407+
llama_kv_cache * child;
408+
std::vector<size_t> layer_ids;
409+
};
410+
411+
llama_kv_cache_hybrid(
412+
const llama_hparams & hparams,
413+
const std::vector<child_cache> & children);
414+
415+
//
416+
// llama_memory_i
417+
//
418+
419+
void clear() override;
420+
421+
bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) override;
422+
void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override;
423+
void seq_keep(llama_seq_id seq_id) override;
424+
void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta) override;
425+
void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) override;
426+
427+
llama_pos seq_pos_max(llama_seq_id seq_id) const override;
428+
429+
//
430+
// llama_kv_cache
431+
//
432+
433+
void restore() override;
434+
void commit() override;
435+
436+
bool update(llama_context & ctx) override;
437+
438+
void defrag_sched(float thold) override;
439+
440+
void set_full() override;
441+
442+
llama_sbatch sbatch_init(const llama_batch & batch, bool logits_all) override;
443+
444+
llama_ubatch ubatch_next(llama_sbatch & sbatch, uint32_t n_ubatch, bool embd_pooled) const override;
445+
446+
// updates the cache head
447+
// Note: On success, it's important that cache.head points
448+
// to the first cell of the slot.
449+
bool find_slot(const llama_ubatch & batch) override;
450+
451+
int32_t get_n_tokens() const override;
452+
int32_t get_used_cells() const override;
453+
454+
// TODO: better data structures to reduce the cost of this operation
455+
llama_pos get_pos_max() const override;
456+
457+
bool get_can_shift() const override;
458+
459+
// state write/load
460+
461+
void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const override;
462+
void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1) override;
463+
464+
private:
465+
466+
const llama_hparams & m_hparams;
467+
const std::unordered_map<size_t, llama_kv_cache *> m_layer_cache_map;
468+
const std::set<llama_kv_cache *> m_children; // Ordered for state IO
469+
const bool m_has_recurrent;
470+
};
471+
398472

399473
//
400474
// kv cache view

0 commit comments

Comments
 (0)