diff --git a/include/llama.h b/include/llama.h index 29677d74207a3..a2b87bcb4d750 100644 --- a/include/llama.h +++ b/include/llama.h @@ -554,6 +554,9 @@ extern "C" { // Returns true if the model is recurrent (like Mamba, RWKV, etc.) LLAMA_API bool llama_model_is_recurrent(const struct llama_model * model); + // Returns true if the model is hybrid-recurrent (like Jamba, Bamba, etc.) + LLAMA_API bool llama_model_is_hybrid_recurrent(const struct llama_model * model); + // Returns 0 on success LLAMA_API uint32_t llama_model_quantize( const char * fname_inp, diff --git a/src/llama-arch.cpp b/src/llama-arch.cpp index abf436adac416..70060debbcb77 100644 --- a/src/llama-arch.cpp +++ b/src/llama-arch.cpp @@ -144,6 +144,7 @@ static const std::map LLM_KV_NAMES = { { LLM_KV_ATTENTION_SCALE, "%s.attention.scale" }, { LLM_KV_ATTENTION_KEY_LENGTH_MLA, "%s.attention.key_length_mla" }, { LLM_KV_ATTENTION_VALUE_LENGTH_MLA, "%s.attention.value_length_mla" }, + { LLM_KV_ATTENTION_LAYER_INDICES, "%s.attention.layer_indices" }, { LLM_KV_ROPE_DIMENSION_COUNT, "%s.rope.dimension_count" }, { LLM_KV_ROPE_DIMENSION_SECTIONS, "%s.rope.dimension_sections" }, @@ -1744,3 +1745,25 @@ llm_arch llm_arch_from_string(const std::string & name) { const llm_tensor_info & llm_tensor_info_for(llm_tensor tensor) { return LLM_TENSOR_INFOS.at(tensor); } + +bool llm_arch_is_recurrent(const llm_arch & arch) { + switch (arch) { + case LLM_ARCH_MAMBA: + case LLM_ARCH_RWKV6: + case LLM_ARCH_RWKV6QWEN2: + case LLM_ARCH_RWKV7: + case LLM_ARCH_ARWKV7: + return true; + default: + return false; + } +} + +bool llm_arch_is_hybrid_recurrent(const llm_arch & arch) { + // TODO: There are currently no hybrid models! Once there are, this will be + // the place to identify them + switch (arch) { + default: + return false; + } +} diff --git a/src/llama-arch.h b/src/llama-arch.h index 41a023da3da6e..35c917a5c365a 100644 --- a/src/llama-arch.h +++ b/src/llama-arch.h @@ -148,6 +148,7 @@ enum llm_kv { LLM_KV_ATTENTION_SCALE, LLM_KV_ATTENTION_KEY_LENGTH_MLA, LLM_KV_ATTENTION_VALUE_LENGTH_MLA, + LLM_KV_ATTENTION_LAYER_INDICES, LLM_KV_ROPE_DIMENSION_COUNT, LLM_KV_ROPE_DIMENSION_SECTIONS, @@ -435,3 +436,6 @@ const char * llm_arch_name(llm_arch arch); llm_arch llm_arch_from_string(const std::string & name); const llm_tensor_info & llm_tensor_info_for(llm_tensor tensor); + +bool llm_arch_is_recurrent(const llm_arch& arch); +bool llm_arch_is_hybrid_recurrent(const llm_arch& arch); diff --git a/src/llama-hparams.cpp b/src/llama-hparams.cpp index 1499eb08a5dd9..70a7114f39715 100644 --- a/src/llama-hparams.cpp +++ b/src/llama-hparams.cpp @@ -65,7 +65,10 @@ uint32_t llama_hparams::n_embd_v_gqa(uint32_t il) const { return n_embd_head_v * n_head_kv; } -uint32_t llama_hparams::n_embd_k_s() const { +uint32_t llama_hparams::n_embd_k_s(uint32_t il) const { + if (!recurrent_layer(il)) { + return 0; + } if (wkv_head_size != 0) { // for RWKV models return token_shift_count * n_embd; @@ -76,7 +79,10 @@ uint32_t llama_hparams::n_embd_k_s() const { return (ssm_d_conv > 0 ? ssm_d_conv - 1 : 0) * ssm_d_inner; } -uint32_t llama_hparams::n_embd_v_s() const { +uint32_t llama_hparams::n_embd_v_s(uint32_t il) const { + if (!recurrent_layer(il)) { + return 0; + } if (wkv_head_size != 0) { // corresponds to RWKV's wkv_states size return n_embd * wkv_head_size; @@ -86,6 +92,10 @@ uint32_t llama_hparams::n_embd_v_s() const { return ssm_d_state * ssm_d_inner; } +bool llama_hparams::recurrent_layer(uint32_t il) const { + return recurrent_layer_arr[il]; +} + bool llama_hparams::is_swa(uint32_t il) const { if (il < n_layer) { return swa_layers[il]; diff --git a/src/llama-hparams.h b/src/llama-hparams.h index 2d72eab180ad0..e10741a104cb7 100644 --- a/src/llama-hparams.h +++ b/src/llama-hparams.h @@ -115,6 +115,9 @@ struct llama_hparams { uint32_t ssm_d_state = 0; uint32_t ssm_dt_rank = 0; + // for hybrid state space models + std::array recurrent_layer_arr; + bool ssm_dt_b_c_rms = false; float f_clamp_kqv = 0.0f; @@ -178,10 +181,13 @@ struct llama_hparams { // dimension of the rolling state embeddings // corresponds to Mamba's conv_states size or RWKV's token_shift states size - uint32_t n_embd_k_s() const; + uint32_t n_embd_k_s(uint32_t il = 0) const; // dimension of the recurrent state embeddings - uint32_t n_embd_v_s() const; + uint32_t n_embd_v_s(uint32_t il = 0) const; + + // whether or not the given layer is recurrent (for hybrid models) + bool recurrent_layer(uint32_t il) const; bool is_swa(uint32_t il) const; }; diff --git a/src/llama-kv-cache.cpp b/src/llama-kv-cache.cpp index 3e3d26286e1ee..0763feddf66d5 100644 --- a/src/llama-kv-cache.cpp +++ b/src/llama-kv-cache.cpp @@ -129,8 +129,8 @@ llama_kv_cache_unified::llama_kv_cache_unified( continue; } - const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s(); - const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(); + const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s(il); + const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(il); const char * dev_name = "CPU"; @@ -352,14 +352,19 @@ llama_memory_decode_state_ptr llama_kv_cache_unified::init( const llama_batch & batch, uint32_t n_ubatch, bool embd_pooled, - bool logits_all) { + bool logits_all, + bool split_equal) { GGML_UNUSED(embd_pooled); auto sbatch = llama_sbatch(batch, hparams.n_embd, true, logits_all); std::vector ubatches; while (sbatch.n_tokens > 0) { - ubatches.push_back(sbatch.split_simple(n_ubatch)); + if (split_equal) { + ubatches.push_back(sbatch.split_equal(n_ubatch)); + } else { + ubatches.push_back(sbatch.split_simple(n_ubatch)); + } } auto heads = prepare(ubatches); @@ -394,7 +399,7 @@ std::vector llama_kv_cache_unified::prepare(const std::vectortype; @@ -1401,7 +1414,7 @@ void llama_kv_cache_unified::state_write_data(llama_io_write_i & io, const std:: for (const auto & layer : layers) { const uint32_t il = layer.il; - const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(); + const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(il); // Write value type const int32_t v_type_i = (int32_t)layer.v->type; @@ -1425,7 +1438,7 @@ void llama_kv_cache_unified::state_write_data(llama_io_write_i & io, const std:: for (const auto & layer : layers) { const uint32_t il = layer.il; - const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(); + const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(il); // Write value type const int32_t v_type_i = (int32_t)layer.v->type; @@ -1565,7 +1578,7 @@ bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell for (const auto & layer : layers) { const uint32_t il = layer.il; - const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s(); + const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s(il); // Read type of key int32_t k_type_i_ref; @@ -1595,7 +1608,7 @@ bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell for (const auto & layer : layers) { const uint32_t il = layer.il; - const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(); + const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(il); // Read type of value int32_t v_type_i_ref; @@ -1625,7 +1638,7 @@ bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell for (const auto & layer : layers) { const uint32_t il = layer.il; - const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(); + const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(il); // Read type of value int32_t v_type_i_ref; @@ -1813,7 +1826,12 @@ llama_pos llama_kv_cache_unified_iswa::seq_pos_max(llama_seq_id seq_id) const { return kv_swa->seq_pos_max(seq_id); } -llama_memory_decode_state_ptr llama_kv_cache_unified_iswa::init(const llama_batch & batch, uint32_t n_ubatch, bool embd_pooled, bool logits_all) { +llama_memory_decode_state_ptr llama_kv_cache_unified_iswa::init( + const llama_batch & batch, + uint32_t n_ubatch, + bool embd_pooled, + bool logits_all, + bool split_equal) { GGML_UNUSED(embd_pooled); auto sbatch = llama_sbatch(batch, hparams.n_embd, true, logits_all); @@ -1821,9 +1839,11 @@ llama_memory_decode_state_ptr llama_kv_cache_unified_iswa::init(const llama_batc std::vector ubatches; while (sbatch.n_tokens > 0) { - auto ubatch = sbatch.split_simple(n_ubatch); - - ubatches.push_back(ubatch); + if (split_equal) { + ubatches.push_back(sbatch.split_equal(n_ubatch)); + } else { + ubatches.push_back(sbatch.split_simple(n_ubatch)); + } } auto heads_base = kv_base->prepare(ubatches); @@ -1861,6 +1881,15 @@ void llama_kv_cache_unified_iswa::set_full() { kv_swa ->set_full(); } +bool llama_kv_cache_unified_iswa::can_seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) const { + GGML_UNUSED(seq_id); + GGML_UNUSED(p0); + GGML_UNUSED(p1); + // Unified attention caches can always do a sequence removal, so since both + // children can, the parent can as well. + return true; +} + bool llama_kv_cache_unified_iswa::get_can_shift() const { return kv_base->get_size() == kv_swa->get_size(); } @@ -1934,12 +1963,13 @@ class llama_kv_cache_recurrent_decode_state_t : public llama_memory_decode_state }; llama_kv_cache_recurrent::llama_kv_cache_recurrent( - const llama_model & model, - ggml_type type_k, - ggml_type type_v, - bool offload, - uint32_t kv_size, - uint32_t n_seq_max) : hparams(model.hparams), n_seq_max(n_seq_max) { + const llama_model & model, + layer_filter_cb && filter, + ggml_type type_k, + ggml_type type_v, + bool offload, + uint32_t kv_size, + uint32_t n_seq_max) : hparams(model.hparams), n_seq_max(n_seq_max) { const int32_t n_layer = hparams.n_layer; LLAMA_LOG_INFO("%s: kv_size = %u, n_seq_max = %u, type_k = '%s', type_v = '%s', n_layer = %d\n", @@ -1981,8 +2011,13 @@ llama_kv_cache_recurrent::llama_kv_cache_recurrent( v_l.reserve(n_layer); for (int i = 0; i < n_layer; i++) { - const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(i) + hparams.n_embd_k_s(); - const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(i) + hparams.n_embd_v_s(); + if (filter && !filter(i)) { + LLAMA_LOG_DEBUG("%s: layer %3d: skipped\n", __func__, i); + continue; + } + + const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(i) + hparams.n_embd_k_s(i); + const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(i) + hparams.n_embd_v_s(i); const char * dev_name = "CPU"; @@ -2006,8 +2041,8 @@ llama_kv_cache_recurrent::llama_kv_cache_recurrent( ggml_tensor * v = ggml_new_tensor_1d(ctx, type_v, n_embd_v_gqa*kv_size); ggml_format_name(k, "cache_k_l%d", i); ggml_format_name(v, "cache_v_l%d", i); - k_l.push_back(k); - v_l.push_back(v); + k_l[i] = k; + v_l[i] = v; } // allocate tensors and initialize the buffers to avoid NaNs in the padding @@ -2051,39 +2086,33 @@ void llama_kv_cache_recurrent::clear() { } bool llama_kv_cache_recurrent::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) { - uint32_t new_head = size; + if (!can_seq_rm(seq_id, p0, p1)) { + // could be fatal + return false; + } + uint32_t new_head = size; if (p0 < 0) { p0 = 0; } - if (p1 < 0) { p1 = std::numeric_limits::max(); } - // models like Mamba or RWKV can't have a state partially erased - if (seq_id >= (int64_t) size) { - // could be fatal - return false; - } if (0 <= seq_id) { int32_t & tail_id = cells[seq_id].tail; if (tail_id >= 0) { const kv_cell & cell = cells[tail_id]; - // partial intersection is invalid - if ((0 < p0 && p0 <= cell.pos) || (0 < p1 && p1 <= cell.pos)) { - return false; - } + // already validated in can_seq_rm + GGML_ASSERT(!((0 < p0 && p0 <= cell.pos) || (0 < p1 && p1 <= cell.pos))); // invalidate tails which will be cleared if (p0 <= cell.pos && cell.pos < p1) { tail_id = -1; } } } else { - // seq_id is negative, then the range should include everything or nothing - if (p0 != p1 && (p0 != 0 || p1 != std::numeric_limits::max())) { - return false; - } + // already validated in can_seq_rm + GGML_ASSERT(!(p0 != p1 && (p0 != 0 || p1 != std::numeric_limits::max()))); } for (uint32_t i = 0; i < size; ++i) { @@ -2274,8 +2303,15 @@ llama_pos llama_kv_cache_recurrent::seq_pos_max(llama_seq_id seq_id) const { return result; } -llama_memory_decode_state_ptr llama_kv_cache_recurrent::init(const llama_batch & batch, uint32_t n_ubatch, bool embd_pooled, bool logits_all) { +llama_memory_decode_state_ptr llama_kv_cache_recurrent::init( + const llama_batch & batch, + uint32_t n_ubatch, + bool embd_pooled, + bool logits_all, + bool split_equal) { GGML_UNUSED(embd_pooled); + // TODO: Should this just be ignored? + assert(split_equal); auto sbatch = llama_sbatch(batch, hparams.n_embd, false, logits_all); @@ -2349,6 +2385,34 @@ void llama_kv_cache_recurrent::set_full() { n = size; head = 0; } +bool llama_kv_cache_recurrent::can_seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) const { + if (p0 < 0) { + p0 = 0; + } + + if (p1 < 0) { + p1 = std::numeric_limits::max(); + } + // models like Mamba or RWKV can't have a state partially erased + if (seq_id >= (int64_t) size) { + // could be fatal + return false; + } + if (0 <= seq_id) { + const int32_t & tail_id = cells[seq_id].tail; + if (tail_id >= 0) { + const kv_cell & cell = cells[tail_id]; + // partial intersection is invalid + if ((0 < p0 && p0 <= cell.pos) || (0 < p1 && p1 <= cell.pos)) { + return false; + } + } + // seq_id is negative, then the range should include everything or nothing + } else if (p0 != p1 && (p0 != 0 || p1 != std::numeric_limits::max())) { + return false; + } + return true; +} bool llama_kv_cache_recurrent::find_slot(const llama_ubatch & ubatch) { const uint32_t n_tokens = ubatch.n_tokens; @@ -2685,7 +2749,7 @@ void llama_kv_cache_recurrent::state_write_data(llama_io_write_i & io, const std // Iterate and write all the keys first, each row is a cell // Get whole range at a time for (uint32_t il = 0; il < n_layer; ++il) { - const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s(); + const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s(il); // Write key type const int32_t k_type_i = (int32_t)k_l[il]->type; @@ -2705,7 +2769,7 @@ void llama_kv_cache_recurrent::state_write_data(llama_io_write_i & io, const std if (!v_trans) { for (uint32_t il = 0; il < n_layer; ++il) { - const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(); + const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(il); // Write value type const int32_t v_type_i = (int32_t)v_l[il]->type; @@ -2726,7 +2790,7 @@ void llama_kv_cache_recurrent::state_write_data(llama_io_write_i & io, const std // When v is transposed, we also need the element size and get the element ranges from each row const uint32_t kv_size = size; for (uint32_t il = 0; il < n_layer; ++il) { - const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(); + const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(il); // Write value type const int32_t v_type_i = (int32_t)v_l[il]->type; @@ -2873,7 +2937,7 @@ bool llama_kv_cache_recurrent::state_read_data(llama_io_read_i & io, uint32_t ce // For each layer, read the keys for each cell, one row is one cell, read as one contiguous block for (uint32_t il = 0; il < n_layer; ++il) { - const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s(); + const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s(il); // Read type of key int32_t k_type_i_ref; @@ -2901,7 +2965,7 @@ bool llama_kv_cache_recurrent::state_read_data(llama_io_read_i & io, uint32_t ce if (!v_trans) { for (uint32_t il = 0; il < n_layer; ++il) { - const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(); + const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(il); // Read type of value int32_t v_type_i_ref; @@ -2929,7 +2993,7 @@ bool llama_kv_cache_recurrent::state_read_data(llama_io_read_i & io, uint32_t ce } else { // For each layer, read the values for each cell (transposed) for (uint32_t il = 0; il < n_layer; ++il) { - const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(); + const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(il); // Read type of value int32_t v_type_i_ref; @@ -2969,3 +3033,279 @@ bool llama_kv_cache_recurrent::state_read_data(llama_io_read_i & io, uint32_t ce return true; } + +// +// llama_kv_cache_hybrid +// + + +class llama_kv_cache_hybrid_decode_state_t : public llama_memory_decode_state_i { +public: + explicit llama_kv_cache_hybrid_decode_state_t( + std::vector decode_states) : + status([](const std::vector & decode_states) -> llama_memory_status { + for (const auto & decode_state : decode_states) { + if (!decode_state) { + return LLAMA_MEMORY_STATUS_FAILED_PREPARE; + } + const auto & status = decode_state->get_status(); + if (status != LLAMA_MEMORY_STATUS_SUCCESS) { + return status; + } + } + return LLAMA_MEMORY_STATUS_SUCCESS; + }(decode_states)), + decode_states(std::move(decode_states)) { + + // make sure at least one decode state + assert(!decode_states.empty()); + + // make sure all out_ids match across states + // TODO: This could be expensive, so maybe don't do it? + const auto & out_ids = decode_states[0]->out_ids(); + for (size_t i = 1; i < decode_states.size(); ++i) { + const auto & out_ids_i = decode_states[i]->out_ids(); + assert(out_ids.size() == out_ids_i.size()); + for (size_t j = 0; j < out_ids.size(); ++j) { + assert(out_ids[j] == out_ids_i[j]); + } + } + } + + ~llama_kv_cache_hybrid_decode_state_t() = default; + + llama_ubatch * next() override { + assert(status == LLAMA_MEMORY_STATUS_SUCCESS); + + // hit next on each child + std::vector next_ubatches; + for (const auto & decode_state : decode_states) { + next_ubatches.push_back(decode_state->next()); + } + + // make sure they all match + // TODO: unnecessary safety? + llama_ubatch * res = next_ubatches[0]; + assert(res); + for (size_t i = 1; i < next_ubatches.size(); ++i) { + llama_ubatch * ubatch_i = next_ubatches[i]; + assert(ubatch_i); + assert(ubatch_i->n_tokens == res->n_tokens); + assert(ubatch_i->n_seq_tokens == res->n_seq_tokens); + assert(ubatch_i->n_seqs == res->n_seqs); + for (size_t j = 0; j < res->n_tokens; ++j) { + assert(ubatch_i->token[j] == res->token[j]); + assert(ubatch_i->pos[j] == res->pos[j]); + assert(ubatch_i->output[j] == res->output[j]); + } + for (size_t j = 0; j < res->n_seqs; ++j) { + assert(ubatch_i->n_seq_id[j] == res->n_seq_id[j]); + } + } + + // return the first ubatch since they all match + return res; + } + + std::vector & out_ids() override { + assert(status == LLAMA_MEMORY_STATUS_SUCCESS); + + return decode_states[0]->out_ids(); + } + + llama_memory_status get_status() const override { + return status; + } + +private: + const llama_memory_status status; + std::vector decode_states; +}; + +llama_kv_cache_hybrid::llama_kv_cache_hybrid(std::vector children_) : + children( + [](std::vector& caches) -> std::set> { + // Sort the caches by the lowest layer ID so the order is repeatable + for (auto & cache : caches) { + GGML_ASSERT(cache.layer_ids.size() > 0); + std::sort(cache.layer_ids.begin(), cache.layer_ids.end()); + } + std::sort(caches.begin(), caches.end(), [](const child_cache & a, const child_cache & b) { + return a.layer_ids[0] < b.layer_ids[0]; + }); + std::set> unique_caches; + for (auto & cache : caches) { + unique_caches.emplace(cache.child.release()); + } + return unique_caches; + }(children_) + ), + has_recurrent( + [](const std::set> & caches) -> bool { + for (const auto & cache : caches) { + if (dynamic_cast(cache.get())) { + return true; + } + } + return false; + }(children) + ) +{ + // Ensure at least one child + GGML_ASSERT(children.size() > 0); + + // Ensure layers are not overlapping and are concurrent + std::set seen_layers; + size_t max_layer = 0; + for (const auto & cache : children_) { + for (const auto & layer_id : cache.layer_ids) { + GGML_ASSERT(seen_layers.find(layer_id) == seen_layers.end()); + seen_layers.insert(layer_id); + if (layer_id > max_layer) { + max_layer = layer_id; + } + } + } + LLAMA_LOG_DEBUG("max_layer=%zu, seen_layers.size()=%zu\n", max_layer, seen_layers.size()); + GGML_ASSERT(max_layer + 1 == seen_layers.size()); +} + +void llama_kv_cache_hybrid::clear() { + for (const auto & cache : children) { + cache->clear(); + } +} + +bool llama_kv_cache_hybrid::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) { + // First check if we can do this removal. This checks all children so that + // no mutation happens before we know if it's possible + if (!can_seq_rm(seq_id, p0, p1)) { + return false; + } + + // Do the removal from each child which should never fail + for (const auto & cache : children) { + const bool failed = cache->seq_rm(seq_id, p0, p1); + GGML_ASSERT(!failed); + } + return true; +} + +void llama_kv_cache_hybrid::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) { + for (const auto & cache : children) { + cache->seq_cp(seq_id_src, seq_id_dst, p0, p1); + } +} + +void llama_kv_cache_hybrid::seq_keep(llama_seq_id seq_id) { + for (const auto & cache : children) { + cache->seq_keep(seq_id); + } +} + +void llama_kv_cache_hybrid::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta) { + for (const auto & cache : children) { + cache->seq_add(seq_id, p0, p1, delta); + } +} + +void llama_kv_cache_hybrid::seq_div(llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) { + for (const auto & cache : children) { + cache->seq_div(seq_id, p0, p1, d); + } +} + +llama_pos llama_kv_cache_hybrid::seq_pos_min(llama_seq_id seq_id) const { + llama_pos min_pos = -1; + for (const auto & cache : children) { + const auto child_min_pos = cache->seq_pos_min(seq_id); + min_pos = min_pos == -1 ? child_min_pos : std::min(min_pos, child_min_pos); + } + return min_pos; +} + +llama_pos llama_kv_cache_hybrid::seq_pos_max(llama_seq_id seq_id) const { + llama_pos max_pos = 0; + for (const auto & cache : children) { + max_pos = std::max(max_pos, cache->seq_pos_max(seq_id)); + } + return max_pos; +} + +llama_memory_decode_state_ptr llama_kv_cache_hybrid::init( + const llama_batch & batch, + uint32_t n_ubatch, + bool embd_pooled, + bool logits_all, + bool split_equal) { + + // recurrent children require equal splits + // TODO: just ignore this if set incorrectly? + assert(!has_recurrent || split_equal); + + // init all children and capture their decode states + std::vector decode_states; + for (const auto & child : children) { + decode_states.emplace_back( + child->init(batch, n_ubatch, embd_pooled, logits_all, split_equal)); + } + + // return the hybrid decode state + return std::make_unique(std::move(decode_states)); +} + +bool llama_kv_cache_hybrid::update(llama_context & ctx) { + bool updated = false; + for (const auto & cache : children) { + updated = cache->update(ctx) || updated; + } + return updated; +} + +void llama_kv_cache_hybrid::defrag_sched(float thold) { + for (const auto & cache : children) { + cache->defrag_sched(thold); + } +} + +void llama_kv_cache_hybrid::set_full() { + for (const auto & cache : children) { + cache->set_full(); + } +} + +bool llama_kv_cache_hybrid::can_seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) const { + for (const auto & cache : children) { + if (!cache->can_seq_rm(seq_id, p0, p1)) { + return false; + } + } + return true; +} + +bool llama_kv_cache_hybrid::get_can_shift() const { + // TODO: Is this correct? + // If any children can shift, return true + for (const auto & cache : children) { + if (cache->get_can_shift()) { + return true; + } + } + return false; +} + +void llama_kv_cache_hybrid::state_write(llama_io_write_i & io, llama_seq_id seq_id) const { + // Write each cache state in order. Note that order is guaranteed at + // initialization by using an ordered set sorted by lowest layer ID + for (const auto & cache : children) { + cache->state_write(io, seq_id); + } +} + +void llama_kv_cache_hybrid::state_read(llama_io_read_i & io, llama_seq_id seq_id) { + // Read each cache state in order. Note that order is guaranteed at + // initialization by using an ordered set sorted by lowest layer ID + for (const auto & cache : children) { + cache->state_read(io, seq_id); + } +} diff --git a/src/llama-kv-cache.h b/src/llama-kv-cache.h index f1ba7cba390e2..d47f25402cee4 100644 --- a/src/llama-kv-cache.h +++ b/src/llama-kv-cache.h @@ -20,6 +20,12 @@ struct llama_model; struct llama_context; struct llama_kv_cache : public llama_memory_i { + + // some child types need to perform different caching for each layer, so + // this callback can be used to determine which layers a given cache should + // be used for + using layer_filter_cb = std::function; + virtual ~llama_kv_cache() = default; // split the input batch into a set of ubatches and verify that they can fit into the cache @@ -28,7 +34,8 @@ struct llama_kv_cache : public llama_memory_i { const llama_batch & batch, uint32_t n_ubatch, bool embd_pooled, - bool logits_all) = 0; + bool logits_all, + bool split_equal = false) = 0; // process any pending defrag/shift/etc. operations // optionally call once before processing a new batch @@ -42,6 +49,11 @@ struct llama_kv_cache : public llama_memory_i { // TODO: remove virtual void set_full() = 0; + // sometimes it is useful to check whether a cache can remove a sequence + // before attempting to mutate the cache (eg a hybrid cache with multiple + // children to keep in sync) + virtual bool can_seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) const = 0; + // getters virtual bool get_can_shift() const = 0; @@ -63,9 +75,6 @@ class llama_kv_cache_unified : public llama_kv_cache { public: static uint32_t get_padding(const llama_cparams & cparams); - // this callback is used to filter out layers that should not be included in the cache - using layer_filter_cb = std::function; - llama_kv_cache_unified( const llama_model & model, layer_filter_cb && filter, @@ -104,7 +113,8 @@ class llama_kv_cache_unified : public llama_kv_cache { const llama_batch & batch, uint32_t n_ubatch, bool embd_pooled, - bool logits_all) override; + bool logits_all, + bool split_equal = false) override; bool update(llama_context & lctx) override; @@ -112,6 +122,8 @@ class llama_kv_cache_unified : public llama_kv_cache { void set_full() override; + bool can_seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) const override; + bool get_can_shift() const override; // state write/load @@ -279,7 +291,8 @@ class llama_kv_cache_unified_iswa : public llama_kv_cache { const llama_batch & batch, uint32_t n_ubatch, bool embd_pooled, - bool logits_all) override; + bool logits_all, + bool split_equal = false) override; bool update(llama_context & lctx) override; @@ -287,6 +300,8 @@ class llama_kv_cache_unified_iswa : public llama_kv_cache { void set_full() override; + bool can_seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) const override; + bool get_can_shift() const override; // state write/load @@ -315,12 +330,13 @@ class llama_kv_cache_unified_iswa : public llama_kv_cache { class llama_kv_cache_recurrent : public llama_kv_cache { public: llama_kv_cache_recurrent( - const llama_model & model, - ggml_type type_k, - ggml_type type_v, - bool offload, - uint32_t kv_size, - uint32_t n_seq_max); + const llama_model & model, + layer_filter_cb && filter, + ggml_type type_k, + ggml_type type_v, + bool offload, + uint32_t kv_size, + uint32_t n_seq_max); ~llama_kv_cache_recurrent() = default; @@ -347,7 +363,8 @@ class llama_kv_cache_recurrent : public llama_kv_cache { const llama_batch & batch, uint32_t n_ubatch, bool embd_pooled, - bool logits_all) override; + bool logits_all, + bool split_equal = true) override; bool update(llama_context & lctx) override; @@ -355,6 +372,8 @@ class llama_kv_cache_recurrent : public llama_kv_cache { void set_full() override; + bool can_seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) const override; + bool prepare(const std::vector & ubatches); // find a contiguous slot of kv cells and emplace the ubatch there @@ -424,3 +443,86 @@ class llama_kv_cache_recurrent : public llama_kv_cache { bool state_read_meta(llama_io_read_i & io, uint32_t cell_count, llama_seq_id dest_seq_id = -1); bool state_read_data(llama_io_read_i & io, uint32_t cell_count); }; + +// +// llama_kv_cache_hybrid +// + +// utilizes multiple different cache types with each layer assigned to exactly +// one cache. This is typically used for hybrid attention / recurrent caching + +class llama_kv_cache_hybrid : public llama_kv_cache { +public: + + struct child_cache { + std::unique_ptr child; + std::vector layer_ids; + + child_cache(std::unique_ptr child_, std::vector layer_ids_) + : child(std::move(child_)), layer_ids(std::move(layer_ids_)) {} + }; + + explicit llama_kv_cache_hybrid(std::vector children); + virtual ~llama_kv_cache_hybrid() = default; + + // getters for specific child cache type + // NOTE: This will fail if there are multiple of the given type + template + const child_t * get_child_cache() const { + const child_t * child = nullptr; + for (const auto & child_cache : children) { + const child_t * child_cast = dynamic_cast(child_cache.get()); + if (child_cast) { + GGML_ASSERT(!child); + child = child_cast; + } + } + return child; + } + + // + // llama_memory_i + // + + void clear() override; + + bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) override; + void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override; + void seq_keep(llama_seq_id seq_id) override; + void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta) override; + void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) override; + + llama_pos seq_pos_min(llama_seq_id seq_id) const override; + llama_pos seq_pos_max(llama_seq_id seq_id) const override; + + // + // llama_kv_cache + // + + llama_memory_decode_state_ptr init( + const llama_batch & batch, + uint32_t n_ubatch, + bool embd_pooled, + bool logits_all, + bool split_equal = true) override; + + bool update(llama_context & ctx) override; + + void defrag_sched(float thold) override; + + void set_full() override; + + bool can_seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) const override; + + bool get_can_shift() const override; + + // state write/load + + void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const override; + void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1) override; + +private: + + const std::set> children; // Ordered for state IO + const bool has_recurrent; +}; diff --git a/src/llama-model.cpp b/src/llama-model.cpp index e99f5309f9904..110c4863880b3 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -466,6 +466,10 @@ void llama_model::load_hparams(llama_model_loader & ml) { std::fill(hparams.n_head_arr.begin(), hparams.n_head_arr.end(), 0); std::fill(hparams.n_head_kv_arr.begin(), hparams.n_head_kv_arr.end(), 0); std::fill(hparams.n_ff_arr.begin(), hparams.n_ff_arr.end(), 0); + std::fill( + hparams.recurrent_layer_arr.begin(), + hparams.recurrent_layer_arr.end(), + llm_arch_is_recurrent(ml.get_arch())); std::fill(hparams.rope_sections.begin(), hparams.rope_sections.end(), 0); @@ -13192,6 +13196,8 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params, llama_memory_i * res; switch (arch) { + // Models that need specific instantiation should be handled in the + // switch statement case LLM_ARCH_BERT: case LLM_ARCH_JINA_BERT_V2: case LLM_ARCH_NOMIC_BERT: @@ -13200,57 +13206,108 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params, { res = nullptr; } break; - case LLM_ARCH_MAMBA: - case LLM_ARCH_RWKV6: - case LLM_ARCH_RWKV6QWEN2: - case LLM_ARCH_RWKV7: - case LLM_ARCH_ARWKV7: - { - res = new llama_kv_cache_recurrent( - *this, - GGML_TYPE_F32, - GGML_TYPE_F32, - cparams.offload_kqv, - std::max((uint32_t) 1, cparams.n_seq_max), - cparams.n_seq_max); - } break; + // Models that need standard caching should rely on recurrent/hybrid + // checks default: { - const auto padding = llama_kv_cache_unified::get_padding(cparams); - - cparams.n_ctx = GGML_PAD(cparams.n_ctx, padding); - - LLAMA_LOG_DEBUG("%s: n_ctx = %u (padded)\n", __func__, cparams.n_ctx); - - if (hparams.swa_type != LLAMA_SWA_TYPE_NONE) { - GGML_ASSERT(hparams.is_swa_any()); + if (llm_arch_is_hybrid_recurrent(arch)) { + // make vectors of recurrent and non-recurrent layer indices + std::vector recurrent_layers; + std::vector unified_layers; + for (auto il = 0u; il < hparams.n_layer; ++il) { + if (hparams.recurrent_layer(il)) { + recurrent_layers.push_back(il); + } else { + unified_layers.push_back(il); + } + } - res = new llama_kv_cache_unified_iswa( - *this, - params.type_k, - params.type_v, - !cparams.flash_attn, - cparams.offload_kqv, - params.swa_full, - cparams.n_ctx, - cparams.n_seq_max, - cparams.n_batch, - padding); - } else { - GGML_ASSERT(!hparams.is_swa_any()); + const auto padding = llama_kv_cache_unified::get_padding(cparams); + cparams.n_ctx = GGML_PAD(cparams.n_ctx, padding); + LLAMA_LOG_DEBUG("%s: n_ctx = %u (padded)\n", __func__, cparams.n_ctx); + + // initialize the children + std::vector children; + children.emplace_back( + std::unique_ptr( + new llama_kv_cache_recurrent( + *this, + [&](int32_t il) { + return hparams.recurrent_layer(il); + }, + GGML_TYPE_F32, + GGML_TYPE_F32, + cparams.offload_kqv, + std::max((uint32_t) 1, cparams.n_seq_max), + cparams.n_seq_max) + ), + std::move(recurrent_layers) + ); + children.emplace_back( + std::unique_ptr( + new llama_kv_cache_unified( + *this, + [&](int32_t il) { + return ! hparams.recurrent_layer(il); + }, + params.type_k, + params.type_v, + !cparams.flash_attn, + cparams.offload_kqv, + cparams.n_ctx, + cparams.n_seq_max, + padding, + hparams.n_swa, + hparams.swa_type) + ), + std::move(unified_layers) + ); - res = new llama_kv_cache_unified( + // initialize the hybrid cache with both children + res = new llama_kv_cache_hybrid(std::move(children)); + } else if (llm_arch_is_recurrent(arch)) { + res = new llama_kv_cache_recurrent( *this, nullptr, - params.type_k, - params.type_v, - !cparams.flash_attn, + GGML_TYPE_F32, + GGML_TYPE_F32, cparams.offload_kqv, - cparams.n_ctx, - cparams.n_seq_max, - padding, - hparams.n_swa, - hparams.swa_type); + std::max((uint32_t) 1, cparams.n_seq_max), + cparams.n_seq_max + ); + } else { + const auto padding = llama_kv_cache_unified::get_padding(cparams); + + cparams.n_ctx = GGML_PAD(cparams.n_ctx, padding); + + LLAMA_LOG_DEBUG("%s: n_ctx = %u (padded)\n", __func__, cparams.n_ctx); + + if (hparams.n_swa > 0) { + res = new llama_kv_cache_unified_iswa( + *this, + params.type_k, + params.type_v, + !cparams.flash_attn, + cparams.offload_kqv, + cparams.n_ctx, + params.swa_full, + cparams.n_seq_max, + cparams.n_batch, + padding); + } else { + res = new llama_kv_cache_unified( + *this, + nullptr, + params.type_k, + params.type_v, + !cparams.flash_attn, + cparams.offload_kqv, + cparams.n_ctx, + cparams.n_seq_max, + padding, + hparams.n_swa, + hparams.swa_type); + } } } } @@ -13799,14 +13856,11 @@ llama_token llama_model_decoder_start_token(const llama_model * model) { } bool llama_model_is_recurrent(const llama_model * model) { - switch (model->arch) { - case LLM_ARCH_MAMBA: return true; - case LLM_ARCH_RWKV6: return true; - case LLM_ARCH_RWKV6QWEN2: return true; - case LLM_ARCH_RWKV7: return true; - case LLM_ARCH_ARWKV7: return true; - default: return false; - } + return llm_arch_is_recurrent(model->arch); +} + +bool llama_model_is_hybrid_recurrent(const llama_model * model) { + return llm_arch_is_hybrid_recurrent(model->arch); } const std::vector> & llama_internal_get_tensor_map(const llama_model * model) { diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index 62a9f5842bca8..ff3d97d7a27eb 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -109,6 +109,7 @@ if (NOT WIN32) llama_build_and_test(test-grammar-integration.cpp) llama_build_and_test(test-llama-grammar.cpp) llama_build_and_test(test-chat.cpp) + llama_build_and_test(test-memory.cpp) # TODO: disabled on loongarch64 because the ggml-ci node lacks Python 3.8 if (NOT ${CMAKE_SYSTEM_PROCESSOR} MATCHES "loongarch64") llama_build_and_test(test-json-schema-to-grammar.cpp WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/..) diff --git a/tests/test-memory.cpp b/tests/test-memory.cpp new file mode 100644 index 0000000000000..33dd22fc7eae8 --- /dev/null +++ b/tests/test-memory.cpp @@ -0,0 +1,300 @@ +/*------------------------------------------------------------------------------ + * Unit tests for llama-memory.h and derived memory implementations. It contains + * a number of tests which can be run all together or separately. + * + * USAGE: ./bin/test-memory + * + * When adding a new test, do the following: + * + * 1. Add the new test__description function under the + * appropriate memory type section + * + * 2. Add `RUN_TEST(test__description);` to main + *----------------------------------------------------------------------------*/ + +#include "../src/llama-arch.h" +#include "../src/llama-batch.h" +#include "../src/llama-hparams.h" +#include "../src/llama-impl.h" +#include "../src/llama-kv-cache.h" +#include "../src/llama-model.h" + +#include "common.h" +#include "ggml.h" +#include "llama.h" + +#include +#include +#include + +/*- Helpers ------------------------------------------------------------------*/ + +static std::shared_ptr _make_model( + llm_arch arch = LLM_ARCH_LLAMA, + uint32_t n_layer = 4, + uint32_t n_embd_head_k = 4, + uint32_t n_embd_head_v = 4, + uint32_t n_head = 8, + uint32_t n_head_kv = 2) { + + llama_model_params params; + params.tensor_buft_overrides = nullptr; + std::shared_ptr model(new llama_model(params)); + model->hparams = llama_hparams(); + model->arch = arch; + + model->hparams.n_layer = n_layer; + model->hparams.n_embd_head_k = n_embd_head_k; + model->hparams.n_embd_head_v = n_embd_head_v; + + // If set to 0, assume the test will fill out the array elementwise (hybrid) + if (n_head > 0) { + auto& n_head_arr = model->hparams.n_head_arr; + std::fill(n_head_arr.begin(), n_head_arr.end(), n_head); + } + if (n_head_kv > 0) { + auto& n_head_kv_arr = model->hparams.n_head_kv_arr; + std::fill(n_head_kv_arr.begin(), n_head_kv_arr.end(), n_head_kv); + } + + return model; +} + +static llama_batch _make_batch( + std::vector> token_seqs, + std::vector> seq_ids) { + GGML_ASSERT(token_seqs.size() == seq_ids.size()); + + size_t total_tokens = 0; + for (const auto & token_seq : token_seqs) { + total_tokens += token_seq.size(); + } + size_t max_seq_ids = 0; + for (const auto & seq_ids_i : seq_ids) { + max_seq_ids = std::max(max_seq_ids, seq_ids_i.size()); + } + llama_batch batch = llama_batch_init(total_tokens, 0, max_seq_ids); + + for (size_t i = 0; i < token_seqs.size(); ++i) { + const auto& token_seq = token_seqs[i]; + const auto& seq_ids_i = seq_ids[i]; + for (int pos = 0; pos < (int)token_seq.size(); ++pos) { + common_batch_add(batch, token_seq[pos], pos, seq_ids_i, false); + } + } + return batch; +} + +static bool is_source_tensor(ggml_tensor * child, ggml_tensor * parent) { + if (!child || !parent) return false; + for (size_t i = 0; i < GGML_MAX_SRC; ++i) { + if (child->src[i] == parent) { + return true; + } else if (child->src[i] != nullptr && is_source_tensor(child->src[i], parent)) { + return true; + } + } + return false; +} + +struct log_scope { + const char * name; + explicit log_scope(const char * name) : name(name) { + LLAMA_LOG_INFO("--------\n"); + LLAMA_LOG_INFO("START: %s\n", name); + } + ~log_scope() { + LLAMA_LOG_INFO("END: %s\n", name); + LLAMA_LOG_INFO("--------\n"); + } +}; + +#define RUN_TEST(test_name) \ + do { \ + bool run_test = argc < 2; \ + std::vector args(argv + 1, argv + argc); \ + if (std::find(args.begin(), args.end(), #test_name) != args.end()) \ + run_test = true; \ + if (run_test) { \ + log_scope __log_scope(#test_name); \ + test_name(); \ + } \ + } while (0) + +/*- Unified Cache ------------------------------------------------------------*/ + +/* Test that the unified cache can be constructed and destructed safely */ +static void test_llama_kv_cache_unified_constructor() { + auto model = _make_model(); + llama_kv_cache_unified cache( + /* model */ *model, + /* filter */ nullptr, + /* type_k */ GGML_TYPE_F32, + /* type_v */ GGML_TYPE_F16, + /* v_trans */ false, + /* offload */ false, + /* kv_size */ 10, + /* n_seq_max */ 1, + /* padding */ 10, + /* n_swa */ 0, + /* swa_type */ LLAMA_SWA_TYPE_NONE + ); +} + +/* Test that the unified cache can operate with a single seq */ +static void test_llama_kv_cache_unified_single_seq() { + auto model = _make_model(); + llama_kv_cache_unified cache( + /* model */ *model, + /* filter */ nullptr, + /* type_k */ GGML_TYPE_F32, + /* type_v */ GGML_TYPE_F16, + /* v_trans */ false, + /* offload */ false, + /* kv_size */ 10, + /* n_seq_max */ 1, + /* padding */ 10, + /* n_swa */ 0, + /* swa_type */ LLAMA_SWA_TYPE_NONE + ); + + // // Create the micro batch with a single 3-token sequence + // llama_batch batch1 = _make_batch({{101, 1, 102}}, {{42}}); + // llama_sbatch sbatch1 = cache.sbatch_init(batch1, false); + // llama_ubatch ubatch1 = cache.ubatch_next(sbatch1, 4, false); + + // // Find a slot for a new sequence + // GGML_ASSERT(cache.find_slot(ubatch1)); + + // // Cache the k/v for a single layer in this slot + // ggml_context * ctx = ggml_init({10240, NULL, false}); + // ggml_tensor * k1 = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, model->hparams.n_embd_k_gqa(0)); + // ggml_tensor * v1 = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, model->hparams.n_embd_v_gqa(0)); + // ggml_tensor * k1_view = cache.cpy_k(ctx, k1, 0); + // ggml_tensor * v1_view = cache.cpy_v(ctx, v1, 0); + // GGML_ASSERT(is_source_tensor(k1_view, k1)); + // GGML_ASSERT(is_source_tensor(v1_view, v1)); + + // // Create a second batch with different tokens and find a slot for it + // llama_batch batch2 = _make_batch({{1, 2, 3, 4}}, {{5}}); + // llama_sbatch sbatch2 = cache.sbatch_init(batch2, false); + // llama_ubatch ubatch2 = cache.ubatch_next(sbatch2, 4, false); + // GGML_ASSERT(cache.find_slot(ubatch2)); + + // // Add some different tensors + // ggml_tensor * k2 = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, model->hparams.n_embd_k_gqa(0)); + // ggml_tensor * v2 = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, model->hparams.n_embd_v_gqa(0)); + // ggml_tensor * k2_view = cache.cpy_k(ctx, k2, 0); + // ggml_tensor * v2_view = cache.cpy_v(ctx, v2, 0); + // GGML_ASSERT(is_source_tensor(k2_view, k2)); + // GGML_ASSERT(is_source_tensor(v2_view, v2)); + + // // Make sure first batch's k/v aren't cache hit + // GGML_ASSERT(!is_source_tensor(k2_view, k1)); + // GGML_ASSERT(!is_source_tensor(v2_view, v1)); + + // // Re-find the slot for the first batch and make sure they cache hit + // GGML_ASSERT(cache.find_slot(ubatch1)); + + // // Clean up + // llama_batch_free(batch1); + // llama_batch_free(batch2); + // ggml_free(ctx); +} + +/*- Recurrent Cache ----------------------------------------------------------*/ + +/* Test that the recurrent cache can be constructed and destructed safely */ +static void test_llama_kv_cache_recurrent_constructor() { + auto model = _make_model(LLM_ARCH_MAMBA); + llama_kv_cache_recurrent cache( + /* model */ *model, + /* filter */ nullptr, + /* type_k */ GGML_TYPE_F32, + /* type_v */ GGML_TYPE_F16, + /* offload */ false, + /* kv_size */ 10, + /* n_seq_max */ 1 + ); +} + +/*- Hybrid Cache -------------------------------------------------------------*/ + +/* Test that the hybrid cache can be constructed and destructed safely */ +static void test_llama_kv_cache_hybrid_constructor() { + auto model = _make_model( + /* arch =*/ LLM_ARCH_LLAMA, + /* n_layer =*/ 4, + /* n_embd_head_k =*/ 4, + /* n_embd_head_v =*/ 4, + /* n_head =*/ 0, + /* n_head_kv =*/ 0 + ); + auto recurrent_filter = [](int32_t il) { + return il == 0 || il == 2; + }; + auto unified_filter = [&recurrent_filter](int32_t il) { + return !recurrent_filter(il); + }; + auto& n_head_arr = model->hparams.n_head_arr; + n_head_arr[0] = 16; + n_head_arr[1] = 32; + n_head_arr[2] = 16; + n_head_arr[3] = 32; + auto& n_head_kv_arr = model->hparams.n_head_kv_arr; + n_head_kv_arr[0] = 16; + n_head_kv_arr[1] = 8; + n_head_kv_arr[2] = 16; + n_head_kv_arr[3] = 8; + + std::unique_ptr u_cache( + new llama_kv_cache_unified( + /* model */ *model, + /* filter */ unified_filter, + /* type_k */ GGML_TYPE_F32, + /* type_v */ GGML_TYPE_F16, + /* v_trans */ false, + /* offload */ false, + /* kv_size */ 10, + /* n_seq_max */ 1, + /* padding */ 10, + /* n_swa */ 0, + /* swa_type */ LLAMA_SWA_TYPE_NONE + ) + ); + auto * u_cache_ptr = u_cache.get(); + std::unique_ptr r_cache ( + new llama_kv_cache_recurrent( + /* model */ *model, + /* filter */ recurrent_filter, + /* type_k */ GGML_TYPE_F32, + /* type_v */ GGML_TYPE_F16, + /* offload */ false, + /* kv_size */ 10, + /* n_seq_max */ 1 + ) + ); + auto * r_cache_ptr = r_cache.get(); + + std::vector children; + children.emplace_back(std::move(u_cache), std::vector{1, 3}); + children.emplace_back(std::move(r_cache), std::vector{0, 2}); + + llama_kv_cache_hybrid cache(std::move(children)); + + GGML_ASSERT(cache.get_child_cache() == u_cache_ptr); + GGML_ASSERT(cache.get_child_cache() == r_cache_ptr); +} + +/*- Main ---------------------------------------------------------------------*/ + +int main(int argc, char* argv[]) { + // Unified Cache Tests + RUN_TEST(test_llama_kv_cache_unified_constructor); + RUN_TEST(test_llama_kv_cache_unified_single_seq); + // Recurrent Cache Tests + RUN_TEST(test_llama_kv_cache_recurrent_constructor); + // Hybrid Cache Tests + RUN_TEST(test_llama_kv_cache_hybrid_constructor); + return 0; +}