Skip to content

Commit 487fb6d

Browse files
committed
build_rwkv: Avoid using inplace operations
Signed-off-by: Molly Sophia <mollysophia379@gmail.com>
1 parent b1b6c7e commit 487fb6d

File tree

1 file changed

+61
-85
lines changed

1 file changed

+61
-85
lines changed

src/llama.cpp

Lines changed: 61 additions & 85 deletions
Original file line numberDiff line numberDiff line change
@@ -8563,36 +8563,29 @@ static struct ggml_tensor * llm_build_kv(
85638563
static struct ggml_tensor * llm_build_time_mix(
85648564
struct ggml_context * ctx,
85658565
const struct llama_layer * layer,
8566-
struct ggml_tensor * current,
8566+
struct ggml_tensor * cur,
85678567
struct ggml_tensor * x_prev,
85688568
struct ggml_tensor ** wkv_state,
85698569
struct ggml_tensor * state_seq) {
8570-
size_t n_embed = current->ne[0];
8571-
size_t n_tokens = current->ne[1];
8570+
size_t n_embed = cur->ne[0];
8571+
size_t n_tokens = cur->ne[1];
85728572
size_t head_size = layer->time_mix_first->ne[0];
85738573
size_t head_count = layer->time_mix_first->ne[1];
85748574
size_t n_kv = state_seq->ne[0];
85758575

8576-
struct ggml_tensor * sx = ggml_sub(ctx, x_prev, current);
8577-
struct ggml_tensor * xxx = ggml_add_inplace(
8578-
ctx,
8579-
ggml_mul(ctx, sx, layer->time_mix_lerp_x),
8580-
current
8581-
);
8576+
struct ggml_tensor * sx = ggml_sub(ctx, x_prev, cur);
8577+
struct ggml_tensor * xxx = ggml_add(ctx, ggml_mul(ctx, sx, layer->time_mix_lerp_x), cur);
85828578

85838579
xxx = ggml_reshape_4d(
85848580
ctx,
8585-
ggml_tanh_inplace(
8581+
ggml_tanh(
85868582
ctx,
85878583
ggml_mul_mat(ctx, layer->time_mix_w1, xxx)
85888584
),
85898585
layer->time_mix_w1->ne[1] / 5, 1, 5, n_tokens
85908586
);
85918587

8592-
xxx = ggml_cont(
8593-
ctx,
8594-
ggml_permute(ctx, xxx, 0, 1, 3, 2)
8595-
);
8588+
xxx = ggml_cont(ctx, ggml_permute(ctx, xxx, 0, 1, 3, 2));
85968589

85978590
xxx = ggml_mul_mat(
85988591
ctx,
@@ -8614,151 +8607,138 @@ static struct ggml_tensor * llm_build_time_mix(
86148607
struct ggml_tensor *mk = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embed * n_tokens);
86158608
mk = ggml_reshape_2d(
86168609
ctx,
8617-
ggml_set_1d_inplace(ctx, mk, ggml_view_1d(ctx, xxx, n_embed * n_tokens, n_embed * n_tokens * sizeof(float)), 0),
8610+
ggml_set_1d(ctx, mk, ggml_view_1d(ctx, xxx, n_embed * n_tokens, n_embed * n_tokens * sizeof(float)), 0),
86188611
n_embed, n_tokens
86198612
);
86208613

86218614
struct ggml_tensor *mv = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embed * n_tokens);
86228615
mv = ggml_reshape_2d(
86238616
ctx,
8624-
ggml_set_1d_inplace(ctx, mv, ggml_view_1d(ctx, xxx, n_embed * n_tokens, n_embed * n_tokens * 2 * sizeof(float)), 0),
8617+
ggml_set_1d(ctx, mv, ggml_view_1d(ctx, xxx, n_embed * n_tokens, n_embed * n_tokens * 2 * sizeof(float)), 0),
86258618
n_embed, n_tokens
86268619
);
86278620

86288621
struct ggml_tensor *mr = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embed * n_tokens);
86298622
mr = ggml_reshape_2d(
86308623
ctx,
8631-
ggml_set_1d_inplace(ctx, mr, ggml_view_1d(ctx, xxx, n_embed * n_tokens, n_embed * n_tokens * 3 * sizeof(float)), 0),
8624+
ggml_set_1d(ctx, mr, ggml_view_1d(ctx, xxx, n_embed * n_tokens, n_embed * n_tokens * 3 * sizeof(float)), 0),
86328625
n_embed, n_tokens
86338626
);
86348627

86358628
struct ggml_tensor *mg = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embed * n_tokens);
86368629
mg = ggml_reshape_2d(
86378630
ctx,
8638-
ggml_set_1d_inplace(ctx, mg, ggml_view_1d(ctx, xxx, n_embed * n_tokens, n_embed * n_tokens * 4 * sizeof(float)), 0),
8631+
ggml_set_1d(ctx, mg, ggml_view_1d(ctx, xxx, n_embed * n_tokens, n_embed * n_tokens * 4 * sizeof(float)), 0),
86398632
n_embed, n_tokens
86408633
);
86418634

8642-
struct ggml_tensor * xw = ggml_add_inplace(
8635+
struct ggml_tensor * xw = ggml_add(
86438636
ctx,
8644-
ggml_mul_inplace(
8637+
ggml_mul(
86458638
ctx,
86468639
ggml_add(ctx, mw, layer->time_mix_lerp_w),
86478640
sx
86488641
),
8649-
current
8642+
cur
86508643
);
86518644

8652-
struct ggml_tensor * xk = ggml_add_inplace(
8645+
struct ggml_tensor * xk = ggml_add(
86538646
ctx,
8654-
ggml_mul_inplace(
8647+
ggml_mul(
86558648
ctx,
86568649
ggml_add(ctx, mk, layer->time_mix_lerp_k),
86578650
sx
86588651
),
8659-
current
8652+
cur
86608653
);
86618654

8662-
struct ggml_tensor * xv = ggml_add_inplace(
8655+
struct ggml_tensor * xv = ggml_add(
86638656
ctx,
8664-
ggml_mul_inplace(
8657+
ggml_mul(
86658658
ctx,
86668659
ggml_add(ctx, mv, layer->time_mix_lerp_v),
86678660
sx
86688661
),
8669-
current
8662+
cur
86708663
);
86718664

8672-
struct ggml_tensor * xr = ggml_add_inplace(
8665+
struct ggml_tensor * xr = ggml_add(
86738666
ctx,
8674-
ggml_mul_inplace(
8667+
ggml_mul(
86758668
ctx,
86768669
ggml_add(ctx, mr, layer->time_mix_lerp_r),
86778670
sx
86788671
),
8679-
current
8672+
cur
86808673
);
86818674

8682-
struct ggml_tensor * xg = ggml_add_inplace(
8675+
struct ggml_tensor * xg = ggml_add(
86838676
ctx,
8684-
ggml_mul_inplace(
8677+
ggml_mul(
86858678
ctx,
86868679
ggml_add(ctx, mg, layer->time_mix_lerp_g),
86878680
sx
86888681
),
8689-
current
8682+
cur
86908683
);
86918684

86928685
struct ggml_tensor * r = ggml_reshape_4d(ctx, ggml_mul_mat(ctx, layer->time_mix_receptance, xr), head_size, 1, head_count, n_tokens);
86938686
struct ggml_tensor * k = ggml_reshape_4d(ctx, ggml_mul_mat(ctx, layer->time_mix_key, xk), 1, head_size, head_count, n_tokens);
86948687
struct ggml_tensor * v = ggml_reshape_4d(ctx, ggml_mul_mat(ctx, layer->time_mix_value, xv), head_size, 1, head_count, n_tokens);
8695-
struct ggml_tensor * g = ggml_silu_inplace(
8688+
struct ggml_tensor * g = ggml_silu(
86968689
ctx,
86978690
ggml_mul_mat(ctx, layer->time_mix_gate, xg)
86988691
);
86998692

87008693
struct ggml_tensor * w = ggml_mul_mat(
87018694
ctx,
87028695
layer->time_mix_decay_w2,
8703-
ggml_tanh_inplace(
8696+
ggml_tanh(
87048697
ctx,
87058698
ggml_mul_mat(ctx, layer->time_mix_decay_w1, xw)
87068699
)
87078700
);
8708-
w = ggml_add_inplace(
8709-
ctx,
8710-
w,
8711-
ggml_reshape_1d(ctx, layer->time_mix_decay, n_embed)
8712-
);
8701+
w = ggml_add(ctx, w, ggml_reshape_1d(ctx, layer->time_mix_decay, n_embed));
87138702
w = ggml_exp(ctx, ggml_neg(ctx, ggml_exp(ctx, w)));
87148703
w = ggml_reshape_4d(ctx, w, 1, head_size, head_count, n_tokens);
87158704

87168705
k = ggml_transpose(ctx, k);
87178706
v = ggml_transpose(ctx, v);
87188707
r = ggml_transpose(ctx, r);
87198708
struct ggml_tensor * wkv_output = ggml_rwkv_wkv(ctx, k, v, r, layer->time_mix_first, w, *wkv_state, state_seq);
8720-
current = ggml_view_1d(ctx, wkv_output, n_embed * n_tokens, 0);
8709+
cur = ggml_view_1d(ctx, wkv_output, n_embed * n_tokens, 0);
87218710
*wkv_state = ggml_view_1d(ctx, wkv_output, n_embed * head_size * n_kv, n_embed * n_tokens * sizeof(float));
87228711

87238712
// ggml_group_norm considers groups in the third dimension.
8724-
current = ggml_reshape_4d(ctx, current, 1, 1, n_embed, n_tokens);
8725-
current = ggml_group_norm(ctx, current, head_count, 64e-5f);
8713+
cur = ggml_reshape_4d(ctx, cur, 1, 1, n_embed, n_tokens);
8714+
cur = ggml_group_norm(ctx, cur, head_count, 64e-5f);
87268715
// Convert back to a regular vector.
8727-
current = ggml_reshape_2d(ctx, current, n_embed, n_tokens);
8728-
current = ggml_add_inplace(
8729-
ctx,
8730-
ggml_mul_inplace(
8731-
ctx,
8732-
current,
8733-
layer->time_mix_ln
8734-
),
8735-
layer->time_mix_ln_b
8736-
);
8716+
cur = ggml_reshape_2d(ctx, cur, n_embed, n_tokens);
8717+
cur = ggml_add(ctx, ggml_mul(ctx, cur, layer->time_mix_ln), layer->time_mix_ln_b);
87378718

8738-
current = ggml_mul(ctx, current, g);
8719+
cur = ggml_mul(ctx, cur, g);
87398720

8740-
return ggml_mul_mat(ctx, layer->time_mix_output, current);
8721+
return ggml_mul_mat(ctx, layer->time_mix_output, cur);
87418722
}
87428723

87438724
static struct ggml_tensor * llm_build_channel_mix(
87448725
struct ggml_context * ctx,
87458726
const struct llama_layer * layer,
8746-
struct ggml_tensor * current,
8727+
struct ggml_tensor * cur,
87478728
struct ggml_tensor * x_prev) {
8748-
struct ggml_tensor * sx = ggml_sub(ctx, x_prev, current);
8749-
struct ggml_tensor * xk = ggml_add_inplace(
8750-
ctx,
8751-
ggml_mul(ctx, sx, layer->channel_mix_lerp_k),
8752-
current
8753-
);
8754-
struct ggml_tensor * xr = ggml_add_inplace(
8729+
struct ggml_tensor * sx = ggml_sub(ctx, x_prev, cur);
8730+
struct ggml_tensor * xk = ggml_add(ctx, ggml_mul(ctx, sx, layer->channel_mix_lerp_k), cur);
8731+
struct ggml_tensor * xr = ggml_add(ctx, ggml_mul(ctx, sx, layer->channel_mix_lerp_r), cur);
8732+
8733+
struct ggml_tensor * r = ggml_sigmoid(ctx, ggml_mul_mat(ctx, layer->channel_mix_receptance, xr));
8734+
struct ggml_tensor * k = ggml_sqr(
87558735
ctx,
8756-
ggml_mul(ctx, sx, layer->channel_mix_lerp_r),
8757-
current
8736+
ggml_relu(
8737+
ctx,
8738+
ggml_mul_mat(ctx, layer->channel_mix_key, xk)
8739+
)
87588740
);
8759-
struct ggml_tensor * r = ggml_sigmoid_inplace(ctx, ggml_mul_mat(ctx, layer->channel_mix_receptance, xr));
8760-
struct ggml_tensor * k = ggml_sqr_inplace(ctx, ggml_relu_inplace(ctx, ggml_mul_mat(ctx, layer->channel_mix_key, xk)));
8761-
return ggml_mul_inplace(ctx, r, ggml_mul_mat(ctx, layer->channel_mix_value, k));
8741+
return ggml_mul(ctx, r, ggml_mul_mat(ctx, layer->channel_mix_value, k));
87628742
}
87638743

87648744
struct llm_build_context {
@@ -14165,13 +14145,12 @@ struct llm_build_context {
1416514145
// Token shift state dimensions should be 2 * n_emb
1416614146
GGML_ASSERT(n_embd == hparams.n_embd_k_s() / 2);
1416714147

14168-
// Input embeddings, start of the model after tokenizing ({n_embd, n_tokens})
1416914148
ggml_tensor * input_embeddings = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb);
1417014149

1417114150
struct ggml_tensor * state_mask = build_inp_s_mask();
1417214151
struct ggml_tensor * state_seq = build_inp_s_seq();
1417314152

14174-
ggml_tensor * x = llm_build_norm(ctx0, input_embeddings, hparams, model.tok_norm, model.tok_norm_b, LLM_NORM, cb, -1);
14153+
ggml_tensor * cur = llm_build_norm(ctx0, input_embeddings, hparams, model.tok_norm, model.tok_norm_b, LLM_NORM, cb, -1);
1417514154

1417614155
for (int layer_i = 0; layer_i < n_layer; ++layer_i) {
1417714156
const llama_layer * layer = &model.layers[layer_i];
@@ -14200,16 +14179,16 @@ struct llm_build_context {
1420014179
struct ggml_tensor * att_shift = ggml_view_1d(ctx0, token_shift, n_embd * n_kv, 0);
1420114180
struct ggml_tensor * ffn_shift = ggml_view_1d(ctx0, token_shift, n_embd * n_kv, n_embd * n_kv * ggml_element_size(kv_self.k_l[layer_i]));
1420214181

14203-
struct ggml_tensor * x_norm = llm_build_norm(ctx0, x, hparams, layer->attn_norm, layer->attn_norm_b, LLM_NORM, cb, layer_i);
14182+
struct ggml_tensor * x_norm = llm_build_norm(ctx0, cur, hparams, layer->attn_norm, layer->attn_norm_b, LLM_NORM, cb, layer_i);
1420414183
struct ggml_tensor * tmp = ggml_rwkv_token_shift(ctx0, att_shift, x_norm, state_seq);
1420514184
struct ggml_tensor * x_prev = ggml_reshape_2d(
1420614185
ctx0,
1420714186
ggml_view_1d(ctx0, tmp, n_embd * n_tokens, 0),
1420814187
n_embd, n_tokens
1420914188
);
1421014189

14211-
x = ggml_add(ctx0, x, llm_build_time_mix(ctx0, layer, x_norm, x_prev, &wkv_states, state_seq));
14212-
ggml_build_forward_expand(gf, x);
14190+
cur = ggml_add(ctx0, cur, llm_build_time_mix(ctx0, layer, x_norm, x_prev, &wkv_states, state_seq));
14191+
ggml_build_forward_expand(gf, cur);
1421314192
ggml_build_forward_expand(
1421414193
gf,
1421514194
ggml_cpy(
@@ -14237,15 +14216,15 @@ struct llm_build_context {
1423714216
)
1423814217
);
1423914218

14240-
x_norm = llm_build_norm(ctx0, x, hparams, layer->attn_norm_2, layer->attn_norm_2_b, LLM_NORM, cb, layer_i);
14219+
x_norm = llm_build_norm(ctx0, cur, hparams, layer->attn_norm_2, layer->attn_norm_2_b, LLM_NORM, cb, layer_i);
1424114220
tmp = ggml_rwkv_token_shift(ctx0, ffn_shift, x_norm, state_seq);
1424214221
x_prev = ggml_reshape_2d(
1424314222
ctx0,
1424414223
ggml_view_1d(ctx0, tmp, n_embd * n_tokens, 0),
1424514224
n_embd, n_tokens
1424614225
);
14247-
x = ggml_add(ctx0, x, llm_build_channel_mix(ctx0, layer, x_norm, x_prev));
14248-
ggml_build_forward_expand(gf, x);
14226+
cur = ggml_add(ctx0, cur, llm_build_channel_mix(ctx0, layer, x_norm, x_prev));
14227+
ggml_build_forward_expand(gf, cur);
1424914228
ggml_build_forward_expand(
1425014229
gf,
1425114230
ggml_cpy(
@@ -14279,21 +14258,18 @@ struct llm_build_context {
1427914258
);
1428014259

1428114260
if ((layer_i + 1) % hparams.rescale_every_n_layers == 0) {
14282-
x = ggml_scale(ctx0, x, 0.5F);
14261+
cur = ggml_scale(ctx0, cur, 0.5F);
1428314262
}
1428414263
}
1428514264

14286-
// Something related to skipping tokens, specifics unclear
1428714265
ggml_tensor * inp_out_ids = build_inp_out_ids();
14288-
x = ggml_get_rows(ctx0, x, inp_out_ids);
14266+
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
1428914267

14290-
// Output head, convert result vector to logits
14291-
x = llm_build_norm(ctx0, x, hparams, model.output_norm, model.output_norm_b, LLM_NORM, cb, -1);
14292-
x = ggml_mul_mat(ctx0, model.output, x);
14268+
cur = llm_build_norm(ctx0, cur, hparams, model.output_norm, model.output_norm_b, LLM_NORM, cb, -1);
14269+
cur = ggml_mul_mat(ctx0, model.output, cur);
1429314270

14294-
// Mark the output as being the result
14295-
cb(x, "result_output", -1);
14296-
ggml_build_forward_expand(gf, x);
14271+
cb(cur, "result_output", -1);
14272+
ggml_build_forward_expand(gf, cur);
1429714273

1429814274
return gf;
1429914275
}

0 commit comments

Comments
 (0)