From be9558e34e957cf666c070724bc581c86962f290 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Thu, 22 May 2025 13:01:01 +0300 Subject: [PATCH 1/7] kv-cache : rework kv_cell ggml-ci --- src/llama-kv-cache.cpp | 366 +++++++++++++++++------------------------ src/llama-kv-cache.h | 47 +++--- src/llama-kv-cells.h | 259 +++++++++++++++++++++++++++++ 3 files changed, 432 insertions(+), 240 deletions(-) create mode 100644 src/llama-kv-cells.h diff --git a/src/llama-kv-cache.cpp b/src/llama-kv-cache.cpp index a2624d71589b5..e35497d038f41 100644 --- a/src/llama-kv-cache.cpp +++ b/src/llama-kv-cache.cpp @@ -65,8 +65,6 @@ llama_kv_cache_unified::llama_kv_cache_unified( }; head = 0; - size = kv_size; - used = 0; cells.resize(kv_size); @@ -138,13 +136,9 @@ llama_kv_cache_unified::llama_kv_cache_unified( } void llama_kv_cache_unified::clear() { - for (uint32_t i = 0; i < size; ++i) { - cells[i].pos = -1; - cells[i].seq_id.clear(); - } + cells.reset(); head = 0; - used = 0; for (auto & buf : bufs) { ggml_backend_buffer_clear(buf.get(), 0); @@ -152,7 +146,7 @@ void llama_kv_cache_unified::clear() { } bool llama_kv_cache_unified::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) { - uint32_t new_head = size; + uint32_t new_head = cells.size(); if (p0 < 0) { p0 = 0; @@ -162,33 +156,20 @@ bool llama_kv_cache_unified::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1 = std::numeric_limits::max(); } - for (uint32_t i = 0; i < size; ++i) { - if (cells[i].pos >= p0 && cells[i].pos < p1) { - if (seq_id < 0) { - cells[i].seq_id.clear(); - } else if (cells[i].has_seq_id(seq_id)) { - cells[i].seq_id.erase(seq_id); - } else { - continue; - } - - if (cells[i].is_empty()) { - // keep count of the number of used cells - if (cells[i].pos >= 0) { - used--; - } - - cells[i].pos = -1; + for (uint32_t i = 0; i < cells.size(); ++i) { + if (!cells.pos_in(i, p0, p1)) { + continue; + } - if (new_head == size) { - new_head = i; - } + if (cells.seq_has(i, seq_id) && cells.seq_rm(i, seq_id)) { + if (new_head == cells.size()) { + new_head = i; } } } // If we freed up a slot, set head to it so searching can start there. - if (new_head != size && new_head < head) { + if (new_head != cells.size() && new_head < head) { head = new_head; } @@ -208,39 +189,30 @@ void llama_kv_cache_unified::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id p1 = std::numeric_limits::max(); } - // otherwise, this is the KV of a Transformer-like model - head = 0; + for (uint32_t i = 0; i < cells.size(); ++i) { + if (!cells.pos_in(i, p0, p1)) { + continue; + } - for (uint32_t i = 0; i < size; ++i) { - if (cells[i].has_seq_id(seq_id_src) && cells[i].pos >= p0 && cells[i].pos < p1) { - cells[i].seq_id.insert(seq_id_dst); + if (cells.seq_has(i, seq_id_src)) { + cells.seq_add(i, seq_id_dst); } } } void llama_kv_cache_unified::seq_keep(llama_seq_id seq_id) { - uint32_t new_head = size; + uint32_t new_head = cells.size(); - for (uint32_t i = 0; i < size; ++i) { - if (!cells[i].has_seq_id(seq_id)) { - if (cells[i].pos >= 0) { - used--; - } - - cells[i].pos = -1; - cells[i].seq_id.clear(); - - if (new_head == size){ + for (uint32_t i = 0; i < cells.size(); ++i) { + if (cells.seq_keep(i, seq_id)) { + if (new_head == cells.size()) { new_head = i; } - } else { - cells[i].seq_id.clear(); - cells[i].seq_id.insert(seq_id); } } // If we freed up a slot, set head to it so searching can start there. - if (new_head != size && new_head < head) { + if (new_head != cells.size() && new_head < head) { head = new_head; } } @@ -250,7 +222,7 @@ void llama_kv_cache_unified::seq_add(llama_seq_id seq_id, llama_pos p0, llama_po return; } - uint32_t new_head = size; + uint32_t new_head = cells.size(); if (p0 < 0) { p0 = 0; @@ -260,25 +232,19 @@ void llama_kv_cache_unified::seq_add(llama_seq_id seq_id, llama_pos p0, llama_po p1 = std::numeric_limits::max(); } - // If there is no range then return early to avoid looping over the + // If there is no range then return early to avoid looping over all cells. if (p0 == p1) { return; } - for (uint32_t i = 0; i < size; ++i) { - if (cells[i].has_seq_id(seq_id) && cells[i].pos >= p0 && cells[i].pos < p1) { - has_shift = true; - - cells[i].pos += delta; - cells[i].delta += delta; + for (uint32_t i = 0; i < cells.size(); ++i) { + if (!cells.pos_in(i, p0, p1)) { + continue; + } - if (cells[i].pos < 0) { - if (!cells[i].is_empty()) { - used--; - } - cells[i].pos = -1; - cells[i].seq_id.clear(); - if (new_head == size) { + if (cells.seq_has(i, seq_id)) { + if (cells.pos_add(i, delta)) { + if (new_head == cells.size()) { new_head = i; } } @@ -287,7 +253,7 @@ void llama_kv_cache_unified::seq_add(llama_seq_id seq_id, llama_pos p0, llama_po // If we freed up a slot, set head to it so searching can start there. // Otherwise we just start the next search from the beginning. - head = new_head != size ? new_head : 0; + head = new_head != cells.size() ? new_head : 0; } void llama_kv_cache_unified::seq_div(llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) { @@ -308,15 +274,13 @@ void llama_kv_cache_unified::seq_div(llama_seq_id seq_id, llama_pos p0, llama_po return; } - for (uint32_t i = 0; i < size; ++i) { - if (cells[i].has_seq_id(seq_id) && cells[i].pos >= p0 && cells[i].pos < p1) { - has_shift = true; + for (uint32_t i = 0; i < cells.size(); ++i) { + if (!cells.pos_in(i, p0, p1)) { + continue; + } - { - llama_pos p_old = cells[i].pos; - cells[i].pos /= d; - cells[i].delta += cells[i].pos - p_old; - } + if (cells.seq_has(i, seq_id)) { + cells.pos_div(i, d); } } } @@ -324,9 +288,9 @@ void llama_kv_cache_unified::seq_div(llama_seq_id seq_id, llama_pos p0, llama_po llama_pos llama_kv_cache_unified::seq_pos_min(llama_seq_id seq_id) const { llama_pos result = std::numeric_limits::max(); - for (uint32_t i = 0; i < size; ++i) { - if (cells[i].has_seq_id(seq_id)) { - result = std::min(result, cells[i].pos); + for (uint32_t i = 0; i < cells.size(); ++i) { + if (cells.seq_has(i, seq_id)) { + result = std::min(result, cells.get_pos(i)); } } @@ -340,9 +304,9 @@ llama_pos llama_kv_cache_unified::seq_pos_min(llama_seq_id seq_id) const { llama_pos llama_kv_cache_unified::seq_pos_max(llama_seq_id seq_id) const { llama_pos result = -1; - for (uint32_t i = 0; i < size; ++i) { - if (cells[i].has_seq_id(seq_id)) { - result = std::max(result, cells[i].pos); + for (uint32_t i = 0; i < cells.size(); ++i) { + if (cells.seq_has(i, seq_id)) { + result = std::max(result, cells.get_pos(i)); } } @@ -350,25 +314,15 @@ llama_pos llama_kv_cache_unified::seq_pos_max(llama_seq_id seq_id) const { } void llama_kv_cache_unified::restore() { - for (const auto & [id, cell] : recovery.cells) { - // TODO: move to new `struct kv_cells` - const bool is_empty0 = cells[id].is_empty(); - const bool is_empty1 = cell.is_empty(); - - if (!is_empty0 && is_empty1) { - used--; - } else if (is_empty0 && !is_empty1) { - used++; - } - - cells[id] = cell; + for (auto & state : recovery.states) { + cells.set(state.i, state.cells); } recovery.clear(); } void llama_kv_cache_unified::commit() { - if (recovery.cells.empty()) { + if (recovery.states.empty()) { LLAMA_LOG_WARN("%s: the recovery information upon a commit was empty - might indicate a bug (ref: %s)\n", __func__, "https://github.com/ggml-org/llama.cpp/pull/13194"); return; @@ -382,7 +336,7 @@ bool llama_kv_cache_unified::update(llama_context & lctx) { auto * sched = lctx.get_sched(); - if (has_shift) { + if (cells.pos_has_shift()) { if (!get_can_shift()) { GGML_ABORT("The current KV cache / model configuration does not support K-shift"); } @@ -406,13 +360,7 @@ bool llama_kv_cache_unified::update(llama_context & lctx) { need_reserve = true; } - { - has_shift = false; - - for (uint32_t i = 0; i < size; ++i) { - cells[i].delta = 0; - } - } + cells.pos_reset_delta(); } if (do_defrag) { @@ -443,7 +391,7 @@ bool llama_kv_cache_unified::update(llama_context & lctx) { void llama_kv_cache_unified::defrag_sched(float thold) { // - do not defrag small contexts (i.e. < 2048 tokens) // - count the padding towards the number of used tokens - const float fragmentation = n >= 2048 ? std::max(0.0f, 1.0f - (float(used + n_pad)/n)) : 0.0f; + const float fragmentation = n >= 2048 ? std::max(0.0f, 1.0f - (float(cells.get_used() + n_pad)/n)) : 0.0f; // queue defragmentation for next llama_kv_cache_update if (fragmentation > thold) { @@ -454,7 +402,7 @@ void llama_kv_cache_unified::defrag_sched(float thold) { } void llama_kv_cache_unified::set_full() { - n = size; + n = cells.size(); // when simulating a full KV cache, the specific value of the "head" pointer is not important because it does not // affect the shapes of the tensors in the compute graph - it only affects the offsets of the K/V views. @@ -478,14 +426,14 @@ bool llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) { // if we have enough unused cells before the current head -> // better to start searching from the beginning of the cache, hoping to fill it - if (head > used + 2*ubatch.n_tokens) { + if (head > cells.get_used() + 2*ubatch.n_tokens) { head = 0; } // otherwise, one cell per token. - if (n_tokens > size) { - LLAMA_LOG_ERROR("%s: n_tokens = %d > size = %d\n", __func__, n_tokens, size); + if (n_tokens > cells.size()) { + LLAMA_LOG_ERROR("%s: n_tokens = %d > size = %u\n", __func__, n_tokens, cells.size()); return false; } @@ -498,10 +446,10 @@ bool llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) { std::string ss; if (n_swa > 0) { for (uint32_t i = 0; i < size; ++i) { - if (cells[i].pos == -1) { + if (cells.is_empty(i)) { ss += '.'; } else { - ss += std::to_string(*cells[i].seq_id.begin()); + ss += 'x'; } if (i%256 == 255) { ss += '\n'; @@ -515,15 +463,16 @@ bool llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) { uint32_t n_tested = 0; while (true) { - if (head + n_tokens > size) { - n_tested += size - head; + if (head + n_tokens > cells.size()) { + n_tested += cells.size() - head; head = 0; continue; } bool found = true; for (uint32_t i = 0; i < n_tokens; i++) { - if (cells[head + i].pos >= 0) { + // TODO: improve to accept cells that are masked by the SWA + if (!cells.is_empty(head + i)) { found = false; head += i + 1; n_tested += i + 1; @@ -535,31 +484,27 @@ bool llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) { break; } - if (n_tested >= size) { + if (n_tested >= cells.size()) { //LLAMA_LOG_ERROR("%s: failed to find a slot for %d tokens\n", __func__, n_tokens); return false; } } - for (uint32_t i = 0; i < n_tokens; ++i) { - // remember the original state - if (recovery.cells.find(head + i) == recovery.cells.end()) { - recovery.cells[head + i] = cells[head + i]; - } + // store the old state of the cells in the recovery stack + recovery.states.push_back({head, cells.cp(head, n_tokens)}); - cells[head + i].pos = ubatch.pos[i]; + for (uint32_t i = 0; i < n_tokens; ++i) { + cells.pos_set(head + i, ubatch.pos[i]); for (int32_t j = 0; j < ubatch.n_seq_id[i]; j++) { - cells[head + i].seq_id.insert(ubatch.seq_id[i][j]); + cells.seq_add(head + i, ubatch.seq_id[i][j]); } } - used += n_tokens; - // a heuristic, to avoid attending the full cache if it is not yet utilized // after enough generations, the benefit from this heuristic disappears // if we start defragmenting the cache, the benefit from this will be more important - n = std::min(size, std::max(n_pad, GGML_PAD(cell_max(), n_pad))); + n = std::min(cells.size(), std::max(n_pad, GGML_PAD(cell_max(), n_pad))); #ifdef FIND_SLOT_DEBUG LLAMA_LOG_WARN("end: n = %5d, used = %5d, head = %5d, n_swa = %5d\n", n, used, head, n_swa); @@ -577,7 +522,7 @@ uint32_t llama_kv_cache_unified::get_n() const { } uint32_t llama_kv_cache_unified::get_size() const { - return size; + return cells.size(); } ggml_tensor * llama_kv_cache_unified::get_k(ggml_context * ctx, int32_t il) const { @@ -661,30 +606,19 @@ void llama_kv_cache_unified::prune_swa(llama_seq_id seq_id, llama_pos pmin, llam int n_attended = 0; - for (uint32_t i = 0; i < size; ++i) { - const llama_pos p0 = cells[i].pos; + for (uint32_t i = 0; i < cells.size(); ++i) { + if (!cells.seq_has(i, seq_id)) { + continue; + } + + const llama_pos p0 = cells.get_pos(i); if (p0 <= pmin && !is_masked_swa(p0, pmin)) { n_attended++; } if (is_masked_swa(p0, pmax)) { - if (seq_id < 0) { - cells[i].seq_id.clear(); - } else if (cells[i].has_seq_id(seq_id)) { - cells[i].seq_id.erase(seq_id); - } else { - continue; - } - - if (cells[i].is_empty()) { - // keep count of the number of used cells - if (cells[i].pos >= 0) { - used--; - } - - cells[i].pos = -1; - } + cells.seq_rm(i, seq_id); } } @@ -723,25 +657,31 @@ void llama_kv_cache_unified::set_input_kq_mask(ggml_tensor * dst, const llama_ub const llama_pos p1 = ubatch->pos[s*n_seq_tokens + j]; for (int i = 0; i < n_kv; ++i) { - const llama_pos p0 = cells[i].pos; + float f = 0.0f; bool masked = false; - // mask the token if not the same sequence - masked = masked || (!cells[i].has_seq_id(seq_id)); + if (cells.is_empty(i)) { + masked = true; + } else { + const llama_pos p0 = cells.get_pos(i); - // mask future tokens - masked = masked || (causal_attn && p0 > p1); + // mask the token if not the same sequence + masked = masked || (!cells.seq_has(i, seq_id)); - // apply SWA if any - masked = masked || (is_masked_swa(p0, p1)); + // mask future tokens + masked = masked || (causal_attn && p0 > p1); - float f = 0.0f; + // apply SWA if any + masked = masked || (is_masked_swa(p0, p1)); + + if (!masked && hparams.use_alibi) { + f = -std::abs(p0 - p1); + } + } if (masked) { f = -INFINITY; - } else if (hparams.use_alibi) { - f = -std::abs(p0 - p1); } data[h*(n_kv*n_tokens) + s*(n_kv*n_seq_tokens) + j*n_kv + i] = f; @@ -765,8 +705,8 @@ void llama_kv_cache_unified::set_input_k_shift(ggml_tensor * dst) const { int32_t * data = (int32_t *) dst->data; - for (uint32_t i = 0; i < size; ++i) { - data[i] = cells[i].delta; + for (uint32_t i = 0; i < cells.size(); ++i) { + data[i] = cells.is_empty(i) ? 0 : cells.get_delta(i); } } @@ -783,7 +723,10 @@ void llama_kv_cache_unified::set_input_pos_bucket(ggml_tensor * dst, const llama for (int h = 0; h < 1; ++h) { for (int j = 0; j < n_tokens; ++j) { for (int i = 0; i < n_kv; ++i) { - data[h*(n_kv*n_tokens) + j*n_kv + i] = llama_relative_position_bucket(cells[i].pos, ubatch->pos[j], hparams.n_rel_attn_bkts, false); + // the position when the cells is empty is irrelevant - it will be masked out later in the attention + const llama_pos p0 = cells.is_empty(i) ? -1 : cells.get_pos(i); + + data[h*(n_kv*n_tokens) + j*n_kv + i] = llama_relative_position_bucket(p0, ubatch->pos[j], hparams.n_rel_attn_bkts, false); } } } @@ -910,7 +853,7 @@ llm_graph_result_ptr llama_kv_cache_unified::build_graph_shift( ggml_tensor * k = ggml_view_3d(ctx, layer.k, - n_embd_head_k, n_head_kv, size, + n_embd_head_k, n_head_kv, cells.size(), ggml_row_size(layer.k->type, n_embd_head_k), ggml_row_size(layer.k->type, n_embd_k_gqa), 0); @@ -1050,12 +993,12 @@ llm_graph_result_ptr llama_kv_cache_unified::build_graph_defrag( } else { view_v_src = ggml_view_2d(ctx, layer.v, nm, n_embd_v_gqa, - ggml_row_size(layer.v->type, size), + ggml_row_size(layer.v->type, cells.size()), ggml_row_size(layer.v->type, i)); view_v_dst = ggml_view_2d(ctx, layer.v, nm, n_embd_v_gqa, - ggml_row_size(layer.v->type, size), + ggml_row_size(layer.v->type, cells.size()), ggml_row_size(layer.v->type, id)); } @@ -1076,7 +1019,7 @@ bool llama_kv_cache_unified::defrag_prepare(int32_t n_max_nodes) { const uint32_t n_layer = layers.size(); const uint32_t n_kv = cell_max(); - const uint32_t n_used = used; + const uint32_t n_used = cells.get_used(); assert(n_used <= n_kv); @@ -1104,9 +1047,7 @@ bool llama_kv_cache_unified::defrag_prepare(int32_t n_max_nodes) { ids.resize(n_kv, n_kv); for (uint32_t i0 = 0; i0 < n_used; ++i0) { - const auto & cell0 = cells[i0]; - - if (!cell0.is_empty()) { + if (!cells.is_empty(i0)) { ids[i0] = i0; continue; @@ -1117,7 +1058,7 @@ bool llama_kv_cache_unified::defrag_prepare(int32_t n_max_nodes) { uint32_t nh = 1; // determine the size of the hole - while (i0 + nh < n_used && cells[i0 + nh].is_empty()) { + while (i0 + nh < n_used && cells.is_empty(i0 + nh)) { nh++; } @@ -1126,9 +1067,7 @@ bool llama_kv_cache_unified::defrag_prepare(int32_t n_max_nodes) { // starting from the end, find nh non-empty cells for (; is > i0; --is) { - const auto & cell1 = cells[is]; - - if (cell1.is_empty() || ids[is] != n_kv) { + if (cells.is_empty(is) || ids[is] != n_kv) { continue; } @@ -1155,9 +1094,7 @@ bool llama_kv_cache_unified::defrag_prepare(int32_t n_max_nodes) { // go back and move the nf cells to the hole for (; i1 < n_kv; ++i1) { - auto & cell1 = cells[i1]; - - if (cell1.is_empty() || ids[i1] != n_kv) { + if (cells.is_empty(i1) || ids[i1] != n_kv) { if (n_moves == max_moves) { stop = true; break; @@ -1171,10 +1108,8 @@ bool llama_kv_cache_unified::defrag_prepare(int32_t n_max_nodes) { ids[i1] = i0 + nf; // move the cell meta data - cells[i0 + nf] = cell1; + cells.mv(i1, i0 + nf); - // clear the old cell and move the head there - cell1 = kv_cell(); head = n_used; if (!cont) { @@ -1210,10 +1145,8 @@ bool llama_kv_cache_unified::defrag_prepare(int32_t n_max_nodes) { } uint32_t llama_kv_cache_unified::cell_max() const { - for (uint32_t i = size; i > 0; --i) { - const kv_cell & cell = cells[i - 1]; - - if (cell.pos >= 0 && !cell.is_empty()) { + for (uint32_t i = cells.size(); i > 0; --i) { + if (!cells.is_empty(i - 1)) { return i; } } @@ -1222,9 +1155,7 @@ uint32_t llama_kv_cache_unified::cell_max() const { } bool llama_kv_cache_unified::is_masked_swa(llama_pos p0, llama_pos p1) const { - if (p0 < 0) { - return true; - } + assert(p0 >= 0 && p1 >= 0); switch (swa_type) { case LLAMA_SWA_TYPE_NONE: @@ -1255,23 +1186,24 @@ void llama_kv_cache_unified::state_write(llama_io_write_i & io, llama_seq_id seq // Count the number of cells with the specified seq_id // Find all the ranges of cells with this seq id (or all, when -1) - uint32_t cell_range_begin = size; - for (uint32_t i = 0; i < size; ++i) { - const auto & cell = cells[i]; - if ((seq_id == -1 && !cell.is_empty()) || cell.has_seq_id(seq_id)) { + uint32_t cell_range_begin = cells.size(); + + for (uint32_t i = 0; i < cells.size(); ++i) { + if (!cells.is_empty(i) && (seq_id == -1 || cells.seq_has(i, seq_id))) { ++cell_count; - if (cell_range_begin == size) { + if (cell_range_begin == cells.size()) { cell_range_begin = i; } } else { - if (cell_range_begin != size) { + if (cell_range_begin != cells.size()) { cell_ranges.emplace_back(cell_range_begin, i); - cell_range_begin = size; + cell_range_begin = cells.size(); } } } - if (cell_range_begin != size) { - cell_ranges.emplace_back(cell_range_begin, size); + + if (cell_range_begin != cells.size()) { + cell_ranges.emplace_back(cell_range_begin, cells.size()); } // DEBUG CHECK: Sum of cell counts in ranges should equal the total cell count @@ -1308,17 +1240,24 @@ void llama_kv_cache_unified::state_read(llama_io_read_i & io, llama_seq_id seq_i void llama_kv_cache_unified::state_write_meta(llama_io_write_i & io, const std::vector> & cell_ranges, llama_seq_id seq_id) const { for (const auto & range : cell_ranges) { for (uint32_t i = range.first; i < range.second; ++i) { - const auto & cell = cells[i]; - const llama_pos pos = cell.pos; - const uint32_t n_seq_id = seq_id == -1 ? cell.seq_id.size() : 0; + std::vector seq_ids; + + for (llama_seq_id cur = 0; cur < (int) n_seq_max; ++cur) { + if (cur == seq_id || seq_id == -1) { + if (cells.seq_has(i, cur)) { + seq_ids.push_back(cur); + } + } + } + + const llama_pos pos = cells.get_pos(i); + const uint32_t n_seq_id = seq_ids.size(); io.write(&pos, sizeof(pos)); io.write(&n_seq_id, sizeof(n_seq_id)); - if (n_seq_id) { - for (auto seq_id : cell.seq_id) { - io.write(&seq_id, sizeof(seq_id)); - } + for (const auto & seq_id : seq_ids) { + io.write(&seq_id, sizeof(seq_id)); } } } @@ -1379,7 +1318,7 @@ void llama_kv_cache_unified::state_write_data(llama_io_write_i & io, const std:: } } else { // When v is transposed, we also need the element size and get the element ranges from each row - const uint32_t kv_size = size; + const uint32_t kv_size = cells.size(); for (const auto & layer : layers) { const uint32_t il = layer.il; @@ -1429,14 +1368,20 @@ bool llama_kv_cache_unified::state_read_meta(llama_io_read_i & io, uint32_t cell io.read_to(&pos, sizeof(pos)); io.read_to(&n_seq_id, sizeof(n_seq_id)); - if (n_seq_id != 0) { + if (n_seq_id != 1) { LLAMA_LOG_ERROR("%s: invalid seq_id-agnostic kv cell\n", __func__); return false; } - batch.pos[i] = pos; - batch.n_seq_id[i] = 1; - batch.seq_id[i] = &dest_seq_id; + // read the sequence id, but directly discard it - we will use dest_seq_id instead + { + llama_seq_id seq_id; + io.read_to(&seq_id, sizeof(seq_id)); + } + + batch.pos[i] = pos; + batch.n_seq_id[i] = n_seq_id; + batch.seq_id[i] = &dest_seq_id; } if (!find_slot(batch)) { @@ -1448,15 +1393,15 @@ bool llama_kv_cache_unified::state_read_meta(llama_io_read_i & io, uint32_t cell // DEBUG CHECK: kv.head should be our first cell, kv.head + cell_count - 1 should be our last cell (verify seq_id and pos values) // Assume that this is one contiguous block of cells - GGML_ASSERT(head + cell_count <= size); - GGML_ASSERT(cells[head].pos == batch.pos[0]); - GGML_ASSERT(cells[head + cell_count - 1].pos == batch.pos[cell_count - 1]); - GGML_ASSERT(cells[head].has_seq_id(dest_seq_id)); - GGML_ASSERT(cells[head + cell_count - 1].has_seq_id(dest_seq_id)); + GGML_ASSERT(head + cell_count <= cells.size()); + GGML_ASSERT(cells.get_pos(head) == batch.pos[0]); + GGML_ASSERT(cells.get_pos(head + cell_count - 1) == batch.pos[cell_count - 1]); + GGML_ASSERT(cells.seq_has(head, dest_seq_id)); + GGML_ASSERT(cells.seq_has(head + cell_count - 1, dest_seq_id)); } else { // whole KV cache restore - if (cell_count > size) { + if (cell_count > cells.size()) { LLAMA_LOG_ERROR("%s: not enough cells in kv cache\n", __func__); return false; } @@ -1464,15 +1409,13 @@ bool llama_kv_cache_unified::state_read_meta(llama_io_read_i & io, uint32_t cell clear(); for (uint32_t i = 0; i < cell_count; ++i) { - kv_cell & cell = cells[i]; - llama_pos pos; uint32_t n_seq_id; io.read_to(&pos, sizeof(pos)); io.read_to(&n_seq_id, sizeof(n_seq_id)); - cell.pos = pos; + cells.pos_set(i, pos); for (uint32_t j = 0; j < n_seq_id; ++j) { llama_seq_id seq_id; @@ -1483,12 +1426,11 @@ bool llama_kv_cache_unified::state_read_meta(llama_io_read_i & io, uint32_t cell return false; } - cell.seq_id.insert(seq_id); + cells.seq_add(i, seq_id); } } head = 0; - used = cell_count; } return true; @@ -1505,8 +1447,8 @@ bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell LLAMA_LOG_ERROR("%s: mismatched layer count (%u instead of %u)\n", __func__, n_layer, (uint32_t) layers.size()); return false; } - if (cell_count > size) { - LLAMA_LOG_ERROR("%s: not enough cells in kv cache to restore state (%u > %u)\n", __func__, cell_count, size); + if (cell_count > cells.size()) { + LLAMA_LOG_ERROR("%s: not enough cells in kv cache to restore state (%u > %u)\n", __func__, cell_count, cells.size()); return false; } if (this->v_trans != (bool) v_trans) { @@ -1609,7 +1551,7 @@ bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell if (cell_count) { // For each row in the transposed matrix, read the values for the whole cell range for (uint32_t j = 0; j < n_embd_v_gqa; ++j) { - const size_t dst_offset = (head + j * size) * v_size_el; + const size_t dst_offset = (head + j * cells.size()) * v_size_el; ggml_backend_tensor_set(layer.v, io.read(cell_count * v_size_el), dst_offset, cell_count * v_size_el); } } diff --git a/src/llama-kv-cache.h b/src/llama-kv-cache.h index 191a1090a1252..f0381a10fc9c0 100644 --- a/src/llama-kv-cache.h +++ b/src/llama-kv-cache.h @@ -4,6 +4,7 @@ #include "llama-io.h" #include "llama-graph.h" #include "llama-memory.h" +#include "llama-kv-cells.h" #include "ggml-cpp.h" @@ -35,6 +36,7 @@ struct llama_kv_cache : public llama_memory_i { virtual void defrag_sched(float thold) = 0; // simulate full cache, used for allocating worst-case compute buffers + // TODO: remove virtual void set_full() = 0; // @@ -42,7 +44,7 @@ struct llama_kv_cache : public llama_memory_i { // // ============================================================================================================= - // TODO: refactor and simplify this + // TODO: refactor and simplify this [TAG: KV_API] virtual llama_sbatch sbatch_init(const llama_batch & batch, bool logits_all) = 0; @@ -159,7 +161,7 @@ class llama_kv_cache_unified : public llama_kv_cache { // llama_kv_cache_unified specific API // - uint32_t get_n() const; + uint32_t get_n() const; uint32_t get_size() const; // get views of the current state of the cache @@ -180,26 +182,6 @@ class llama_kv_cache_unified : public llama_kv_cache { const llama_model & model; const llama_hparams & hparams; - struct kv_cell { - llama_pos pos = -1; - llama_pos delta = 0; - - // TODO: replace with bitset uint64_t - std::set seq_id; - - bool has_seq_id(const llama_seq_id & id) const { - return seq_id.find(id) != seq_id.end(); - } - - bool is_empty() const { - return seq_id.empty(); - } - - bool is_same_seq(const kv_cell & other) const { - return seq_id == other.seq_id; - } - }; - struct kv_layer { // layer index in the model // note: can be different from the layer index in the KV cache @@ -209,15 +191,13 @@ class llama_kv_cache_unified : public llama_kv_cache { ggml_tensor * v; }; - bool has_shift = false; bool do_defrag = false; bool v_trans = true; // the value tensor is transposed uint32_t head = 0; // the location where the batch will be placed in the cache (see find_slot()) - uint32_t size = 0; // total number of cells, shared across all sequences - uint32_t used = 0; // used cells (i.e. at least one seq_id) (TODO: add `struct kv_cells` and keep track automaticallt) // computed before each graph build + // TODO: cells should start to maintain this value dynamically based on the edits uint32_t n = 0; const uint32_t n_seq_max = 1; @@ -233,19 +213,29 @@ class llama_kv_cache_unified : public llama_kv_cache { std::vector ctxs; std::vector bufs; - std::vector cells; // TODO: replace with `struct kv_cells` + llama_kv_cells_unified cells; + std::vector layers; // model layer id -> KV cache layer id std::unordered_map map_layer_ids; // recovery information used to restore the KV cells to their original state in case of a failure + // TODO: do not store as a state in the llama_kv_cache object, instead return upon batch preparation + // to achieve that, first need to refactor the llama_kv_cache interface [TAG: KV_API] struct { void clear() { - cells.clear(); + states.clear(); } - std::unordered_map cells; + struct state { + uint32_t i; + + llama_kv_cells_unified cells; + }; + + // stack with the partial states before each ubatch + std::vector states; } recovery; // defrag @@ -257,6 +247,7 @@ class llama_kv_cache_unified : public llama_kv_cache { bool defrag_prepare(int32_t n_max_nodes); // find how many cells are currently in use + // TODO: optimize uint32_t cell_max() const; size_t total_size() const; diff --git a/src/llama-kv-cells.h b/src/llama-kv-cells.h new file mode 100644 index 0000000000000..a607ee89cfffc --- /dev/null +++ b/src/llama-kv-cells.h @@ -0,0 +1,259 @@ +#pragma once + +#include +#include +#include + +using llama_pos = int32_t; +using llama_seq_id = int32_t; + +// meta information about KV cells that can be part of multiple sequences at the same time +// TODO: add unit tests +struct llama_kv_cells_unified { + void reset() { + for (uint32_t i = 0; i < pos.size(); ++i) { + pos[i] = -1; + delta[i] = 0; + seq[i].reset(); + } + + used = 0; + has_delta = false; + } + + uint32_t size() const { + return pos.size(); + } + + void resize(uint32_t n) { + pos.resize(n); + delta.resize(n); + seq.resize(n); + + reset(); + } + + bool is_empty(uint32_t i) const { + assert(i < pos.size()); + assert((pos[i] < 0 && pos[i] == -1) || pos[i] >= 0); + + return pos[i] == -1; + } + + uint32_t get_used() const { + return used; + } + + // move cell isrc to idst + void mv(uint32_t isrc, uint32_t idst) { + assert(isrc < pos.size()); + assert(idst < pos.size()); + + pos [idst] = pos [isrc]; + delta[idst] = delta[isrc]; + seq [idst] = seq [isrc]; + + pos [isrc] = -1; + delta[isrc] = 0; + seq [isrc].reset(); + } + + // copy the state of cells [i, i + n) + llama_kv_cells_unified cp(uint32_t i, uint32_t n) const { + assert(i + n <= pos.size()); + + llama_kv_cells_unified res; + + res.resize(n); + + for (uint32_t j = 0; j < n; ++j) { + res.pos[j] = pos[i + j]; + res.seq[j] = seq[i + j]; + + assert(delta[i + j] == 0); + } + + return res; + } + + // set the state of cells [i, i + other.pos.size()) + void set(uint32_t i, const llama_kv_cells_unified & other) { + assert(i + other.pos.size() <= pos.size()); + + for (uint32_t j = 0; j < other.pos.size(); ++j) { + if (pos[i + j] == -1 && other.pos[j] != -1) { + used++; + } + + if (pos[i + j] != -1 && other.pos[j] == -1) { + used--; + } + + pos[i + j] = other.pos[j]; + seq[i + j] = other.seq[j]; + + assert(delta[i + j] == 0); + } + } + + // note: call only if the cell has seq_id + // return true if the cell becomes empty + bool seq_rm(uint32_t i, llama_seq_id seq_id) { + assert(i < pos.size()); + assert(seq[i].test(seq_id)); + assert(pos[i] != -1); + assert(seq_id >= 0); + + seq[i].reset(seq_id); + + if (seq[i].none()) { + pos[i]= -1; + + used--; + + return true; + } + + return false; + } + + // return true if the cell becomes empty (i.e. it did not contain seq_id before the call) + bool seq_keep(uint32_t i, llama_seq_id seq_id) { + assert(i < pos.size()); + + if (seq[i].test(seq_id)) { + seq[i].reset(); + seq[i].set(seq_id); + + return false; + } + + if (seq[i].any()) { + seq[i].reset(); + pos[i] = -1; + + used--; + + return true; + } + + assert(pos[i] == -1); + + return false; + } + + bool seq_has(uint32_t i, llama_seq_id seq_id) const { + assert(i < pos.size()); + assert(seq_id >= 0); + + return seq[i].test(seq_id); + } + + // note: call only if the cell is not empty + bool seq_add(uint32_t i, llama_seq_id seq_id) { + assert(i < pos.size()); + assert(pos[i] != -1); + + if (seq[i].none()) { + seq[i].set(seq_id); + + used++; + + return true; + } + + return false; + } + + // note: call only if the cell is not empty + llama_pos get_pos(uint32_t i) const { + assert(i < pos.size()); + assert(pos[i] != -1); + + return pos[i]; + } + + // note: call only if the cell is not empty + llama_pos get_delta(uint32_t i) const { + assert(i < pos.size()); + assert(pos[i] != -1); + + return delta[i]; + } + + bool pos_in(uint32_t i, llama_pos p0, llama_pos p1) const { + assert(i < pos.size()); + + return pos[i] >= p0 && pos[i] < p1; + } + + // note: call only if the cell is empty + void pos_set(uint32_t i, llama_pos p) { + assert(i < pos.size()); + assert(pos[i] == -1); + + pos[i] = p; + used++; + } + + // pos[i] = pos[i] + d + // note: call only if the cell is not empty + bool pos_add(uint32_t i, llama_pos d) { + assert(i < pos.size()); + assert(pos[i] != -1); + + pos[i] += d; + delta[i] += d; + + has_delta = true; + + if (pos[i] < 0) { + pos[i] = -1; + seq[i].reset(); + + used--; + + return true; + } + + return false; + } + + // pos[i] = pos[i] / d + // note: call only if the cell is not empty + void pos_div(uint32_t i, int d) { + assert(i < pos.size()); + assert(pos[i] != -1); + + const llama_pos p_old = pos[i]; + + pos[i] /= d; + delta[i] += p_old - pos[i]; + + has_delta = true; + } + + bool pos_has_shift() const { + return has_delta; + } + + void pos_reset_delta() { + has_delta = false; + + for (uint32_t i = 0; i < delta.size(); ++i) { + delta[i] = 0; + } + } + +private: + uint32_t used = 0; // used cells (i.e. at least one seq_id) + + bool has_delta = false; + + std::vector pos; + std::vector delta; + + // TODO: assert n_seq_max <= 64 + std::vector> seq; +}; + From 71be7e50855e97ac8847576285b935a811cc1c0e Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Fri, 23 May 2025 11:49:58 +0300 Subject: [PATCH 2/7] kv-cells : use "shift" instead of "delta" consistently ggml-ci --- src/llama-kv-cache.cpp | 24 +++++++------- src/llama-kv-cache.h | 6 ++-- src/llama-kv-cells.h | 75 +++++++++++++++++++++++++----------------- src/llama-memory.h | 2 +- 4 files changed, 60 insertions(+), 47 deletions(-) diff --git a/src/llama-kv-cache.cpp b/src/llama-kv-cache.cpp index e35497d038f41..d1d9fe3f88309 100644 --- a/src/llama-kv-cache.cpp +++ b/src/llama-kv-cache.cpp @@ -217,8 +217,8 @@ void llama_kv_cache_unified::seq_keep(llama_seq_id seq_id) { } } -void llama_kv_cache_unified::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta) { - if (delta == 0) { +void llama_kv_cache_unified::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) { + if (shift == 0) { return; } @@ -243,7 +243,7 @@ void llama_kv_cache_unified::seq_add(llama_seq_id seq_id, llama_pos p0, llama_po } if (cells.seq_has(i, seq_id)) { - if (cells.pos_add(i, delta)) { + if (cells.pos_add(i, shift)) { if (new_head == cells.size()) { new_head = i; } @@ -336,7 +336,7 @@ bool llama_kv_cache_unified::update(llama_context & lctx) { auto * sched = lctx.get_sched(); - if (cells.pos_has_shift()) { + if (cells.get_has_shift()) { if (!get_can_shift()) { GGML_ABORT("The current KV cache / model configuration does not support K-shift"); } @@ -360,7 +360,7 @@ bool llama_kv_cache_unified::update(llama_context & lctx) { need_reserve = true; } - cells.pos_reset_delta(); + cells.reset_shift(); } if (do_defrag) { @@ -706,7 +706,7 @@ void llama_kv_cache_unified::set_input_k_shift(ggml_tensor * dst) const { int32_t * data = (int32_t *) dst->data; for (uint32_t i = 0; i < cells.size(); ++i) { - data[i] = cells.is_empty(i) ? 0 : cells.get_delta(i); + data[i] = cells.is_empty(i) ? 0 : cells.get_shift(i); } } @@ -1631,9 +1631,9 @@ void llama_kv_cache_unified_iswa::seq_keep(llama_seq_id seq_id) { kv_swa ->seq_keep(seq_id); } -void llama_kv_cache_unified_iswa::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta) { - kv_base->seq_add(seq_id, p0, p1, delta); - kv_swa ->seq_add(seq_id, p0, p1, delta); +void llama_kv_cache_unified_iswa::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) { + kv_base->seq_add(seq_id, p0, p1, shift); + kv_swa ->seq_add(seq_id, p0, p1, shift); } void llama_kv_cache_unified_iswa::seq_div(llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) { @@ -2005,8 +2005,8 @@ void llama_kv_cache_recurrent::seq_keep(llama_seq_id seq_id) { } } -void llama_kv_cache_recurrent::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta) { - if (delta == 0) { +void llama_kv_cache_recurrent::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) { + if (shift == 0) { return; } @@ -2029,7 +2029,7 @@ void llama_kv_cache_recurrent::seq_add(llama_seq_id seq_id, llama_pos p0, llama_ if (tail_id >= 0) { kv_cell & cell = cells[tail_id]; if (cell.has_seq_id(seq_id) && p0 <= cell.pos && cell.pos < p1) { - cell.pos += delta; + cell.pos += shift; } } } diff --git a/src/llama-kv-cache.h b/src/llama-kv-cache.h index f0381a10fc9c0..86a96820e2420 100644 --- a/src/llama-kv-cache.h +++ b/src/llama-kv-cache.h @@ -123,7 +123,7 @@ class llama_kv_cache_unified : public llama_kv_cache { bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) override; void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override; void seq_keep(llama_seq_id seq_id) override; - void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta) override; + void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) override; void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) override; llama_pos seq_pos_min(llama_seq_id seq_id) const override; @@ -316,7 +316,7 @@ class llama_kv_cache_unified_iswa : public llama_kv_cache { bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) override; void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override; void seq_keep(llama_seq_id seq_id) override; - void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta) override; + void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) override; void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) override; llama_pos seq_pos_min(llama_seq_id seq_id) const override; @@ -422,7 +422,7 @@ class llama_kv_cache_recurrent : public llama_kv_cache { bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) override; void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override; void seq_keep(llama_seq_id seq_id) override; - void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta) override; + void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) override; void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) override; llama_pos seq_pos_min(llama_seq_id seq_id) const override; diff --git a/src/llama-kv-cells.h b/src/llama-kv-cells.h index a607ee89cfffc..eaf5ec9c0a86c 100644 --- a/src/llama-kv-cells.h +++ b/src/llama-kv-cells.h @@ -1,24 +1,32 @@ #pragma once +#include "llama.h" + #include #include #include -using llama_pos = int32_t; -using llama_seq_id = int32_t; - // meta information about KV cells that can be part of multiple sequences at the same time // TODO: add unit tests -struct llama_kv_cells_unified { +class llama_kv_cells_unified { +public: void reset() { for (uint32_t i = 0; i < pos.size(); ++i) { pos[i] = -1; - delta[i] = 0; + shift[i] = 0; seq[i].reset(); } used = 0; - has_delta = false; + has_shift = false; + } + + void reset_shift() { + has_shift = false; + + for (uint32_t i = 0; i < shift.size(); ++i) { + shift[i] = 0; + } } uint32_t size() const { @@ -27,7 +35,7 @@ struct llama_kv_cells_unified { void resize(uint32_t n) { pos.resize(n); - delta.resize(n); + shift.resize(n); seq.resize(n); reset(); @@ -44,17 +52,21 @@ struct llama_kv_cells_unified { return used; } + bool get_has_shift() const { + return has_shift; + } + // move cell isrc to idst void mv(uint32_t isrc, uint32_t idst) { assert(isrc < pos.size()); assert(idst < pos.size()); pos [idst] = pos [isrc]; - delta[idst] = delta[isrc]; + shift[idst] = shift[isrc]; seq [idst] = seq [isrc]; pos [isrc] = -1; - delta[isrc] = 0; + shift[isrc] = 0; seq [isrc].reset(); } @@ -70,7 +82,7 @@ struct llama_kv_cells_unified { res.pos[j] = pos[i + j]; res.seq[j] = seq[i + j]; - assert(delta[i + j] == 0); + assert(shift[i + j] == 0); } return res; @@ -92,7 +104,7 @@ struct llama_kv_cells_unified { pos[i + j] = other.pos[j]; seq[i + j] = other.seq[j]; - assert(delta[i + j] == 0); + assert(shift[i + j] == 0); } } @@ -174,11 +186,11 @@ struct llama_kv_cells_unified { } // note: call only if the cell is not empty - llama_pos get_delta(uint32_t i) const { + llama_pos get_shift(uint32_t i) const { assert(i < pos.size()); assert(pos[i] != -1); - return delta[i]; + return shift[i]; } bool pos_in(uint32_t i, llama_pos p0, llama_pos p1) const { @@ -203,9 +215,9 @@ struct llama_kv_cells_unified { assert(pos[i] != -1); pos[i] += d; - delta[i] += d; + shift[i] += d; - has_delta = true; + has_shift = true; if (pos[i] < 0) { pos[i] = -1; @@ -228,30 +240,31 @@ struct llama_kv_cells_unified { const llama_pos p_old = pos[i]; pos[i] /= d; - delta[i] += p_old - pos[i]; + shift[i] += p_old - pos[i]; - has_delta = true; - } - - bool pos_has_shift() const { - return has_delta; - } - - void pos_reset_delta() { - has_delta = false; - - for (uint32_t i = 0; i < delta.size(); ++i) { - delta[i] = 0; - } + has_shift = true; } private: uint32_t used = 0; // used cells (i.e. at least one seq_id) - bool has_delta = false; + bool has_shift = false; std::vector pos; - std::vector delta; + + // this array accumulates any applied shifts to the pos array since the last reset_shift() call + // this is used to queue multiple updates to the pos array, which in the end can be applied in one go: + // + // cells.pos_add(x, shift_x); + // cells.pos_div(y, shift_y); + // ... + // for (int i = 0; i < n; ++i) { + // auto shift_i = cells.get_shift(i); + // ... + // } + // cells.reset_shift(); + // + std::vector shift; // TODO: assert n_seq_max <= 64 std::vector> seq; diff --git a/src/llama-memory.h b/src/llama-memory.h index c2571edc715e1..a2d250434affa 100644 --- a/src/llama-memory.h +++ b/src/llama-memory.h @@ -22,7 +22,7 @@ class llama_memory_i { virtual bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) = 0; virtual void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) = 0; virtual void seq_keep(llama_seq_id seq_id) = 0; - virtual void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta) = 0; + virtual void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) = 0; virtual void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) = 0; virtual llama_pos seq_pos_min(llama_seq_id seq_id) const = 0; From 7b3f12a8fbaa04a6376eee609345bdf4c1c0947e Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Fri, 23 May 2025 11:58:58 +0300 Subject: [PATCH 3/7] llama : add llama_max_parallel_sequences() ggml-ci --- include/llama.h | 1 + src/llama-context.cpp | 7 ++++++- src/llama-cparams.cpp | 4 ++++ src/llama-cparams.h | 2 ++ src/llama-kv-cells.h | 5 +++-- 5 files changed, 16 insertions(+), 3 deletions(-) diff --git a/include/llama.h b/include/llama.h index 52cd7a5a037ef..eafab7323d9bf 100644 --- a/include/llama.h +++ b/include/llama.h @@ -471,6 +471,7 @@ extern "C" { LLAMA_API int64_t llama_time_us(void); LLAMA_API size_t llama_max_devices(void); + LLAMA_API size_t llama_max_parallel_sequences(void); LLAMA_API bool llama_supports_mmap (void); LLAMA_API bool llama_supports_mlock (void); diff --git a/src/llama-context.cpp b/src/llama-context.cpp index 85b4324b699e6..ca75b2cbec315 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -25,7 +25,12 @@ llama_context::llama_context( const auto & hparams = model.hparams; - cparams.n_seq_max = std::max(1u, params.n_seq_max); + cparams.n_seq_max = std::max(1u, params.n_seq_max); + if (cparams.n_seq_max > LLAMA_MAX_PARALLEL_SEQUENCES) { + LLAMA_LOG_WARN("%s: n_seq_max (%d) is larger than the maximum supported (%d) - clamping\n", __func__, cparams.n_seq_max, LLAMA_MAX_PARALLEL_SEQUENCES); + cparams.n_seq_max = LLAMA_MAX_PARALLEL_SEQUENCES; + } + cparams.n_threads = params.n_threads; cparams.n_threads_batch = params.n_threads_batch; cparams.yarn_ext_factor = params.yarn_ext_factor; diff --git a/src/llama-cparams.cpp b/src/llama-cparams.cpp index 28369be365252..f7b36590fe3e3 100644 --- a/src/llama-cparams.cpp +++ b/src/llama-cparams.cpp @@ -1 +1,5 @@ #include "llama-cparams.h" + +size_t llama_max_parallel_sequences(void) { + return LLAMA_MAX_PARALLEL_SEQUENCES; +} diff --git a/src/llama-cparams.h b/src/llama-cparams.h index 246fa5777deea..2871031ef0961 100644 --- a/src/llama-cparams.h +++ b/src/llama-cparams.h @@ -4,6 +4,8 @@ #include +#define LLAMA_MAX_PARALLEL_SEQUENCES 64 + struct llama_cparams { uint32_t n_ctx; // context size used during inference uint32_t n_batch; diff --git a/src/llama-kv-cells.h b/src/llama-kv-cells.h index eaf5ec9c0a86c..e7f70cc56ef47 100644 --- a/src/llama-kv-cells.h +++ b/src/llama-kv-cells.h @@ -1,6 +1,7 @@ #pragma once #include "llama.h" +#include "llama-cparams.h" #include #include @@ -119,7 +120,7 @@ class llama_kv_cells_unified { seq[i].reset(seq_id); if (seq[i].none()) { - pos[i]= -1; + pos[i] = -1; used--; @@ -267,6 +268,6 @@ class llama_kv_cells_unified { std::vector shift; // TODO: assert n_seq_max <= 64 - std::vector> seq; + std::vector> seq; }; From 6221dd292e1219133be35cba69b0ec0f88a73601 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Fri, 23 May 2025 12:22:50 +0300 Subject: [PATCH 4/7] kv-cells : update comments [no ci] --- src/llama-kv-cells.h | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/llama-kv-cells.h b/src/llama-kv-cells.h index e7f70cc56ef47..3f1f37835897d 100644 --- a/src/llama-kv-cells.h +++ b/src/llama-kv-cells.h @@ -194,12 +194,15 @@ class llama_kv_cells_unified { return shift[i]; } + // check if a cell is not empty and its position is within [p0, p1) bool pos_in(uint32_t i, llama_pos p0, llama_pos p1) const { assert(i < pos.size()); return pos[i] >= p0 && pos[i] < p1; } + // set the position of an empty cell + // does not modify "has_shift" // note: call only if the cell is empty void pos_set(uint32_t i, llama_pos p) { assert(i < pos.size()); @@ -210,6 +213,7 @@ class llama_kv_cells_unified { } // pos[i] = pos[i] + d + // sets "has_shift" to true // note: call only if the cell is not empty bool pos_add(uint32_t i, llama_pos d) { assert(i < pos.size()); @@ -233,6 +237,7 @@ class llama_kv_cells_unified { } // pos[i] = pos[i] / d + // sets "has_shift" to true // note: call only if the cell is not empty void pos_div(uint32_t i, int d) { assert(i < pos.size()); @@ -267,7 +272,6 @@ class llama_kv_cells_unified { // std::vector shift; - // TODO: assert n_seq_max <= 64 std::vector> seq; }; From 43b40d3f32ddd1a8ebe3e433abf5103db2d3d242 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Fri, 23 May 2025 16:54:36 +0300 Subject: [PATCH 5/7] context : fail upon construction if sequences exceed max value ggml-ci --- src/llama-context.cpp | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/llama-context.cpp b/src/llama-context.cpp index ca75b2cbec315..98ecb7c8249ce 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -27,8 +27,7 @@ llama_context::llama_context( cparams.n_seq_max = std::max(1u, params.n_seq_max); if (cparams.n_seq_max > LLAMA_MAX_PARALLEL_SEQUENCES) { - LLAMA_LOG_WARN("%s: n_seq_max (%d) is larger than the maximum supported (%d) - clamping\n", __func__, cparams.n_seq_max, LLAMA_MAX_PARALLEL_SEQUENCES); - cparams.n_seq_max = LLAMA_MAX_PARALLEL_SEQUENCES; + throw std::runtime_error("n_seq_max must be <= " + std::to_string(LLAMA_MAX_PARALLEL_SEQUENCES)); } cparams.n_threads = params.n_threads; From f71e737a839e09f50d997f908c1a6f55dcda89c1 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Fri, 23 May 2025 17:44:18 +0300 Subject: [PATCH 6/7] kv-cells : get_pos() -> pos_get() + comments ggml-ci --- src/llama-kv-cache.cpp | 16 ++++++++-------- src/llama-kv-cells.h | 19 +++++++++++-------- 2 files changed, 19 insertions(+), 16 deletions(-) diff --git a/src/llama-kv-cache.cpp b/src/llama-kv-cache.cpp index d1d9fe3f88309..ae2d2684f8cba 100644 --- a/src/llama-kv-cache.cpp +++ b/src/llama-kv-cache.cpp @@ -290,7 +290,7 @@ llama_pos llama_kv_cache_unified::seq_pos_min(llama_seq_id seq_id) const { for (uint32_t i = 0; i < cells.size(); ++i) { if (cells.seq_has(i, seq_id)) { - result = std::min(result, cells.get_pos(i)); + result = std::min(result, cells.pos_get(i)); } } @@ -306,7 +306,7 @@ llama_pos llama_kv_cache_unified::seq_pos_max(llama_seq_id seq_id) const { for (uint32_t i = 0; i < cells.size(); ++i) { if (cells.seq_has(i, seq_id)) { - result = std::max(result, cells.get_pos(i)); + result = std::max(result, cells.pos_get(i)); } } @@ -611,7 +611,7 @@ void llama_kv_cache_unified::prune_swa(llama_seq_id seq_id, llama_pos pmin, llam continue; } - const llama_pos p0 = cells.get_pos(i); + const llama_pos p0 = cells.pos_get(i); if (p0 <= pmin && !is_masked_swa(p0, pmin)) { n_attended++; @@ -664,7 +664,7 @@ void llama_kv_cache_unified::set_input_kq_mask(ggml_tensor * dst, const llama_ub if (cells.is_empty(i)) { masked = true; } else { - const llama_pos p0 = cells.get_pos(i); + const llama_pos p0 = cells.pos_get(i); // mask the token if not the same sequence masked = masked || (!cells.seq_has(i, seq_id)); @@ -724,7 +724,7 @@ void llama_kv_cache_unified::set_input_pos_bucket(ggml_tensor * dst, const llama for (int j = 0; j < n_tokens; ++j) { for (int i = 0; i < n_kv; ++i) { // the position when the cells is empty is irrelevant - it will be masked out later in the attention - const llama_pos p0 = cells.is_empty(i) ? -1 : cells.get_pos(i); + const llama_pos p0 = cells.is_empty(i) ? -1 : cells.pos_get(i); data[h*(n_kv*n_tokens) + j*n_kv + i] = llama_relative_position_bucket(p0, ubatch->pos[j], hparams.n_rel_attn_bkts, false); } @@ -1250,7 +1250,7 @@ void llama_kv_cache_unified::state_write_meta(llama_io_write_i & io, const std:: } } - const llama_pos pos = cells.get_pos(i); + const llama_pos pos = cells.pos_get(i); const uint32_t n_seq_id = seq_ids.size(); io.write(&pos, sizeof(pos)); @@ -1394,8 +1394,8 @@ bool llama_kv_cache_unified::state_read_meta(llama_io_read_i & io, uint32_t cell // DEBUG CHECK: kv.head should be our first cell, kv.head + cell_count - 1 should be our last cell (verify seq_id and pos values) // Assume that this is one contiguous block of cells GGML_ASSERT(head + cell_count <= cells.size()); - GGML_ASSERT(cells.get_pos(head) == batch.pos[0]); - GGML_ASSERT(cells.get_pos(head + cell_count - 1) == batch.pos[cell_count - 1]); + GGML_ASSERT(cells.pos_get(head) == batch.pos[0]); + GGML_ASSERT(cells.pos_get(head + cell_count - 1) == batch.pos[cell_count - 1]); GGML_ASSERT(cells.seq_has(head, dest_seq_id)); GGML_ASSERT(cells.seq_has(head + cell_count - 1, dest_seq_id)); } else { diff --git a/src/llama-kv-cells.h b/src/llama-kv-cells.h index 3f1f37835897d..83b010ac14535 100644 --- a/src/llama-kv-cells.h +++ b/src/llama-kv-cells.h @@ -57,7 +57,7 @@ class llama_kv_cells_unified { return has_shift; } - // move cell isrc to idst + // move cell isrc to idst (used during defrag) void mv(uint32_t isrc, uint32_t idst) { assert(isrc < pos.size()); assert(idst < pos.size()); @@ -71,7 +71,7 @@ class llama_kv_cells_unified { seq [isrc].reset(); } - // copy the state of cells [i, i + n) + // copy the state of cells [i, i + n) (used for save/restore the state of the cells) llama_kv_cells_unified cp(uint32_t i, uint32_t n) const { assert(i + n <= pos.size()); @@ -89,7 +89,7 @@ class llama_kv_cells_unified { return res; } - // set the state of cells [i, i + other.pos.size()) + // set the state of cells [i, i + other.pos.size()) (used for save/restore the state of the cells) void set(uint32_t i, const llama_kv_cells_unified & other) { assert(i + other.pos.size() <= pos.size()); @@ -179,7 +179,7 @@ class llama_kv_cells_unified { } // note: call only if the cell is not empty - llama_pos get_pos(uint32_t i) const { + llama_pos pos_get(uint32_t i) const { assert(i < pos.size()); assert(pos[i] != -1); @@ -264,11 +264,14 @@ class llama_kv_cells_unified { // cells.pos_add(x, shift_x); // cells.pos_div(y, shift_y); // ... - // for (int i = 0; i < n; ++i) { - // auto shift_i = cells.get_shift(i); - // ... + // + // if (cells.has_shift()) { + // for (int i = 0; i < n; ++i) { + // auto shift_i = cells.get_shift(i); + // ... + // } + // cells.reset_shift(); // } - // cells.reset_shift(); // std::vector shift; From dd394a694e2221be5896209a43dc984131725228 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sat, 24 May 2025 13:36:17 +0300 Subject: [PATCH 7/7] kv-cells : fix tracking of "used" cells ggml-ci --- src/llama-kv-cells.h | 17 +++++------------ 1 file changed, 5 insertions(+), 12 deletions(-) diff --git a/src/llama-kv-cells.h b/src/llama-kv-cells.h index 83b010ac14535..138545533ed22 100644 --- a/src/llama-kv-cells.h +++ b/src/llama-kv-cells.h @@ -162,20 +162,13 @@ class llama_kv_cells_unified { return seq[i].test(seq_id); } - // note: call only if the cell is not empty - bool seq_add(uint32_t i, llama_seq_id seq_id) { + // note: call only if the cell is not empty and the seq_id is not in the cell + void seq_add(uint32_t i, llama_seq_id seq_id) { assert(i < pos.size()); assert(pos[i] != -1); + assert(!seq[i].test(seq_id)); - if (seq[i].none()) { - seq[i].set(seq_id); - - used++; - - return true; - } - - return false; + seq[i].set(seq_id); } // note: call only if the cell is not empty @@ -252,7 +245,7 @@ class llama_kv_cells_unified { } private: - uint32_t used = 0; // used cells (i.e. at least one seq_id) + uint32_t used = 0; // used cells (i.e. pos[i] != -1, allowed to not have any seq_id) bool has_shift = false;