Skip to content

Commit 0b3853d

Browse files
committed
fix: Give ownership of child caches to the hybrid cache
The parent should fully own the lifecycle of the children which is managed by the m_children member holding unique_ptrs. These need to be initialized correctly, so the constructor now takes the input vector of child_cache by value instead of reference so that the child pointers can be transferred to the parent cache. The expectation is that the vector of child_cache instances will be instantiated in-place with move semantics. Branch: HybridCache Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>
1 parent 04c94d6 commit 0b3853d

File tree

2 files changed

+17
-14
lines changed

2 files changed

+17
-14
lines changed

src/llama-kv-cache.cpp

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2419,23 +2419,23 @@ bool llama_kv_cache_recurrent::state_read_data(llama_io_read_i & io, uint32_t ce
24192419
// llama_kv_cache_hybrid
24202420
//
24212421
llama_kv_cache_hybrid::llama_kv_cache_hybrid(
2422-
const llama_hparams & hparams,
2423-
const std::vector<child_cache> & children) :
2422+
const llama_hparams & hparams,
2423+
std::vector<child_cache> children) :
24242424
m_hparams(hparams),
24252425
m_layer_cache_map(
24262426
[](const std::vector<child_cache>& caches) -> std::unordered_map<size_t, llama_kv_cache*> {
24272427
std::unordered_map<size_t, llama_kv_cache*> map;
24282428
for (const auto & cache : caches) {
24292429
for (size_t layer_id : cache.layer_ids) {
2430-
map[layer_id] = cache.child;
2430+
map[layer_id] = cache.child.get();
24312431
}
24322432
}
24332433

24342434
return map;
24352435
}(children)
24362436
),
24372437
m_children(
2438-
[](std::vector<child_cache> caches) -> std::set<llama_kv_cache*> {
2438+
[](std::vector<child_cache>& caches) -> std::set<std::unique_ptr<llama_kv_cache>> {
24392439
// Sort the caches by the lowest layer ID so the order is repeatable
24402440
for (auto & cache : caches) {
24412441
GGML_ASSERT(cache.layer_ids.size() > 0);
@@ -2444,22 +2444,22 @@ llama_kv_cache_hybrid::llama_kv_cache_hybrid(
24442444
std::sort(caches.begin(), caches.end(), [](const child_cache & a, const child_cache & b) {
24452445
return a.layer_ids[0] < b.layer_ids[0];
24462446
});
2447-
std::set<llama_kv_cache*> unique_caches;
2448-
for (const auto & cache : caches) {
2449-
unique_caches.insert(cache.child);
2447+
std::set<std::unique_ptr<llama_kv_cache>> unique_caches;
2448+
for (auto & cache : caches) {
2449+
unique_caches.emplace(cache.child.release());
24502450
}
24512451
return unique_caches;
24522452
}(children)
24532453
),
24542454
m_has_recurrent(
2455-
[](const std::vector<child_cache>& caches) -> bool {
2455+
[](const std::set<std::unique_ptr<llama_kv_cache>> & caches) -> bool {
24562456
for (const auto & cache : caches) {
2457-
if (dynamic_cast<llama_kv_cache_recurrent *>(cache.child)) {
2457+
if (dynamic_cast<llama_kv_cache_recurrent *>(cache.get())) {
24582458
return true;
24592459
}
24602460
}
24612461
return false;
2462-
}(children)
2462+
}(m_children)
24632463
)
24642464
{
24652465
// Ensure at least one child

src/llama-kv-cache.h

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -413,13 +413,16 @@ class llama_kv_cache_hybrid : public llama_kv_cache {
413413
public:
414414

415415
struct child_cache {
416-
llama_kv_cache * child;
417-
std::vector<size_t> layer_ids;
416+
std::unique_ptr<llama_kv_cache> child;
417+
std::vector<size_t> layer_ids;
418+
419+
child_cache(std::unique_ptr<llama_kv_cache> child_, std::vector<size_t> layer_ids_)
420+
: child(std::move(child_)), layer_ids(std::move(layer_ids_)) {}
418421
};
419422

420423
llama_kv_cache_hybrid(
421424
const llama_hparams & hparams,
422-
const std::vector<child_cache> & children);
425+
std::vector<child_cache> children);
423426

424427
//
425428
// llama_memory_i
@@ -476,7 +479,7 @@ class llama_kv_cache_hybrid : public llama_kv_cache {
476479

477480
const llama_hparams & m_hparams;
478481
const std::unordered_map<size_t, llama_kv_cache *> m_layer_cache_map;
479-
const std::set<llama_kv_cache *> m_children; // Ordered for state IO
482+
const std::set<std::unique_ptr<llama_kv_cache>> m_children; // Ordered for state IO
480483
const bool m_has_recurrent;
481484
};
482485

0 commit comments

Comments
 (0)