@@ -9384,6 +9384,7 @@ static struct ggml_tensor * llm_build_mamba(
9384
9384
}
9385
9385
9386
9386
static struct ggml_tensor * llm_build_time_mix_rwkv6(
9387
+ struct llama_context & lctx,
9387
9388
struct ggml_context * ctx,
9388
9389
const struct llama_layer * layer,
9389
9390
struct ggml_tensor * cur,
@@ -9481,12 +9482,12 @@ static struct ggml_tensor * llm_build_time_mix_rwkv6(
9481
9482
cur
9482
9483
);
9483
9484
9484
- struct ggml_tensor * r = ggml_reshape_4d(ctx, ggml_mul_mat( ctx, layer->time_mix_receptance, xr), head_size, 1, head_count, n_tokens);
9485
- struct ggml_tensor * k = ggml_reshape_4d(ctx, ggml_mul_mat( ctx, layer->time_mix_key, xk), 1, head_size, head_count, n_tokens);
9486
- struct ggml_tensor * v = ggml_reshape_4d(ctx, ggml_mul_mat( ctx, layer->time_mix_value, xv), head_size, 1, head_count, n_tokens);
9485
+ struct ggml_tensor * r = ggml_reshape_4d(ctx, llm_build_lora_mm(lctx, ctx, layer->time_mix_receptance, xr), head_size, 1, head_count, n_tokens);
9486
+ struct ggml_tensor * k = ggml_reshape_4d(ctx, llm_build_lora_mm(lctx, ctx, layer->time_mix_key, xk), 1, head_size, head_count, n_tokens);
9487
+ struct ggml_tensor * v = ggml_reshape_4d(ctx, llm_build_lora_mm(lctx, ctx, layer->time_mix_value, xv), head_size, 1, head_count, n_tokens);
9487
9488
struct ggml_tensor * g = ggml_silu(
9488
9489
ctx,
9489
- ggml_mul_mat( ctx, layer->time_mix_gate, xg)
9490
+ llm_build_lora_mm(lctx, ctx, layer->time_mix_gate, xg)
9490
9491
);
9491
9492
9492
9493
struct ggml_tensor * w = ggml_mul_mat(
@@ -9516,12 +9517,13 @@ static struct ggml_tensor * llm_build_time_mix_rwkv6(
9516
9517
cur = ggml_add(ctx, ggml_mul(ctx, cur, layer->time_mix_ln), layer->time_mix_ln_b);
9517
9518
9518
9519
cur = ggml_mul(ctx, cur, g);
9519
- cur = ggml_mul_mat( ctx, layer->time_mix_output, cur);
9520
+ cur = llm_build_lora_mm(lctx, ctx, layer->time_mix_output, cur);
9520
9521
9521
9522
return ggml_reshape_3d(ctx, cur, n_embed, n_seq_tokens, n_seqs);
9522
9523
}
9523
9524
9524
9525
static struct ggml_tensor * llm_build_channel_mix_rwkv6(
9526
+ struct llama_context & lctx,
9525
9527
struct ggml_context * ctx,
9526
9528
const struct llama_layer * layer,
9527
9529
struct ggml_tensor * cur,
@@ -9530,15 +9532,15 @@ static struct ggml_tensor * llm_build_channel_mix_rwkv6(
9530
9532
struct ggml_tensor * xk = ggml_add(ctx, ggml_mul(ctx, sx, layer->channel_mix_lerp_k), cur);
9531
9533
struct ggml_tensor * xr = ggml_add(ctx, ggml_mul(ctx, sx, layer->channel_mix_lerp_r), cur);
9532
9534
9533
- struct ggml_tensor * r = ggml_sigmoid(ctx, ggml_mul_mat( ctx, layer->channel_mix_receptance, xr));
9535
+ struct ggml_tensor * r = ggml_sigmoid(ctx, llm_build_lora_mm(lctx, ctx, layer->channel_mix_receptance, xr));
9534
9536
struct ggml_tensor * k = ggml_sqr(
9535
9537
ctx,
9536
9538
ggml_relu(
9537
9539
ctx,
9538
- ggml_mul_mat( ctx, layer->channel_mix_key, xk)
9540
+ llm_build_lora_mm(lctx, ctx, layer->channel_mix_key, xk)
9539
9541
)
9540
9542
);
9541
- return ggml_mul(ctx, r, ggml_mul_mat( ctx, layer->channel_mix_value, k));
9543
+ return ggml_mul(ctx, r, llm_build_lora_mm(lctx, ctx, layer->channel_mix_value, k));
9542
9544
}
9543
9545
9544
9546
struct llm_build_context {
@@ -15109,7 +15111,7 @@ struct llm_build_context {
15109
15111
1
15110
15112
);
15111
15113
15112
- cur = ggml_add(ctx0, cur, llm_build_time_mix_rwkv6(ctx0, layer, x_norm_att, x_prev, &wkv_states));
15114
+ cur = ggml_add(ctx0, cur, llm_build_time_mix_rwkv6(lctx, ctx0, layer, x_norm_att, x_prev, &wkv_states));
15113
15115
ggml_build_forward_expand(gf, cur);
15114
15116
ggml_build_forward_expand(
15115
15117
gf,
@@ -15132,7 +15134,7 @@ struct llm_build_context {
15132
15134
ggml_view_3d(ctx0, x_norm_ffn, n_embd, n_seq_tokens - 1, n_seqs, x_norm_ffn->nb[1], x_norm_ffn->nb[2], 0),
15133
15135
1
15134
15136
);
15135
- cur = ggml_add(ctx0, cur, llm_build_channel_mix_rwkv6(ctx0, layer, x_norm_ffn, x_prev));
15137
+ cur = ggml_add(ctx0, cur, llm_build_channel_mix_rwkv6(lctx, ctx0, layer, x_norm_ffn, x_prev));
15136
15138
ggml_build_forward_expand(gf, cur);
15137
15139
15138
15140
struct ggml_tensor * last_norm_att = ggml_view_3d(ctx0, x_norm_att, n_embd, 1, n_seqs, x_norm_att->nb[1], x_norm_att->nb[2], (n_seq_tokens-1)*n_embd*ggml_element_size(x_norm_att));
@@ -15166,7 +15168,7 @@ struct llm_build_context {
15166
15168
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
15167
15169
15168
15170
cur = llm_build_norm(ctx0, cur, hparams, model.output_norm, model.output_norm_b, LLM_NORM, cb, -1);
15169
- cur = ggml_mul_mat( ctx0, model.output, cur);
15171
+ cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
15170
15172
15171
15173
cb(cur, "result_output", -1);
15172
15174
ggml_build_forward_expand(gf, cur);
0 commit comments