Skip to content

Commit 7f2ef56

Browse files
committed
llama: rwkv6: Add lora for some supported tensors
Currently att.key/receptance/value/gate/output, ffn.receptance/key/value, as well as head.weight Signed-off-by: Molly Sophia <mollysophia379@gmail.com>
1 parent 7444046 commit 7f2ef56

File tree

1 file changed

+13
-11
lines changed

1 file changed

+13
-11
lines changed

src/llama.cpp

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -9384,6 +9384,7 @@ static struct ggml_tensor * llm_build_mamba(
93849384
}
93859385

93869386
static struct ggml_tensor * llm_build_time_mix_rwkv6(
9387+
struct llama_context & lctx,
93879388
struct ggml_context * ctx,
93889389
const struct llama_layer * layer,
93899390
struct ggml_tensor * cur,
@@ -9481,12 +9482,12 @@ static struct ggml_tensor * llm_build_time_mix_rwkv6(
94819482
cur
94829483
);
94839484

9484-
struct ggml_tensor * r = ggml_reshape_4d(ctx, ggml_mul_mat(ctx, layer->time_mix_receptance, xr), head_size, 1, head_count, n_tokens);
9485-
struct ggml_tensor * k = ggml_reshape_4d(ctx, ggml_mul_mat(ctx, layer->time_mix_key, xk), 1, head_size, head_count, n_tokens);
9486-
struct ggml_tensor * v = ggml_reshape_4d(ctx, ggml_mul_mat(ctx, layer->time_mix_value, xv), head_size, 1, head_count, n_tokens);
9485+
struct ggml_tensor * r = ggml_reshape_4d(ctx, llm_build_lora_mm(lctx, ctx, layer->time_mix_receptance, xr), head_size, 1, head_count, n_tokens);
9486+
struct ggml_tensor * k = ggml_reshape_4d(ctx, llm_build_lora_mm(lctx, ctx, layer->time_mix_key, xk), 1, head_size, head_count, n_tokens);
9487+
struct ggml_tensor * v = ggml_reshape_4d(ctx, llm_build_lora_mm(lctx, ctx, layer->time_mix_value, xv), head_size, 1, head_count, n_tokens);
94879488
struct ggml_tensor * g = ggml_silu(
94889489
ctx,
9489-
ggml_mul_mat(ctx, layer->time_mix_gate, xg)
9490+
llm_build_lora_mm(lctx, ctx, layer->time_mix_gate, xg)
94909491
);
94919492

94929493
struct ggml_tensor * w = ggml_mul_mat(
@@ -9516,12 +9517,13 @@ static struct ggml_tensor * llm_build_time_mix_rwkv6(
95169517
cur = ggml_add(ctx, ggml_mul(ctx, cur, layer->time_mix_ln), layer->time_mix_ln_b);
95179518

95189519
cur = ggml_mul(ctx, cur, g);
9519-
cur = ggml_mul_mat(ctx, layer->time_mix_output, cur);
9520+
cur = llm_build_lora_mm(lctx, ctx, layer->time_mix_output, cur);
95209521

95219522
return ggml_reshape_3d(ctx, cur, n_embed, n_seq_tokens, n_seqs);
95229523
}
95239524

95249525
static struct ggml_tensor * llm_build_channel_mix_rwkv6(
9526+
struct llama_context & lctx,
95259527
struct ggml_context * ctx,
95269528
const struct llama_layer * layer,
95279529
struct ggml_tensor * cur,
@@ -9530,15 +9532,15 @@ static struct ggml_tensor * llm_build_channel_mix_rwkv6(
95309532
struct ggml_tensor * xk = ggml_add(ctx, ggml_mul(ctx, sx, layer->channel_mix_lerp_k), cur);
95319533
struct ggml_tensor * xr = ggml_add(ctx, ggml_mul(ctx, sx, layer->channel_mix_lerp_r), cur);
95329534

9533-
struct ggml_tensor * r = ggml_sigmoid(ctx, ggml_mul_mat(ctx, layer->channel_mix_receptance, xr));
9535+
struct ggml_tensor * r = ggml_sigmoid(ctx, llm_build_lora_mm(lctx, ctx, layer->channel_mix_receptance, xr));
95349536
struct ggml_tensor * k = ggml_sqr(
95359537
ctx,
95369538
ggml_relu(
95379539
ctx,
9538-
ggml_mul_mat(ctx, layer->channel_mix_key, xk)
9540+
llm_build_lora_mm(lctx, ctx, layer->channel_mix_key, xk)
95399541
)
95409542
);
9541-
return ggml_mul(ctx, r, ggml_mul_mat(ctx, layer->channel_mix_value, k));
9543+
return ggml_mul(ctx, r, llm_build_lora_mm(lctx, ctx, layer->channel_mix_value, k));
95429544
}
95439545

95449546
struct llm_build_context {
@@ -15109,7 +15111,7 @@ struct llm_build_context {
1510915111
1
1511015112
);
1511115113

15112-
cur = ggml_add(ctx0, cur, llm_build_time_mix_rwkv6(ctx0, layer, x_norm_att, x_prev, &wkv_states));
15114+
cur = ggml_add(ctx0, cur, llm_build_time_mix_rwkv6(lctx, ctx0, layer, x_norm_att, x_prev, &wkv_states));
1511315115
ggml_build_forward_expand(gf, cur);
1511415116
ggml_build_forward_expand(
1511515117
gf,
@@ -15132,7 +15134,7 @@ struct llm_build_context {
1513215134
ggml_view_3d(ctx0, x_norm_ffn, n_embd, n_seq_tokens - 1, n_seqs, x_norm_ffn->nb[1], x_norm_ffn->nb[2], 0),
1513315135
1
1513415136
);
15135-
cur = ggml_add(ctx0, cur, llm_build_channel_mix_rwkv6(ctx0, layer, x_norm_ffn, x_prev));
15137+
cur = ggml_add(ctx0, cur, llm_build_channel_mix_rwkv6(lctx, ctx0, layer, x_norm_ffn, x_prev));
1513615138
ggml_build_forward_expand(gf, cur);
1513715139

1513815140
struct ggml_tensor * last_norm_att = ggml_view_3d(ctx0, x_norm_att, n_embd, 1, n_seqs, x_norm_att->nb[1], x_norm_att->nb[2], (n_seq_tokens-1)*n_embd*ggml_element_size(x_norm_att));
@@ -15166,7 +15168,7 @@ struct llm_build_context {
1516615168
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
1516715169

1516815170
cur = llm_build_norm(ctx0, cur, hparams, model.output_norm, model.output_norm_b, LLM_NORM, cb, -1);
15169-
cur = ggml_mul_mat(ctx0, model.output, cur);
15171+
cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
1517015172

1517115173
cb(cur, "result_output", -1);
1517215174
ggml_build_forward_expand(gf, cur);

0 commit comments

Comments
 (0)