Skip to content

Commit 6468631

Browse files
committed
llama : use n_swa + n_ubatch cells for SWA cache
ggml-ci
1 parent 1adcd4b commit 6468631

File tree

5 files changed

+12
-11
lines changed

5 files changed

+12
-11
lines changed

include/llama.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -502,6 +502,7 @@ extern "C" {
502502
LLAMA_API int32_t llama_model_n_layer (const struct llama_model * model);
503503
LLAMA_API int32_t llama_model_n_head (const struct llama_model * model);
504504
LLAMA_API int32_t llama_model_n_head_kv (const struct llama_model * model);
505+
LLAMA_API int32_t llama_model_n_swa (const struct llama_model * model);
505506

506507
// Get the model's RoPE frequency scaling factor
507508
LLAMA_API float llama_model_rope_freq_scale_train(const struct llama_model * model);

src/llama-kv-cache.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1738,14 +1738,14 @@ llama_kv_cache_unified_iswa::llama_kv_cache_unified_iswa(
17381738
bool swa_full,
17391739
uint32_t kv_size,
17401740
uint32_t n_seq_max,
1741-
uint32_t n_batch,
1741+
uint32_t n_ubatch,
17421742
uint32_t n_pad) : hparams(model.hparams) {
17431743
llama_kv_cache_unified::layer_filter_cb filter_base = [&](int32_t il) { return !model.hparams.is_swa(il); };
17441744
llama_kv_cache_unified::layer_filter_cb filter_swa = [&](int32_t il) { return model.hparams.is_swa(il); };
17451745

17461746
const uint32_t size_base = kv_size;
17471747

1748-
uint32_t size_swa = std::min(size_base, GGML_PAD(hparams.n_swa*n_seq_max + n_batch, n_pad));
1748+
uint32_t size_swa = std::min(size_base, GGML_PAD(hparams.n_swa*n_seq_max + n_ubatch, n_pad));
17491749

17501750
// when using full-size SWA cache, we set the SWA cache size to be equal to the base cache size
17511751
if (swa_full) {

src/llama-kv-cache.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -251,7 +251,7 @@ class llama_kv_cache_unified_iswa : public llama_kv_cache {
251251
bool swa_full,
252252
uint32_t kv_size,
253253
uint32_t n_seq_max,
254-
uint32_t n_batch,
254+
uint32_t n_ubatch,
255255
uint32_t n_pad);
256256

257257
~llama_kv_cache_unified_iswa() = default;

src/llama-model.cpp

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13234,7 +13234,7 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
1323413234
params.swa_full,
1323513235
cparams.n_ctx,
1323613236
cparams.n_seq_max,
13237-
cparams.n_batch,
13237+
cparams.n_ubatch,
1323813238
padding);
1323913239
} else {
1324013240
GGML_ASSERT(!hparams.is_swa_any());
@@ -13597,6 +13597,10 @@ int32_t llama_model_n_head_kv(const llama_model * model) {
1359713597
return model->hparams.n_head_kv();
1359813598
}
1359913599

13600+
int32_t llama_model_n_swa(const llama_model * model) {
13601+
return model->hparams.n_swa;
13602+
}
13603+
1360013604
// deprecated
1360113605
int32_t llama_n_ctx_train(const llama_model * model) {
1360213606
return llama_model_n_ctx_train(model);

tools/server/server.cpp

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2018,11 +2018,6 @@ struct server_context {
20182018
params_base.n_cache_reuse = 0;
20192019
SRV_WRN("%s\n", "cache_reuse is not supported by this context, it will be disabled");
20202020
}
2021-
2022-
if (!params_base.speculative.model.path.empty()) {
2023-
SRV_ERR("%s\n", "err: speculative decode is not supported by this context");
2024-
return false;
2025-
}
20262021
}
20272022

20282023
return true;
@@ -3216,9 +3211,10 @@ struct server_context {
32163211
}
32173212

32183213
if (slot.n_past > 0 && slot.n_past < (int) slot.cache_tokens.size()) {
3214+
const auto n_swa = llama_model_n_swa(model);
32193215
const auto pos_min = llama_kv_self_seq_pos_min(ctx, slot.id);
3220-
if (pos_min > 0) {
3221-
SLT_WRN(slot, "n_past = %d, cache_tokens.size() = %d, seq_id = %d, pos_min = %d\n", slot.n_past, (int) slot.cache_tokens.size(), slot.id, pos_min);
3216+
if (pos_min == -1 || pos_min > slot.n_past - n_swa) {
3217+
SLT_WRN(slot, "n_past = %d, cache_tokens.size() = %d, seq_id = %d, pos_min = %d, n_swa = %d\n", slot.n_past, (int) slot.cache_tokens.size(), slot.id, pos_min, n_swa);
32223218
SLT_WRN(slot, "forcing full prompt re-processing due to lack of cache data (likely due to SWA, see %s)\n",
32233219
"https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055");
32243220
slot.n_past = 0;

0 commit comments

Comments
 (0)