9
9
#include < cmath>
10
10
#include < cstring>
11
11
12
- static int32_t llama_relative_position_bucket (llama_pos x, llama_pos y, uint64_t n_buckets, bool bidirectional) {
13
- // TODO move to hparams if a T5 variant appears that uses a different value
14
- const int64_t max_distance = 128 ;
15
-
16
- if (bidirectional) {
17
- n_buckets >>= 1 ;
18
- }
19
-
20
- const int64_t max_exact = n_buckets >> 1 ;
21
-
22
- int32_t relative_position = x - y;
23
- int32_t relative_bucket = 0 ;
24
-
25
- if (bidirectional) {
26
- relative_bucket += (relative_position > 0 ) * n_buckets;
27
- relative_position = abs (relative_position);
28
- } else {
29
- relative_position = -std::min<int32_t >(relative_position, 0 );
30
- }
31
-
32
- int32_t relative_position_if_large = floorf (max_exact + logf (1.0 * relative_position / max_exact) * (n_buckets - max_exact) / log (1.0 * max_distance / max_exact));
33
- relative_position_if_large = std::min<int32_t >(relative_position_if_large, n_buckets - 1 );
34
- relative_bucket += (relative_position < max_exact ? relative_position : relative_position_if_large);
35
-
36
- return relative_bucket;
37
- }
38
-
39
12
void llm_graph_input_embd::set_input (const llama_ubatch * ubatch) {
40
13
if (ubatch->token ) {
41
14
const int64_t n_tokens = ubatch->n_tokens ;
@@ -110,22 +83,7 @@ void llm_graph_input_pos_bucket::set_input(const llama_ubatch * ubatch) {
110
83
111
84
void llm_graph_input_pos_bucket_kv::set_input (const llama_ubatch * ubatch) {
112
85
if (pos_bucket) {
113
- const int64_t n_tokens = ubatch->n_tokens ;
114
-
115
- GGML_ASSERT (ggml_backend_buffer_is_host (pos_bucket->buffer ));
116
- GGML_ASSERT (!ubatch->equal_seqs ); // TODO: use ubatch->n_seqs instead of failing
117
-
118
- int32_t * data = (int32_t *) pos_bucket->data ;
119
-
120
- const int64_t n_kv = kv_self->n ;
121
-
122
- for (int h = 0 ; h < 1 ; ++h) {
123
- for (int j = 0 ; j < n_tokens; ++j) {
124
- for (int i = 0 ; i < n_kv; ++i) {
125
- data[h*(n_kv*n_tokens) + j*n_kv + i] = llama_relative_position_bucket (kv_self->cells [i].pos , ubatch->pos [j], hparams.n_rel_attn_bkts , false );
126
- }
127
- }
128
- }
86
+ kv_self->set_input_pos_bucket (pos_bucket, ubatch);
129
87
}
130
88
}
131
89
@@ -403,99 +361,12 @@ void llm_graph_input_attn_no_cache::set_input(const llama_ubatch * ubatch) {
403
361
}
404
362
405
363
void llm_graph_input_attn_kv_unified::set_input (const llama_ubatch * ubatch) {
406
- if (self_kq_mask || self_kq_mask_swa) {
407
- const int64_t n_kv = kv_self->n ;
408
- const int64_t n_tokens = ubatch->n_tokens ;
409
- const int64_t n_seq_tokens = ubatch->n_seq_tokens ;
410
- const int64_t n_seqs = ubatch->n_seqs ;
411
-
412
- float * data = nullptr ;
413
- float * data_swa = nullptr ;
414
-
415
- if (self_kq_mask) {
416
- GGML_ASSERT (ggml_backend_buffer_is_host (self_kq_mask->buffer ));
417
- data = (float *) self_kq_mask->data ;
418
- }
419
-
420
- if (self_kq_mask_swa) {
421
- GGML_ASSERT (ggml_backend_buffer_is_host (self_kq_mask_swa->buffer ));
422
- data_swa = (float *) self_kq_mask_swa->data ;
423
- }
424
-
425
- // Use only the previous KV cells of the correct sequence for each token of the ubatch.
426
- // It's assumed that if a token in the batch has multiple sequences, they are equivalent.
427
- // Example with a cache of 10 tokens, 2 tokens populated in cache and 3 tokens in batch:
428
- // Causal mask:
429
- // xxx-------
430
- // xxxx------
431
- // xxxxx-----
432
- // Non-causal mask:
433
- // xxxxx-----
434
- // xxxxx-----
435
- // xxxxx-----
436
- // To visualize the mask, see https://github.com/ggml-org/llama.cpp/pull/12615
437
- for (int h = 0 ; h < 1 ; ++h) {
438
- for (int s = 0 ; s < n_seqs; ++s) {
439
- const llama_seq_id seq_id = ubatch->seq_id [s][0 ];
440
-
441
- for (int j = 0 ; j < n_seq_tokens; ++j) {
442
- const llama_pos pos = ubatch->pos [s*n_seq_tokens + j];
443
- for (int i = 0 ; i < n_kv; ++i) {
444
- float f;
445
- // mask the token if:
446
- if (!kv_self->cells [i].has_seq_id (seq_id) // not the correct sequence
447
- || (cparams.causal_attn && kv_self->cells [i].pos > pos) // for causal, mask future tokens
448
- ) {
449
- f = -INFINITY;
450
- } else {
451
- if (hparams.use_alibi ) {
452
- f = -std::abs (kv_self->cells [i].pos - pos);
453
- } else {
454
- f = 0 .0f ;
455
- }
456
- }
457
-
458
- if (data) {
459
- data[h*(n_kv*n_tokens) + s*(n_kv*n_seq_tokens) + j*n_kv + i] = f;
460
- }
461
-
462
- // may need to cut off old tokens for sliding window
463
- // TODO @ngxson : we are currently re-using the swa logic to store the chunked mask, we should rename SWA to something more generic like "aux mask"
464
- if (data_swa) {
465
- if (hparams.n_attn_chunk ) {
466
- llama_pos pos_chunk_start = (pos / hparams.n_attn_chunk ) * hparams.n_attn_chunk ;
467
- if (kv_self->cells [i].pos < pos_chunk_start || pos < pos_chunk_start) {
468
- f = -INFINITY;
469
- }
470
- } else {
471
- if (pos - kv_self->cells [i].pos >= (int32_t )hparams.n_swa ) {
472
- f = -INFINITY;
473
- }
474
- }
475
- data_swa[h*(n_kv*n_tokens) + s*(n_kv*n_seq_tokens) + j*n_kv + i] = f;
476
- }
477
- }
478
- }
479
- }
480
-
481
- // mask padded tokens
482
- if (data) {
483
- for (int i = n_tokens; i < GGML_PAD (n_tokens, GGML_KQ_MASK_PAD); ++i) {
484
- for (int j = 0 ; j < n_kv; ++j) {
485
- data[h*(n_kv*n_tokens) + i*n_kv + j] = -INFINITY;
486
- }
487
- }
488
- }
364
+ if (self_kq_mask) {
365
+ kv_self->set_input_kq_mask (self_kq_mask, ubatch, cparams.causal_attn );
366
+ }
489
367
490
- // mask padded tokens
491
- if (data_swa) {
492
- for (int i = n_tokens; i < GGML_PAD (n_tokens, GGML_KQ_MASK_PAD); ++i) {
493
- for (int j = 0 ; j < n_kv; ++j) {
494
- data_swa[h*(n_kv*n_tokens) + i*n_kv + j] = -INFINITY;
495
- }
496
- }
497
- }
498
- }
368
+ if (self_kq_mask_swa) {
369
+ kv_self->set_input_kq_mask_swa (self_kq_mask_swa, ubatch, cparams.causal_attn );
499
370
}
500
371
}
501
372
@@ -1152,7 +1023,7 @@ ggml_tensor * llm_graph_context::build_inp_pos_bucket_dec() const {
1152
1023
1153
1024
auto inp = std::make_unique<llm_graph_input_pos_bucket_kv>(hparams, kv_self);
1154
1025
1155
- const auto n_kv = kv_self->n ;
1026
+ const auto n_kv = kv_self->n_base () ;
1156
1027
1157
1028
auto & cur = inp->pos_bucket ;
1158
1029
@@ -1368,17 +1239,21 @@ llm_graph_input_attn_kv_unified * llm_graph_context::build_attn_inp_kv_unified()
1368
1239
1369
1240
auto inp = std::make_unique<llm_graph_input_attn_kv_unified>(hparams, cparams, kv_self);
1370
1241
1371
- const auto n_kv = kv_self->n ;
1242
+ {
1243
+ const auto n_kv = kv_self->n_base ();
1372
1244
1373
- inp->self_kq_mask = ggml_new_tensor_2d (ctx0, GGML_TYPE_F32, n_kv, GGML_PAD (n_tokens, GGML_KQ_MASK_PAD));
1374
- // cb(inp->self_kq_mask, "KQ_mask", -1);
1375
- ggml_set_input (inp->self_kq_mask );
1245
+ inp->self_kq_mask = ggml_new_tensor_2d (ctx0, GGML_TYPE_F32, n_kv, GGML_PAD (n_tokens, GGML_KQ_MASK_PAD));
1246
+ // cb(inp->self_kq_mask, "KQ_mask", -1);
1247
+ ggml_set_input (inp->self_kq_mask );
1376
1248
1377
- inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast (ctx0, inp->self_kq_mask , GGML_TYPE_F16) : inp->self_kq_mask ;
1249
+ inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast (ctx0, inp->self_kq_mask , GGML_TYPE_F16) : inp->self_kq_mask ;
1250
+ }
1378
1251
1379
1252
if (hparams.n_swa_pattern > 1 ) {
1380
1253
GGML_ASSERT (hparams.n_swa > 0 );
1381
1254
1255
+ const auto n_kv = kv_self->n_swa ();
1256
+
1382
1257
inp->self_kq_mask_swa = ggml_new_tensor_2d (ctx0, GGML_TYPE_F32, n_kv, GGML_PAD (n_tokens, GGML_KQ_MASK_PAD));
1383
1258
// cb(inp->self_kq_mask_swa, "KQ_mask_swa", -1);
1384
1259
ggml_set_input (inp->self_kq_mask_swa );
@@ -1408,6 +1283,9 @@ ggml_tensor * llm_graph_context::build_attn(
1408
1283
ggml_build_forward_expand (gf, v_cur);
1409
1284
1410
1285
const llama_kv_cache_unified * kv_self = static_cast <const llama_kv_cache_unified *>(memory);
1286
+
1287
+ const auto & kv_layer = kv_self->get_layer (il);
1288
+
1411
1289
const auto & n_ctx = cparams.n_ctx ;
1412
1290
1413
1291
const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa (il);
@@ -1419,11 +1297,11 @@ ggml_tensor * llm_graph_context::build_attn(
1419
1297
1420
1298
// store to KV cache
1421
1299
{
1422
- const auto kv_head = kv_self ->head ;
1300
+ const auto kv_head = kv_layer. cells ->head ;
1423
1301
1424
- GGML_ASSERT (kv_self ->size == n_ctx);
1302
+ GGML_ASSERT (kv_layer. cells ->size == n_ctx);
1425
1303
1426
- ggml_tensor * k_cache_view = ggml_view_1d (ctx0, kv_self-> k_l [il] , n_tokens*n_embd_k_gqa, ggml_row_size (kv_self-> k_l [il] ->type , n_embd_k_gqa)*kv_head);
1304
+ ggml_tensor * k_cache_view = ggml_view_1d (ctx0, kv_layer. k , n_tokens*n_embd_k_gqa, ggml_row_size (kv_layer. k ->type , n_embd_k_gqa)*kv_head);
1427
1305
// cb(k_cache_view, "k_cache_view", il);
1428
1306
1429
1307
// note: storing RoPE-ed version of K in the KV cache
@@ -1434,12 +1312,12 @@ ggml_tensor * llm_graph_context::build_attn(
1434
1312
ggml_tensor * v_cache_view = nullptr ;
1435
1313
1436
1314
if (!v_trans) {
1437
- v_cache_view = ggml_view_1d (ctx0, kv_self-> v_l [il] , n_tokens*n_embd_v_gqa, ggml_row_size (kv_self-> v_l [il] ->type , n_embd_v_gqa)*kv_head);
1315
+ v_cache_view = ggml_view_1d (ctx0, kv_layer. v , n_tokens*n_embd_v_gqa, ggml_row_size (kv_layer. v ->type , n_embd_v_gqa)*kv_head);
1438
1316
} else {
1439
1317
// note: the V cache is transposed when not using flash attention
1440
- v_cache_view = ggml_view_2d (ctx0, kv_self-> v_l [il] , n_tokens, n_embd_v_gqa,
1441
- ( n_ctx)*ggml_element_size (kv_self-> v_l [il] ),
1442
- (kv_head)*ggml_element_size (kv_self-> v_l [il] ));
1318
+ v_cache_view = ggml_view_2d (ctx0, kv_layer. v , n_tokens, n_embd_v_gqa,
1319
+ ( n_ctx)*ggml_element_size (kv_layer. v ),
1320
+ (kv_head)*ggml_element_size (kv_layer. v ));
1443
1321
1444
1322
v_cur = ggml_transpose (ctx0, v_cur);
1445
1323
}
@@ -1449,12 +1327,11 @@ ggml_tensor * llm_graph_context::build_attn(
1449
1327
}
1450
1328
1451
1329
const bool is_swa = hparams.is_swa (il);
1330
+ const int64_t n_head_kv = hparams.n_head_kv (il);
1452
1331
1453
1332
const auto & kq_mask = is_swa ? inp->get_kq_mask_swa () : inp->get_kq_mask ();
1454
1333
1455
- const auto n_kv = kv_self->n ;
1456
-
1457
- const int64_t n_head_kv = hparams.n_head_kv (il);
1334
+ const auto n_kv = kv_layer.cells ->n ;
1458
1335
1459
1336
const auto & n_embd_head_k = hparams.n_embd_head_k ;
1460
1337
const auto & n_embd_head_v = hparams.n_embd_head_v ;
@@ -1463,23 +1340,23 @@ ggml_tensor * llm_graph_context::build_attn(
1463
1340
// cb(q, "q", il);
1464
1341
1465
1342
ggml_tensor * k =
1466
- ggml_view_3d (ctx0, kv_self-> k_l [il] ,
1343
+ ggml_view_3d (ctx0, kv_layer. k ,
1467
1344
n_embd_head_k, n_kv, n_head_kv,
1468
- ggml_row_size (kv_self-> k_l [il] ->type , n_embd_k_gqa),
1469
- ggml_row_size (kv_self-> k_l [il] ->type , n_embd_head_k),
1345
+ ggml_row_size (kv_layer. k ->type , n_embd_k_gqa),
1346
+ ggml_row_size (kv_layer. k ->type , n_embd_head_k),
1470
1347
0 );
1471
1348
// cb(k, "k", il);
1472
1349
1473
1350
ggml_tensor * v = !v_trans ?
1474
- ggml_view_3d (ctx0, kv_self-> v_l [il] ,
1351
+ ggml_view_3d (ctx0, kv_layer. v ,
1475
1352
n_embd_head_v, n_kv, n_head_kv,
1476
- ggml_row_size (kv_self-> v_l [il] ->type , n_embd_v_gqa),
1477
- ggml_row_size (kv_self-> v_l [il] ->type , n_embd_head_v),
1353
+ ggml_row_size (kv_layer. v ->type , n_embd_v_gqa),
1354
+ ggml_row_size (kv_layer. v ->type , n_embd_head_v),
1478
1355
0 ) :
1479
- ggml_view_3d (ctx0, kv_self-> v_l [il] ,
1356
+ ggml_view_3d (ctx0, kv_layer. v ,
1480
1357
n_kv, n_embd_head_v, n_head_kv,
1481
- ggml_element_size (kv_self-> v_l [il] )*n_ctx,
1482
- ggml_element_size (kv_self-> v_l [il] )*n_ctx*n_embd_head_v,
1358
+ ggml_element_size (kv_layer. v )*n_ctx,
1359
+ ggml_element_size (kv_layer. v )*n_ctx*n_embd_head_v,
1483
1360
0 );
1484
1361
1485
1362
ggml_tensor * cur = build_attn_mha (gf, q, k, v, kq_b, kq_mask, v_mla, v_trans, kq_scale);
@@ -1711,3 +1588,30 @@ void llm_graph_context::build_pooling(
1711
1588
1712
1589
ggml_build_forward_expand (gf, cur);
1713
1590
}
1591
+
1592
+ int32_t llama_relative_position_bucket (llama_pos x, llama_pos y, uint64_t n_buckets, bool bidirectional) {
1593
+ // TODO move to hparams if a T5 variant appears that uses a different value
1594
+ const int64_t max_distance = 128 ;
1595
+
1596
+ if (bidirectional) {
1597
+ n_buckets >>= 1 ;
1598
+ }
1599
+
1600
+ const int64_t max_exact = n_buckets >> 1 ;
1601
+
1602
+ int32_t relative_position = x - y;
1603
+ int32_t relative_bucket = 0 ;
1604
+
1605
+ if (bidirectional) {
1606
+ relative_bucket += (relative_position > 0 ) * n_buckets;
1607
+ relative_position = abs (relative_position);
1608
+ } else {
1609
+ relative_position = -std::min<int32_t >(relative_position, 0 );
1610
+ }
1611
+
1612
+ int32_t relative_position_if_large = floorf (max_exact + logf (1.0 * relative_position / max_exact) * (n_buckets - max_exact) / log (1.0 * max_distance / max_exact));
1613
+ relative_position_if_large = std::min<int32_t >(relative_position_if_large, n_buckets - 1 );
1614
+ relative_bucket += (relative_position < max_exact ? relative_position : relative_position_if_large);
1615
+
1616
+ return relative_bucket;
1617
+ }
0 commit comments