@@ -2817,6 +2817,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
2817
2817
"GET_REL_POS",
2818
2818
"ADD_REL_POS",
2819
2819
"RWKV_WKV",
2820
+ "RWKV_TOKEN_SHIFT",
2820
2821
2821
2822
"UNARY",
2822
2823
@@ -2835,7 +2836,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
2835
2836
"CROSS_ENTROPY_LOSS_BACK",
2836
2837
};
2837
2838
2838
- static_assert(GGML_OP_COUNT == 75 , "GGML_OP_COUNT != 75 ");
2839
+ static_assert(GGML_OP_COUNT == 76 , "GGML_OP_COUNT != 76 ");
2839
2840
2840
2841
static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
2841
2842
"none",
@@ -2905,7 +2906,8 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
2905
2906
"win_unpart(x)",
2906
2907
"get_rel_pos(x)",
2907
2908
"add_rel_pos(x)",
2908
- "rwkv_wkv(x, k, v, r, tf, td, s)",
2909
+ "rwkv_wkv(k, v, r, tf, td, s, sq)",
2910
+ "rwkv_token_shift(xc, xn, sq)",
2909
2911
2910
2912
"unary(x)",
2911
2913
@@ -2924,7 +2926,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
2924
2926
"cross_entropy_loss_back(x,y)",
2925
2927
};
2926
2928
2927
- static_assert(GGML_OP_COUNT == 75 , "GGML_OP_COUNT != 75 ");
2929
+ static_assert(GGML_OP_COUNT == 76 , "GGML_OP_COUNT != 76 ");
2928
2930
2929
2931
static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
2930
2932
@@ -7499,35 +7501,39 @@ struct ggml_tensor * ggml_rwkv_wkv(
7499
7501
struct ggml_tensor * r,
7500
7502
struct ggml_tensor * tf,
7501
7503
struct ggml_tensor * td,
7502
- struct ggml_tensor * state) {
7504
+ struct ggml_tensor * state,
7505
+ struct ggml_tensor * state_seq) {
7503
7506
GGML_ASSERT(ggml_is_contiguous(k));
7504
7507
GGML_ASSERT(ggml_is_contiguous(v));
7505
7508
GGML_ASSERT(ggml_is_contiguous(r));
7506
7509
GGML_ASSERT(ggml_is_contiguous(tf));
7507
7510
GGML_ASSERT(ggml_is_contiguous(td));
7508
7511
GGML_ASSERT(ggml_is_contiguous(state));
7512
+ GGML_ASSERT(ggml_is_contiguous(state_seq));
7513
+ GGML_ASSERT(state_seq->type == GGML_TYPE_I32);
7509
7514
7510
7515
const int64_t S = k->ne[0];
7511
7516
const int64_t H = k->ne[2];
7512
7517
const int64_t n_tokens = k->ne[3];
7518
+ const int64_t n_kv = state_seq->ne[0];
7513
7519
{
7514
7520
GGML_ASSERT(k->ne[1] == 1);
7515
7521
GGML_ASSERT(v->ne[0] == 1 && v->ne[1] == S && v->ne[2] == H && v->ne[3] == n_tokens);
7516
7522
GGML_ASSERT(r->ne[0] == 1 && r->ne[1] == S && r->ne[2] == H && r->ne[3] == n_tokens);
7517
7523
// TODO: RWKV v4 and v5
7518
7524
GGML_ASSERT(td->ne[0] == 1 && td->ne[1] == S && td->ne[2] == H && td->ne[3] == n_tokens);
7519
- GGML_ASSERT(ggml_nelements(state) == S * S * H);
7525
+ GGML_ASSERT(ggml_nelements(state) == S * S * H * n_kv );
7520
7526
}
7521
7527
7522
7528
bool is_node = false;
7523
7529
7524
- if (k->grad || v->grad || r->grad || tf->grad || td->grad || state->grad) {
7530
+ if (k->grad || v->grad || r->grad || tf->grad || td->grad || state->grad || state_seq->grad ) {
7525
7531
GGML_ABORT("fatal error"); // TODO: implement backward
7526
7532
is_node = true;
7527
7533
}
7528
7534
7529
7535
// concat output and new_state
7530
- const int64_t ne[4] = { S * H, n_tokens + S, 1, 1 };
7536
+ const int64_t ne[4] = { S * H, n_tokens + S * n_kv , 1, 1 };
7531
7537
struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
7532
7538
7533
7539
result->op = GGML_OP_RWKV_WKV;
@@ -7538,6 +7544,48 @@ struct ggml_tensor * ggml_rwkv_wkv(
7538
7544
result->src[3] = tf;
7539
7545
result->src[4] = td;
7540
7546
result->src[5] = state;
7547
+ result->src[6] = state_seq;
7548
+
7549
+ return result;
7550
+ }
7551
+
7552
+ // ggml_rwkv_token_shift
7553
+
7554
+ struct ggml_tensor * ggml_rwkv_token_shift(
7555
+ struct ggml_context * ctx,
7556
+ struct ggml_tensor * x_carry,
7557
+ struct ggml_tensor * x_norm,
7558
+ struct ggml_tensor * state_seq) {
7559
+ GGML_ASSERT(ggml_is_contiguous(x_carry));
7560
+ GGML_ASSERT(ggml_is_contiguous(x_norm));
7561
+ GGML_ASSERT(ggml_is_contiguous(state_seq));
7562
+ GGML_ASSERT(state_seq->type == GGML_TYPE_I32);
7563
+
7564
+ const int64_t n_embd = x_norm->ne[0];
7565
+ const int64_t n_kv = state_seq->ne[0];
7566
+ const int64_t n_tokens = state_seq->ne[1];
7567
+ {
7568
+ GGML_ASSERT(x_norm->ne[0] == n_embd);
7569
+ GGML_ASSERT(x_norm->ne[1] == n_tokens);
7570
+ GGML_ASSERT(ggml_nelements(x_carry) == n_embd * n_kv);
7571
+ }
7572
+
7573
+ bool is_node = false;
7574
+
7575
+ if (x_carry->grad || x_norm->grad || state_seq->grad) {
7576
+ GGML_ABORT("fatal error"); // TODO: implement backward
7577
+ is_node = true;
7578
+ }
7579
+
7580
+ // concat output and new_state
7581
+ const int64_t ne[4] = { n_embd, n_tokens + n_kv, 1, 1 };
7582
+ struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
7583
+
7584
+ result->op = GGML_OP_RWKV_TOKEN_SHIFT;
7585
+ result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
7586
+ result->src[0] = x_carry;
7587
+ result->src[1] = x_norm;
7588
+ result->src[2] = state_seq;
7541
7589
7542
7590
return result;
7543
7591
}
@@ -16463,6 +16511,7 @@ static void ggml_compute_forward_rwkv_wkv_f32(
16463
16511
const size_t T = dst->src[1]->ne[3];
16464
16512
const size_t C = dst->ne[0];
16465
16513
const size_t H = dst->src[1]->ne[2];
16514
+ const size_t n_kv = dst->src[6]->ne[0];
16466
16515
16467
16516
float * dst_data = (float *) dst->data;
16468
16517
float * state = ((float *) dst->data) + C * T;
@@ -16478,7 +16527,8 @@ static void ggml_compute_forward_rwkv_wkv_f32(
16478
16527
float * r = (float *) dst->src[2]->data;
16479
16528
float * time_faaaa = (float *) dst->src[3]->data;
16480
16529
float * time_decay = (float *) dst->src[4]->data;
16481
- memcpy(state, dst->src[5]->data, (C / H) * C * sizeof(float));
16530
+ int32_t * seq_data = (int32_t *) dst->src[6]->data;
16531
+ memcpy(state, dst->src[5]->data, (C / H) * C * n_kv * sizeof(float));
16482
16532
16483
16533
size_t t_stride = H * (C / H);
16484
16534
@@ -16491,6 +16541,7 @@ static void ggml_compute_forward_rwkv_wkv_f32(
16491
16541
// recursive through each token
16492
16542
for (size_t t = 0; t < T; t++) {
16493
16543
size_t t_offset = t * t_stride;
16544
+ float * state_cur = state + (C / H) * C * seq_data[t * n_kv];
16494
16545
16495
16546
for (size_t h = 0; h < H; h++) {
16496
16547
size_t h_offset = h * h_stride;
@@ -16514,14 +16565,23 @@ static void ggml_compute_forward_rwkv_wkv_f32(
16514
16565
16515
16566
float v_val = v[t_h_j_offset];
16516
16567
float kv_val = v_val * k_val;
16517
- float prev_state_val = state [h_2d_i_j_offset];
16568
+ float prev_state_val = state_cur [h_2d_i_j_offset];
16518
16569
float temp_val = kv_val * time_faaaa_val + prev_state_val;
16519
16570
dst_data[t_h_j_offset] += temp_val * r_val;
16520
- state [h_2d_i_j_offset] = prev_state_val * time_decay_val + kv_val;
16571
+ state_cur [h_2d_i_j_offset] = prev_state_val * time_decay_val + kv_val;
16521
16572
}
16522
16573
}
16523
16574
}
16524
16575
}
16576
+
16577
+ for (size_t t = 0; t < T; t++) {
16578
+ for (size_t kv = 1; kv < n_kv; kv++) {
16579
+ int64_t seq = seq_data[t * n_kv + kv];
16580
+ if (seq >= 0 && seq_data[(t + 1) * n_kv + kv] != seq) {
16581
+ memcpy(state + (C / H) * C * seq, state + (C / H) * C * seq_data[t * n_kv], (C / H) * C * sizeof(float));
16582
+ }
16583
+ }
16584
+ }
16525
16585
}
16526
16586
16527
16587
static void ggml_compute_forward_rwkv_wkv(
@@ -16542,6 +16602,77 @@ static void ggml_compute_forward_rwkv_wkv(
16542
16602
}
16543
16603
}
16544
16604
16605
+ static void ggml_compute_forward_rwkv_token_shift_f32(
16606
+ const struct ggml_compute_params * params,
16607
+ struct ggml_tensor * dst) {
16608
+ const int64_t n_embd = dst->ne[0];
16609
+ const int64_t n_kv = dst->src[2]->ne[0];
16610
+ const int64_t n_tokens = dst->src[1]->ne[1];
16611
+ float * dst_data = (float *) dst->data;
16612
+ float * x_carry = (float *) dst->src[0]->data;
16613
+ float * x_norm = (float *) dst->src[1]->data;
16614
+ int32_t * sq_data = (int32_t *) dst->src[2]->data;
16615
+
16616
+ if (params->ith != 0) {
16617
+ return;
16618
+ }
16619
+
16620
+ int32_t seq_start = 0;
16621
+ int32_t seq_length = 0;
16622
+
16623
+ for (int i1 = 0; i1 < n_kv; ++i1) {
16624
+ seq_start = -1;
16625
+ // assume that the tokens for each sequence are contiguous
16626
+ for (int i2 = 0; i2 < n_tokens; ++i2) {
16627
+ int32_t seq = sq_data[i2*n_kv];
16628
+ if (seq == i1 && seq_start < 0) {
16629
+ seq_start = i2;
16630
+ }
16631
+
16632
+ if ((seq_start >= 0 && seq != i1) || i2 == n_tokens - 1) {
16633
+ seq_length = i2 - seq_start + (i2 == n_tokens - 1);
16634
+ break;
16635
+ }
16636
+ }
16637
+
16638
+ if (seq_start >= 0) {
16639
+ int32_t seq = sq_data[seq_start*n_kv];
16640
+ memcpy(dst_data + seq_start*n_embd, x_carry + seq*n_embd, n_embd*sizeof(float));
16641
+ memcpy(dst_data + (seq_start+1)*n_embd, x_norm + seq_start*n_embd, (seq_length-1)*n_embd*sizeof(float));
16642
+ }
16643
+ }
16644
+
16645
+ for (int i3 = 0; i3 < n_kv; ++i3) {
16646
+ int32_t last_token_pos = 0;
16647
+ for (int i4 = 0; i4 < n_tokens; ++i4) {
16648
+ for (int i5 = 0; i5 < n_kv; ++i5) {
16649
+ if (sq_data[i4*n_kv + i5] == i3) {
16650
+ last_token_pos = i4;
16651
+ }
16652
+ }
16653
+ }
16654
+ memcpy(dst_data + (n_tokens + i3)*n_embd, x_norm + last_token_pos*n_embd, n_embd*sizeof(float));
16655
+ }
16656
+ }
16657
+
16658
+ static void ggml_compute_forward_rwkv_token_shift(
16659
+ const struct ggml_compute_params * params,
16660
+ struct ggml_tensor * dst) {
16661
+
16662
+ const struct ggml_tensor * src0 = dst->src[0];
16663
+
16664
+ switch (src0->type) {
16665
+ case GGML_TYPE_F32:
16666
+ {
16667
+ ggml_compute_forward_rwkv_token_shift_f32(params, dst);
16668
+ } break;
16669
+ default:
16670
+ {
16671
+ GGML_ABORT("fatal error");
16672
+ }
16673
+ }
16674
+ }
16675
+
16545
16676
// ggml_compute_forward_map_unary
16546
16677
16547
16678
static void ggml_compute_forward_map_unary_f32(
@@ -17192,6 +17323,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
17192
17323
{
17193
17324
ggml_compute_forward_rwkv_wkv(params, tensor);
17194
17325
} break;
17326
+ case GGML_OP_RWKV_TOKEN_SHIFT:
17327
+ {
17328
+ ggml_compute_forward_rwkv_token_shift(params, tensor);
17329
+ } break;
17195
17330
case GGML_OP_MAP_UNARY:
17196
17331
{
17197
17332
ggml_unary_op_f32_t fun;
@@ -18254,6 +18389,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
18254
18389
case GGML_OP_GET_REL_POS:
18255
18390
case GGML_OP_ADD_REL_POS:
18256
18391
case GGML_OP_RWKV_WKV:
18392
+ case GGML_OP_RWKV_TOKEN_SHIFT:
18257
18393
case GGML_OP_MAP_UNARY:
18258
18394
case GGML_OP_MAP_BINARY:
18259
18395
case GGML_OP_MAP_CUSTOM1_F32:
@@ -18824,6 +18960,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
18824
18960
case GGML_OP_WIN_UNPART:
18825
18961
case GGML_OP_GET_REL_POS:
18826
18962
case GGML_OP_RWKV_WKV:
18963
+ case GGML_OP_RWKV_TOKEN_SHIFT:
18827
18964
case GGML_OP_MAP_UNARY:
18828
18965
case GGML_OP_MAP_BINARY:
18829
18966
case GGML_OP_MAP_CUSTOM1_F32:
0 commit comments