Skip to content

Commit 7f9de4a

Browse files
committed
Use MODEL_ARCH.RWKV6 instead of MODEL_ARCH.RWKV
Signed-off-by: Molly Sophia <mollysophia379@gmail.com>
1 parent 0a630dc commit 7f9de4a

File tree

3 files changed

+19
-19
lines changed

3 files changed

+19
-19
lines changed

convert_hf_to_gguf.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2705,7 +2705,7 @@ class StarCoder2Model(Model):
27052705

27062706
@Model.register("Rwkv6ForCausalLM")
27072707
class RwkvModel(Model):
2708-
model_arch = gguf.MODEL_ARCH.RWKV
2708+
model_arch = gguf.MODEL_ARCH.RWKV6
27092709

27102710
def set_vocab(self):
27112711
assert (self.dir_model / "rwkv_vocab_v20230424.txt").is_file()

gguf-py/gguf/constants.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -210,7 +210,7 @@ class MODEL_ARCH(IntEnum):
210210
GEMMA = auto()
211211
GEMMA2 = auto()
212212
STARCODER2 = auto()
213-
RWKV = auto()
213+
RWKV6 = auto()
214214
MAMBA = auto()
215215
XVERSE = auto()
216216
COMMAND_R = auto()
@@ -362,7 +362,7 @@ class MODEL_TENSOR(IntEnum):
362362
MODEL_ARCH.GEMMA: "gemma",
363363
MODEL_ARCH.GEMMA2: "gemma2",
364364
MODEL_ARCH.STARCODER2: "starcoder2",
365-
MODEL_ARCH.RWKV: "rwkv",
365+
MODEL_ARCH.RWKV6: "rwkv6",
366366
MODEL_ARCH.MAMBA: "mamba",
367367
MODEL_ARCH.XVERSE: "xverse",
368368
MODEL_ARCH.COMMAND_R: "command-r",
@@ -903,7 +903,7 @@ class MODEL_TENSOR(IntEnum):
903903
MODEL_TENSOR.FFN_DOWN,
904904
MODEL_TENSOR.FFN_UP,
905905
],
906-
MODEL_ARCH.RWKV: [
906+
MODEL_ARCH.RWKV6: [
907907
MODEL_TENSOR.TOKEN_EMBD,
908908
MODEL_TENSOR.TOKEN_EMBD_NORM,
909909
MODEL_TENSOR.OUTPUT_NORM,

src/llama.cpp

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -210,7 +210,7 @@ enum llm_arch {
210210
LLM_ARCH_T5,
211211
LLM_ARCH_T5ENCODER,
212212
LLM_ARCH_JAIS,
213-
LLM_ARCH_RWKV,
213+
LLM_ARCH_RWKV6,
214214
LLM_ARCH_UNKNOWN,
215215
};
216216

@@ -256,7 +256,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
256256
{ LLM_ARCH_T5, "t5" },
257257
{ LLM_ARCH_T5ENCODER, "t5encoder" },
258258
{ LLM_ARCH_JAIS, "jais" },
259-
{ LLM_ARCH_RWKV, "rwkv" },
259+
{ LLM_ARCH_RWKV6, "rwkv6" },
260260
{ LLM_ARCH_UNKNOWN, "(unknown)" },
261261
};
262262

@@ -1328,7 +1328,7 @@ static const std::map<llm_arch, std::map<llm_tensor, std::string>> LLM_TENSOR_NA
13281328
},
13291329
},
13301330
{
1331-
LLM_ARCH_RWKV,
1331+
LLM_ARCH_RWKV6,
13321332
{
13331333
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
13341334
{ LLM_TENSOR_TOKEN_EMBD_NORM, "token_embd_norm" },
@@ -3052,7 +3052,7 @@ static bool llama_kv_cache_init(
30523052
cache.has_shift = false;
30533053

30543054
// TODO: find a nicer way to add other recurrent model architectures
3055-
cache.recurrent = model.arch == LLM_ARCH_MAMBA || model.arch == LLM_ARCH_RWKV;
3055+
cache.recurrent = model.arch == LLM_ARCH_MAMBA || model.arch == LLM_ARCH_RWKV6;
30563056
cache.v_trans = !cache.recurrent && !cparams.flash_attn;
30573057

30583058
cache.head = 0;
@@ -5348,7 +5348,7 @@ static void llm_load_hparams(
53485348
default: model.type = e_model::MODEL_UNKNOWN;
53495349
}
53505350
} break;
5351-
case LLM_ARCH_RWKV:
5351+
case LLM_ARCH_RWKV6:
53525352
{
53535353
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps);
53545354
ml.get_key(LLM_KV_WKV_HEAD_SIZE, hparams.wkv_head_size);
@@ -7700,7 +7700,7 @@ static bool llm_load_tensors(
77007700
layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd});
77017701
}
77027702
} break;
7703-
case LLM_ARCH_RWKV:
7703+
case LLM_ARCH_RWKV6:
77047704
{
77057705
model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab});
77067706

@@ -8555,7 +8555,7 @@ static struct ggml_tensor * llm_build_kv(
85558555
}
85568556

85578557

8558-
static struct ggml_tensor * llm_build_time_mix(
8558+
static struct ggml_tensor * llm_build_time_mix_rwkv6(
85598559
struct ggml_context * ctx,
85608560
const struct llama_layer * layer,
85618561
struct ggml_tensor * cur,
@@ -8716,7 +8716,7 @@ static struct ggml_tensor * llm_build_time_mix(
87168716
return ggml_mul_mat(ctx, layer->time_mix_output, cur);
87178717
}
87188718

8719-
static struct ggml_tensor * llm_build_channel_mix(
8719+
static struct ggml_tensor * llm_build_channel_mix_rwkv6(
87208720
struct ggml_context * ctx,
87218721
const struct llama_layer * layer,
87228722
struct ggml_tensor * cur,
@@ -14134,7 +14134,7 @@ struct llm_build_context {
1413414134
return gf;
1413514135
}
1413614136

14137-
ggml_cgraph * build_rwkv() {
14137+
ggml_cgraph * build_rwkv6() {
1413814138
ggml_cgraph *gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
1413914139

1414014140
// Token shift state dimensions should be 2 * n_emb
@@ -14182,7 +14182,7 @@ struct llm_build_context {
1418214182
n_embd, n_tokens
1418314183
);
1418414184

14185-
cur = ggml_add(ctx0, cur, llm_build_time_mix(ctx0, layer, x_norm, x_prev, &wkv_states, state_seq));
14185+
cur = ggml_add(ctx0, cur, llm_build_time_mix_rwkv6(ctx0, layer, x_norm, x_prev, &wkv_states, state_seq));
1418614186
ggml_build_forward_expand(gf, cur);
1418714187
ggml_build_forward_expand(
1418814188
gf,
@@ -14218,7 +14218,7 @@ struct llm_build_context {
1421814218
ggml_view_1d(ctx0, tmp, n_embd * n_tokens, 0),
1421914219
n_embd, n_tokens
1422014220
);
14221-
cur = ggml_add(ctx0, cur, llm_build_channel_mix(ctx0, layer, x_norm, x_prev));
14221+
cur = ggml_add(ctx0, cur, llm_build_channel_mix_rwkv6(ctx0, layer, x_norm, x_prev));
1422214222
ggml_build_forward_expand(gf, cur);
1422314223
ggml_build_forward_expand(
1422414224
gf,
@@ -14523,9 +14523,9 @@ static struct ggml_cgraph * llama_build_graph(
1452314523
{
1452414524
result = llm.build_jais();
1452514525
} break;
14526-
case LLM_ARCH_RWKV:
14526+
case LLM_ARCH_RWKV6:
1452714527
{
14528-
result = llm.build_rwkv();
14528+
result = llm.build_rwkv6();
1452914529
} break;
1453014530
default:
1453114531
GGML_ABORT("fatal error");
@@ -17250,7 +17250,7 @@ struct llama_context * llama_new_context_with_model(
1725017250
ggml_type type_v = params.type_v;
1725117251

1725217252
// Mamba and RWKV only need a constant number of KV cache cells per sequence
17253-
if (model->arch == LLM_ARCH_MAMBA || model->arch == LLM_ARCH_RWKV) {
17253+
if (model->arch == LLM_ARCH_MAMBA || model->arch == LLM_ARCH_RWKV6) {
1725417254
// Mamba and RWKV need at least as many KV cells as there are sequences kept at any time
1725517255
kv_size = std::max((uint32_t) 1, params.n_seq_max);
1725617256
// it's probably best to keep as much precision as possible for the states
@@ -17560,7 +17560,7 @@ enum llama_rope_type llama_rope_type(const struct llama_model * model) {
1756017560
case LLM_ARCH_T5:
1756117561
case LLM_ARCH_T5ENCODER:
1756217562
case LLM_ARCH_JAIS:
17563-
case LLM_ARCH_RWKV:
17563+
case LLM_ARCH_RWKV6:
1756417564
return LLAMA_ROPE_TYPE_NONE;
1756517565

1756617566
// use what we call a normal RoPE, operating on pairs of consecutive head values

0 commit comments

Comments
 (0)