Skip to content

Commit f0fa165

Browse files
committed
Fix parallel inferencing for RWKV
Signed-off-by: Molly Sophia <mollysophia379@gmail.com>
1 parent bb18d82 commit f0fa165

File tree

3 files changed

+232
-43
lines changed

3 files changed

+232
-43
lines changed

ggml/include/ggml.h

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -507,6 +507,7 @@ extern "C" {
507507
GGML_OP_GET_REL_POS,
508508
GGML_OP_ADD_REL_POS,
509509
GGML_OP_RWKV_WKV,
510+
GGML_OP_RWKV_TOKEN_SHIFT,
510511

511512
GGML_OP_UNARY,
512513

@@ -1857,7 +1858,14 @@ extern "C" {
18571858
struct ggml_tensor * r,
18581859
struct ggml_tensor * tf,
18591860
struct ggml_tensor * td,
1860-
struct ggml_tensor * state);
1861+
struct ggml_tensor * state,
1862+
struct ggml_tensor * state_seq);
1863+
1864+
GGML_API struct ggml_tensor * ggml_rwkv_token_shift(
1865+
struct ggml_context * ctx,
1866+
struct ggml_tensor * x_carry,
1867+
struct ggml_tensor * x_norm,
1868+
struct ggml_tensor * state_seq);
18611869

18621870
// custom operators
18631871

ggml/src/ggml.c

Lines changed: 147 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2817,6 +2817,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
28172817
"GET_REL_POS",
28182818
"ADD_REL_POS",
28192819
"RWKV_WKV",
2820+
"RWKV_TOKEN_SHIFT",
28202821

28212822
"UNARY",
28222823

@@ -2835,7 +2836,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
28352836
"CROSS_ENTROPY_LOSS_BACK",
28362837
};
28372838

2838-
static_assert(GGML_OP_COUNT == 75, "GGML_OP_COUNT != 75");
2839+
static_assert(GGML_OP_COUNT == 76, "GGML_OP_COUNT != 76");
28392840

28402841
static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
28412842
"none",
@@ -2905,7 +2906,8 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
29052906
"win_unpart(x)",
29062907
"get_rel_pos(x)",
29072908
"add_rel_pos(x)",
2908-
"rwkv_wkv(x, k, v, r, tf, td, s)",
2909+
"rwkv_wkv(k, v, r, tf, td, s, sq)",
2910+
"rwkv_token_shift(xc, xn, sq)",
29092911

29102912
"unary(x)",
29112913

@@ -2924,7 +2926,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
29242926
"cross_entropy_loss_back(x,y)",
29252927
};
29262928

2927-
static_assert(GGML_OP_COUNT == 75, "GGML_OP_COUNT != 75");
2929+
static_assert(GGML_OP_COUNT == 76, "GGML_OP_COUNT != 76");
29282930

29292931
static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
29302932

@@ -7499,35 +7501,39 @@ struct ggml_tensor * ggml_rwkv_wkv(
74997501
struct ggml_tensor * r,
75007502
struct ggml_tensor * tf,
75017503
struct ggml_tensor * td,
7502-
struct ggml_tensor * state) {
7504+
struct ggml_tensor * state,
7505+
struct ggml_tensor * state_seq) {
75037506
GGML_ASSERT(ggml_is_contiguous(k));
75047507
GGML_ASSERT(ggml_is_contiguous(v));
75057508
GGML_ASSERT(ggml_is_contiguous(r));
75067509
GGML_ASSERT(ggml_is_contiguous(tf));
75077510
GGML_ASSERT(ggml_is_contiguous(td));
75087511
GGML_ASSERT(ggml_is_contiguous(state));
7512+
GGML_ASSERT(ggml_is_contiguous(state_seq));
7513+
GGML_ASSERT(state_seq->type == GGML_TYPE_I32);
75097514

75107515
const int64_t S = k->ne[0];
75117516
const int64_t H = k->ne[2];
75127517
const int64_t n_tokens = k->ne[3];
7518+
const int64_t n_kv = state_seq->ne[0];
75137519
{
75147520
GGML_ASSERT(k->ne[1] == 1);
75157521
GGML_ASSERT(v->ne[0] == 1 && v->ne[1] == S && v->ne[2] == H && v->ne[3] == n_tokens);
75167522
GGML_ASSERT(r->ne[0] == 1 && r->ne[1] == S && r->ne[2] == H && r->ne[3] == n_tokens);
75177523
// TODO: RWKV v4 and v5
75187524
GGML_ASSERT(td->ne[0] == 1 && td->ne[1] == S && td->ne[2] == H && td->ne[3] == n_tokens);
7519-
GGML_ASSERT(ggml_nelements(state) == S * S * H);
7525+
GGML_ASSERT(ggml_nelements(state) == S * S * H * n_kv);
75207526
}
75217527

75227528
bool is_node = false;
75237529

7524-
if (k->grad || v->grad || r->grad || tf->grad || td->grad || state->grad) {
7530+
if (k->grad || v->grad || r->grad || tf->grad || td->grad || state->grad || state_seq->grad) {
75257531
GGML_ABORT("fatal error"); // TODO: implement backward
75267532
is_node = true;
75277533
}
75287534

75297535
// concat output and new_state
7530-
const int64_t ne[4] = { S * H, n_tokens + S, 1, 1 };
7536+
const int64_t ne[4] = { S * H, n_tokens + S * n_kv, 1, 1 };
75317537
struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
75327538

75337539
result->op = GGML_OP_RWKV_WKV;
@@ -7538,6 +7544,48 @@ struct ggml_tensor * ggml_rwkv_wkv(
75387544
result->src[3] = tf;
75397545
result->src[4] = td;
75407546
result->src[5] = state;
7547+
result->src[6] = state_seq;
7548+
7549+
return result;
7550+
}
7551+
7552+
// ggml_rwkv_token_shift
7553+
7554+
struct ggml_tensor * ggml_rwkv_token_shift(
7555+
struct ggml_context * ctx,
7556+
struct ggml_tensor * x_carry,
7557+
struct ggml_tensor * x_norm,
7558+
struct ggml_tensor * state_seq) {
7559+
GGML_ASSERT(ggml_is_contiguous(x_carry));
7560+
GGML_ASSERT(ggml_is_contiguous(x_norm));
7561+
GGML_ASSERT(ggml_is_contiguous(state_seq));
7562+
GGML_ASSERT(state_seq->type == GGML_TYPE_I32);
7563+
7564+
const int64_t n_embd = x_norm->ne[0];
7565+
const int64_t n_kv = state_seq->ne[0];
7566+
const int64_t n_tokens = state_seq->ne[1];
7567+
{
7568+
GGML_ASSERT(x_norm->ne[0] == n_embd);
7569+
GGML_ASSERT(x_norm->ne[1] == n_tokens);
7570+
GGML_ASSERT(ggml_nelements(x_carry) == n_embd * n_kv);
7571+
}
7572+
7573+
bool is_node = false;
7574+
7575+
if (x_carry->grad || x_norm->grad || state_seq->grad) {
7576+
GGML_ABORT("fatal error"); // TODO: implement backward
7577+
is_node = true;
7578+
}
7579+
7580+
// concat output and new_state
7581+
const int64_t ne[4] = { n_embd, n_tokens + n_kv, 1, 1 };
7582+
struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
7583+
7584+
result->op = GGML_OP_RWKV_TOKEN_SHIFT;
7585+
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
7586+
result->src[0] = x_carry;
7587+
result->src[1] = x_norm;
7588+
result->src[2] = state_seq;
75417589

75427590
return result;
75437591
}
@@ -16463,6 +16511,7 @@ static void ggml_compute_forward_rwkv_wkv_f32(
1646316511
const size_t T = dst->src[1]->ne[3];
1646416512
const size_t C = dst->ne[0];
1646516513
const size_t H = dst->src[1]->ne[2];
16514+
const size_t n_kv = dst->src[6]->ne[0];
1646616515

1646716516
float * dst_data = (float *) dst->data;
1646816517
float * state = ((float *) dst->data) + C * T;
@@ -16478,7 +16527,8 @@ static void ggml_compute_forward_rwkv_wkv_f32(
1647816527
float * r = (float *) dst->src[2]->data;
1647916528
float * time_faaaa = (float *) dst->src[3]->data;
1648016529
float * time_decay = (float *) dst->src[4]->data;
16481-
memcpy(state, dst->src[5]->data, (C / H) * C * sizeof(float));
16530+
int32_t * seq_data = (int32_t *) dst->src[6]->data;
16531+
memcpy(state, dst->src[5]->data, (C / H) * C * n_kv * sizeof(float));
1648216532

1648316533
size_t t_stride = H * (C / H);
1648416534

@@ -16491,6 +16541,7 @@ static void ggml_compute_forward_rwkv_wkv_f32(
1649116541
// recursive through each token
1649216542
for (size_t t = 0; t < T; t++) {
1649316543
size_t t_offset = t * t_stride;
16544+
float * state_cur = state + (C / H) * C * seq_data[t * n_kv];
1649416545

1649516546
for (size_t h = 0; h < H; h++) {
1649616547
size_t h_offset = h * h_stride;
@@ -16514,14 +16565,23 @@ static void ggml_compute_forward_rwkv_wkv_f32(
1651416565

1651516566
float v_val = v[t_h_j_offset];
1651616567
float kv_val = v_val * k_val;
16517-
float prev_state_val = state[h_2d_i_j_offset];
16568+
float prev_state_val = state_cur[h_2d_i_j_offset];
1651816569
float temp_val = kv_val * time_faaaa_val + prev_state_val;
1651916570
dst_data[t_h_j_offset] += temp_val * r_val;
16520-
state[h_2d_i_j_offset] = prev_state_val * time_decay_val + kv_val;
16571+
state_cur[h_2d_i_j_offset] = prev_state_val * time_decay_val + kv_val;
1652116572
}
1652216573
}
1652316574
}
1652416575
}
16576+
16577+
for (size_t t = 0; t < T; t++) {
16578+
for (size_t kv = 1; kv < n_kv; kv++) {
16579+
int64_t seq = seq_data[t * n_kv + kv];
16580+
if (seq >= 0 && seq_data[(t + 1) * n_kv + kv] != seq) {
16581+
memcpy(state + (C / H) * C * seq, state + (C / H) * C * seq_data[t * n_kv], (C / H) * C * sizeof(float));
16582+
}
16583+
}
16584+
}
1652516585
}
1652616586

1652716587
static void ggml_compute_forward_rwkv_wkv(
@@ -16542,6 +16602,77 @@ static void ggml_compute_forward_rwkv_wkv(
1654216602
}
1654316603
}
1654416604

16605+
static void ggml_compute_forward_rwkv_token_shift_f32(
16606+
const struct ggml_compute_params * params,
16607+
struct ggml_tensor * dst) {
16608+
const int64_t n_embd = dst->ne[0];
16609+
const int64_t n_kv = dst->src[2]->ne[0];
16610+
const int64_t n_tokens = dst->src[1]->ne[1];
16611+
float * dst_data = (float *) dst->data;
16612+
float * x_carry = (float *) dst->src[0]->data;
16613+
float * x_norm = (float *) dst->src[1]->data;
16614+
int32_t * sq_data = (int32_t *) dst->src[2]->data;
16615+
16616+
if (params->ith != 0) {
16617+
return;
16618+
}
16619+
16620+
int32_t seq_start = 0;
16621+
int32_t seq_length = 0;
16622+
16623+
for (int i1 = 0; i1 < n_kv; ++i1) {
16624+
seq_start = -1;
16625+
// assume that the tokens for each sequence are contiguous
16626+
for (int i2 = 0; i2 < n_tokens; ++i2) {
16627+
int32_t seq = sq_data[i2*n_kv];
16628+
if (seq == i1 && seq_start < 0) {
16629+
seq_start = i2;
16630+
}
16631+
16632+
if ((seq_start >= 0 && seq != i1) || i2 == n_tokens - 1) {
16633+
seq_length = i2 - seq_start + (i2 == n_tokens - 1);
16634+
break;
16635+
}
16636+
}
16637+
16638+
if (seq_start >= 0) {
16639+
int32_t seq = sq_data[seq_start*n_kv];
16640+
memcpy(dst_data + seq_start*n_embd, x_carry + seq*n_embd, n_embd*sizeof(float));
16641+
memcpy(dst_data + (seq_start+1)*n_embd, x_norm + seq_start*n_embd, (seq_length-1)*n_embd*sizeof(float));
16642+
}
16643+
}
16644+
16645+
for (int i3 = 0; i3 < n_kv; ++i3) {
16646+
int32_t last_token_pos = 0;
16647+
for (int i4 = 0; i4 < n_tokens; ++i4) {
16648+
for (int i5 = 0; i5 < n_kv; ++i5) {
16649+
if (sq_data[i4*n_kv + i5] == i3) {
16650+
last_token_pos = i4;
16651+
}
16652+
}
16653+
}
16654+
memcpy(dst_data + (n_tokens + i3)*n_embd, x_norm + last_token_pos*n_embd, n_embd*sizeof(float));
16655+
}
16656+
}
16657+
16658+
static void ggml_compute_forward_rwkv_token_shift(
16659+
const struct ggml_compute_params * params,
16660+
struct ggml_tensor * dst) {
16661+
16662+
const struct ggml_tensor * src0 = dst->src[0];
16663+
16664+
switch (src0->type) {
16665+
case GGML_TYPE_F32:
16666+
{
16667+
ggml_compute_forward_rwkv_token_shift_f32(params, dst);
16668+
} break;
16669+
default:
16670+
{
16671+
GGML_ABORT("fatal error");
16672+
}
16673+
}
16674+
}
16675+
1654516676
// ggml_compute_forward_map_unary
1654616677

1654716678
static void ggml_compute_forward_map_unary_f32(
@@ -17192,6 +17323,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
1719217323
{
1719317324
ggml_compute_forward_rwkv_wkv(params, tensor);
1719417325
} break;
17326+
case GGML_OP_RWKV_TOKEN_SHIFT:
17327+
{
17328+
ggml_compute_forward_rwkv_token_shift(params, tensor);
17329+
} break;
1719517330
case GGML_OP_MAP_UNARY:
1719617331
{
1719717332
ggml_unary_op_f32_t fun;
@@ -18254,6 +18389,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
1825418389
case GGML_OP_GET_REL_POS:
1825518390
case GGML_OP_ADD_REL_POS:
1825618391
case GGML_OP_RWKV_WKV:
18392+
case GGML_OP_RWKV_TOKEN_SHIFT:
1825718393
case GGML_OP_MAP_UNARY:
1825818394
case GGML_OP_MAP_BINARY:
1825918395
case GGML_OP_MAP_CUSTOM1_F32:
@@ -18824,6 +18960,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
1882418960
case GGML_OP_WIN_UNPART:
1882518961
case GGML_OP_GET_REL_POS:
1882618962
case GGML_OP_RWKV_WKV:
18963+
case GGML_OP_RWKV_TOKEN_SHIFT:
1882718964
case GGML_OP_MAP_UNARY:
1882818965
case GGML_OP_MAP_BINARY:
1882918966
case GGML_OP_MAP_CUSTOM1_F32:

0 commit comments

Comments
 (0)