Skip to content

Commit 7e46bde

Browse files
committed
llama: rwkv6: Apply code style and misc changes
Signed-off-by: Molly Sophia <mollysophia379@gmail.com>
1 parent c1a3a9c commit 7e46bde

File tree

2 files changed

+37
-47
lines changed

2 files changed

+37
-47
lines changed

convert_hf_to_gguf.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -297,6 +297,7 @@ def prepare_tensors(self):
297297
gguf.MODEL_TENSOR.POS_EMBD,
298298
gguf.MODEL_TENSOR.TOKEN_TYPES,
299299
gguf.MODEL_TENSOR.SSM_CONV1D,
300+
gguf.MODEL_TENSOR.TIME_MIX_FIRST,
300301
)
301302
)
302303
or not name.endswith(".weight")
@@ -2760,6 +2761,7 @@ def set_gguf_parameters(self):
27602761
self.gguf_writer.add_layer_norm_eps(layer_norm_eps)
27612762
self.gguf_writer.add_rescale_every_n_layers(rescale_every_n_layers)
27622763
self.gguf_writer.add_wkv_head_size(head_size)
2764+
self.gguf_writer.add_file_type(self.ftype)
27632765

27642766
# required by llama.cpp, unused
27652767
self.gguf_writer.add_head_count(0)

src/llama.cpp

Lines changed: 35 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -5161,6 +5161,7 @@ static const char * llama_model_type_name(e_model type) {
51615161
case MODEL_1B: return "1B";
51625162
case MODEL_1_3B: return "1.3B";
51635163
case MODEL_1_4B: return "1.4B";
5164+
case MODEL_1_6B: return "1.6B";
51645165
case MODEL_2B: return "2B";
51655166
case MODEL_2_8B: return "2.8B";
51665167
case MODEL_3B: return "3B";
@@ -15064,49 +15065,40 @@ struct llm_build_context {
1506415065
GGML_ASSERT(batch.equal_seqs);
1506515066
GGML_ASSERT(n_tokens == n_seq_tokens * n_seqs);
1506615067

15067-
ggml_tensor * input_embeddings = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb);
15068-
15068+
struct ggml_tensor * cur;
15069+
struct ggml_tensor * inpL;
1506915070
struct ggml_tensor * state_copy = build_inp_s_copy();
1507015071
struct ggml_tensor * state_mask = build_inp_s_mask();
1507115072

15072-
ggml_tensor * cur = llm_build_norm(ctx0, input_embeddings, hparams, model.tok_norm, model.tok_norm_b, LLM_NORM, cb, -1);
15073+
inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb);
15074+
inpL = llm_build_norm(ctx0, inpL, hparams, model.tok_norm, model.tok_norm_b, LLM_NORM, cb, -1);
1507315075

15074-
for (int layer_i = 0; layer_i < n_layer; ++layer_i) {
15075-
const llama_layer * layer = &model.layers[layer_i];
15076+
for (int il = 0; il < n_layer; ++il) {
15077+
const llama_layer * layer = &model.layers[il];
1507615078

1507715079
// (ab)using the KV cache to store the states
1507815080
struct ggml_tensor * token_shift = llm_build_copy_mask_state(ctx0,
15079-
gf, kv_self.k_l[layer_i], state_copy, state_mask,
15081+
gf, kv_self.k_l[il], state_copy, state_mask,
1508015082
hparams.n_embd_k_s(), kv_self.size, kv_head, n_kv, n_seqs);
1508115083
struct ggml_tensor * wkv_states = llm_build_copy_mask_state(ctx0,
15082-
gf, kv_self.v_l[layer_i], state_copy, state_mask,
15084+
gf, kv_self.v_l[il], state_copy, state_mask,
1508315085
hparams.n_embd_v_s(), kv_self.size, kv_head, n_kv, n_seqs);
1508415086

15085-
cur = ggml_reshape_3d(ctx0, cur, n_embd, n_seq_tokens, n_seqs);
15087+
cur = ggml_reshape_3d(ctx0, inpL, n_embd, n_seq_tokens, n_seqs);
15088+
token_shift = ggml_reshape_3d(ctx0, token_shift, n_embd, 2, n_seqs);
1508615089

15087-
token_shift = ggml_cont(
15088-
ctx0,
15089-
ggml_permute(
15090-
ctx0,
15091-
ggml_reshape_3d(ctx0, token_shift, n_embd, 2, n_seqs),
15092-
0, 2, 1, 3
15093-
)
15094-
);
15090+
struct ggml_tensor * att_shift = ggml_view_3d(ctx0, token_shift, n_embd, 1, n_seqs, token_shift->nb[1], token_shift->nb[2], 0);
15091+
struct ggml_tensor * ffn_shift = ggml_view_3d(ctx0, token_shift, n_embd, 1, n_seqs, token_shift->nb[1], token_shift->nb[2], n_embd * ggml_element_size(token_shift));
1509515092

15096-
struct ggml_tensor * att_shift = ggml_view_1d(ctx0, token_shift, n_embd * n_seqs, 0);
15097-
struct ggml_tensor * ffn_shift = ggml_view_1d(ctx0, token_shift, n_embd * n_seqs, n_embd * n_seqs * ggml_element_size(token_shift));
15098-
att_shift = ggml_reshape_3d(ctx0, att_shift, n_embd, 1, n_seqs);
15099-
ffn_shift = ggml_reshape_3d(ctx0, ffn_shift, n_embd, 1, n_seqs);
15100-
15101-
struct ggml_tensor * x_norm = llm_build_norm(ctx0, cur, hparams, layer->attn_norm, layer->attn_norm_b, LLM_NORM, cb, layer_i);
15093+
struct ggml_tensor * x_norm_att = llm_build_norm(ctx0, cur, hparams, layer->attn_norm, layer->attn_norm_b, LLM_NORM, cb, il);
1510215094
struct ggml_tensor * x_prev = ggml_concat(
1510315095
ctx0,
1510415096
att_shift,
15105-
ggml_view_3d(ctx0, x_norm, n_embd, n_seq_tokens - 1, n_seqs, x_norm->nb[1], x_norm->nb[2], 0),
15097+
ggml_view_3d(ctx0, x_norm_att, n_embd, n_seq_tokens - 1, n_seqs, x_norm_att->nb[1], x_norm_att->nb[2], 0),
1510615098
1
1510715099
);
1510815100

15109-
cur = ggml_add(ctx0, cur, llm_build_time_mix_rwkv6(ctx0, layer, x_norm, x_prev, &wkv_states));
15101+
cur = ggml_add(ctx0, cur, llm_build_time_mix_rwkv6(ctx0, layer, x_norm_att, x_prev, &wkv_states));
1511015102
ggml_build_forward_expand(gf, cur);
1511115103
ggml_build_forward_expand(
1511215104
gf,
@@ -15115,38 +15107,22 @@ struct llm_build_context {
1511515107
wkv_states,
1511615108
ggml_view_1d(
1511715109
ctx0,
15118-
kv_self.v_l[layer_i],
15110+
kv_self.v_l[il],
1511915111
hparams.n_embd_v_s() * n_seqs,
15120-
hparams.n_embd_v_s() * kv_head * ggml_type_size(kv_self.v_l[layer_i]->type)
15112+
hparams.n_embd_v_s() * kv_head * ggml_element_size(kv_self.v_l[il])
1512115113
)
1512215114
)
1512315115
);
15124-
struct ggml_tensor * last_norm = ggml_view_3d(ctx0, x_norm, n_embd, 1, n_seqs, x_norm->nb[1], x_norm->nb[2], (n_seq_tokens-1)*n_embd*ggml_element_size(x_norm));
15125-
ggml_build_forward_expand(
15126-
gf,
15127-
ggml_cpy(
15128-
ctx0, last_norm,
15129-
ggml_view_1d(ctx0, token_shift, n_embd * n_seqs, 0)
15130-
)
15131-
);
1513215116

15133-
x_norm = llm_build_norm(ctx0, cur, hparams, layer->attn_norm_2, layer->attn_norm_2_b, LLM_NORM, cb, layer_i);
15117+
ggml_tensor * x_norm_ffn = llm_build_norm(ctx0, cur, hparams, layer->attn_norm_2, layer->attn_norm_2_b, LLM_NORM, cb, il);
1513415118
x_prev = ggml_concat(
1513515119
ctx0,
1513615120
ffn_shift,
15137-
ggml_view_3d(ctx0, x_norm, n_embd, n_seq_tokens - 1, n_seqs, x_norm->nb[1], x_norm->nb[2], 0),
15121+
ggml_view_3d(ctx0, x_norm_ffn, n_embd, n_seq_tokens - 1, n_seqs, x_norm_ffn->nb[1], x_norm_ffn->nb[2], 0),
1513815122
1
1513915123
);
15140-
cur = ggml_add(ctx0, cur, llm_build_channel_mix_rwkv6(ctx0, layer, x_norm, x_prev));
15141-
last_norm = ggml_view_3d(ctx0, x_norm, n_embd, 1, n_seqs, x_norm->nb[1], x_norm->nb[2], (n_seq_tokens-1)*n_embd*ggml_element_size(x_norm));
15124+
cur = ggml_add(ctx0, cur, llm_build_channel_mix_rwkv6(ctx0, layer, x_norm_ffn, x_prev));
1514215125
ggml_build_forward_expand(gf, cur);
15143-
ggml_build_forward_expand(
15144-
gf,
15145-
ggml_cpy(
15146-
ctx0, last_norm,
15147-
ggml_view_1d(ctx0, token_shift, n_embd * n_seqs, n_embd * n_seqs * ggml_element_size(token_shift))
15148-
)
15149-
);
1515015126

1515115127
token_shift = ggml_cont(
1515215128
ctx0,
@@ -15157,20 +15133,32 @@ struct llm_build_context {
1515715133
)
1515815134
);
1515915135

15136+
struct ggml_tensor * last_norm_att = ggml_view_3d(ctx0, x_norm_att, n_embd, 1, n_seqs, x_norm_att->nb[1], x_norm_att->nb[2], (n_seq_tokens-1)*n_embd*ggml_element_size(x_norm_att));
15137+
struct ggml_tensor * last_norm_ffn = ggml_view_3d(ctx0, x_norm_ffn, n_embd, 1, n_seqs, x_norm_ffn->nb[1], x_norm_ffn->nb[2], (n_seq_tokens-1)*n_embd*ggml_element_size(x_norm_ffn));
15138+
15139+
token_shift = ggml_concat(ctx0, last_norm_att, last_norm_ffn, 1);
15140+
1516015141
ggml_build_forward_expand(
1516115142
gf,
1516215143
ggml_cpy(
1516315144
ctx0,
1516415145
ggml_view_1d(ctx0, token_shift, n_embd * n_seqs * 2, 0),
15165-
ggml_view_1d(ctx0, kv_self.k_l[layer_i], hparams.n_embd_k_s() * n_seqs, hparams.n_embd_k_s() * kv_head * ggml_type_size(kv_self.k_l[layer_i]->type))
15146+
ggml_view_1d(ctx0, kv_self.k_l[il], hparams.n_embd_k_s() * n_seqs, hparams.n_embd_k_s() * kv_head * ggml_element_size(kv_self.k_l[il]))
1516615147
)
1516715148
);
1516815149

15169-
if ((layer_i + 1) % hparams.rescale_every_n_layers == 0) {
15150+
if ((il + 1) % hparams.rescale_every_n_layers == 0) {
1517015151
cur = ggml_scale(ctx0, cur, 0.5F);
1517115152
}
15153+
15154+
cur = lctx.cvec.apply_to(ctx0, cur, il);
15155+
cb(cur, "l_out", il);
15156+
15157+
// input for next layer
15158+
inpL = cur;
1517215159
}
1517315160

15161+
cur = inpL;
1517415162
ggml_tensor * inp_out_ids = build_inp_out_ids();
1517515163
cur = ggml_reshape_2d(ctx0, cur, n_embd, n_tokens);
1517615164
cur = ggml_get_rows(ctx0, cur, inp_out_ids);

0 commit comments

Comments
 (0)