Skip to content

Commit 7c1bdd0

Browse files
committed
llama : apply K-cache roping for Falcon and Baichuan
1 parent 0cbf3bf commit 7c1bdd0

File tree

1 file changed

+49
-0
lines changed

1 file changed

+49
-0
lines changed

llama.cpp

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2746,6 +2746,7 @@ static struct ggml_cgraph * llm_build_llama(
27462746
ggml_set_name(cur, "attention_norm_0");
27472747
}
27482748

2749+
// shift the entire K-cache if needed
27492750
if (do_rope_shift) {
27502751
ggml_build_forward_expand(gf,
27512752
ggml_rope_custom_inplace(ctx0,
@@ -2987,6 +2988,8 @@ static struct ggml_cgraph * llm_build_baichaun(
29872988
const int32_t n_tokens = batch.n_tokens;
29882989
const int32_t n_kv = llama_kv_cache_cell_max(kv_self);
29892990

2991+
const bool do_rope_shift = kv_self.has_shift || ggml_allocr_is_measure(lctx.alloc);
2992+
29902993
auto & buf_compute = lctx.buf_compute;
29912994

29922995
struct ggml_init_params params = {
@@ -3090,6 +3093,16 @@ static struct ggml_cgraph * llm_build_baichaun(
30903093
}
30913094
}
30923095

3096+
// K_shift
3097+
struct ggml_tensor * K_shift = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_ctx);
3098+
ggml_allocr_alloc(lctx.alloc, K_shift);
3099+
if (!ggml_allocr_is_measure(lctx.alloc)) {
3100+
int * data = (int *) K_shift->data;
3101+
for (int i = 0; i < n_ctx; ++i) {
3102+
data[i] = kv_self.cells[i].delta;
3103+
}
3104+
}
3105+
30933106
for (int il = 0; il < n_layer; ++il) {
30943107
ggml_format_name(inpL, "layer_inp_%d", il);
30953108

@@ -3115,6 +3128,18 @@ static struct ggml_cgraph * llm_build_baichaun(
31153128
ggml_set_name(cur, "attention_norm_0");
31163129
}
31173130

3131+
// shift the entire K-cache if needed
3132+
if (do_rope_shift) {
3133+
ggml_build_forward_expand(gf,
3134+
ggml_rope_custom_inplace(ctx0,
3135+
ggml_view_3d(ctx0, kv_self.k,
3136+
n_embd_head, n_head_kv, n_ctx,
3137+
ggml_element_size(kv_self.k)*n_embd_head,
3138+
ggml_element_size(kv_self.k)*n_embd_gqa,
3139+
ggml_element_size(kv_self.k)*n_embd_gqa*n_ctx*il),
3140+
K_shift, n_embd_head, 0, 0, freq_base, freq_scale));
3141+
}
3142+
31183143
// self-attention
31193144
{
31203145
// compute Q and K and RoPE them
@@ -3362,6 +3387,8 @@ static struct ggml_cgraph * llm_build_falcon(
33623387
const int32_t n_tokens = batch.n_tokens;
33633388
const int32_t n_kv = llama_kv_cache_cell_max(kv_self);
33643389

3390+
const bool do_rope_shift = kv_self.has_shift || ggml_allocr_is_measure(lctx.alloc);
3391+
33653392
auto & buf_compute = lctx.buf_compute;
33663393

33673394
struct ggml_init_params params = {
@@ -3465,6 +3492,16 @@ static struct ggml_cgraph * llm_build_falcon(
34653492
}
34663493
}
34673494

3495+
// K_shift
3496+
struct ggml_tensor * K_shift = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_ctx);
3497+
ggml_allocr_alloc(lctx.alloc, K_shift);
3498+
if (!ggml_allocr_is_measure(lctx.alloc)) {
3499+
int * data = (int *) K_shift->data;
3500+
for (int i = 0; i < n_ctx; ++i) {
3501+
data[i] = kv_self.cells[i].delta;
3502+
}
3503+
}
3504+
34683505
for (int il = 0; il < n_layer; ++il) {
34693506
struct ggml_tensor * attn_norm;
34703507

@@ -3476,6 +3513,18 @@ static struct ggml_cgraph * llm_build_falcon(
34763513
}
34773514
#endif // GGML_USE_CUBLAS
34783515

3516+
// shift the entire K-cache if needed
3517+
if (do_rope_shift) {
3518+
ggml_build_forward_expand(gf,
3519+
ggml_rope_custom_inplace(ctx0,
3520+
ggml_view_3d(ctx0, kv_self.k,
3521+
n_embd_head, n_head_kv, n_ctx,
3522+
ggml_element_size(kv_self.k)*n_embd_head,
3523+
ggml_element_size(kv_self.k)*n_embd_gqa,
3524+
ggml_element_size(kv_self.k)*n_embd_gqa*n_ctx*il),
3525+
K_shift, n_embd_head, 2, 0, freq_base, freq_scale));
3526+
}
3527+
34793528
// self-attention
34803529
// TODO: refactor into common function (shared with LLaMA)
34813530
{

0 commit comments

Comments
 (0)