@@ -298,6 +298,8 @@ enum llm_kv {
298
298
LLM_KV_ATTN_LOGIT_SOFTCAPPING,
299
299
LLM_KV_FINAL_LOGIT_SOFTCAPPING,
300
300
LLM_KV_RESCALE_EVERY_N_LAYERS,
301
+ LLM_KV_TIME_MIX_EXTRA_DIM,
302
+ LLM_KV_TIME_DECAY_EXTRA_DIM,
301
303
302
304
LLM_KV_ATTENTION_HEAD_COUNT,
303
305
LLM_KV_ATTENTION_HEAD_COUNT_KV,
@@ -400,6 +402,8 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
400
402
{ LLM_KV_ATTN_LOGIT_SOFTCAPPING, "%s.attn_logit_softcapping" },
401
403
{ LLM_KV_FINAL_LOGIT_SOFTCAPPING, "%s.final_logit_softcapping" },
402
404
{ LLM_KV_RESCALE_EVERY_N_LAYERS, "%s.rescale_every_n_layers" },
405
+ { LLM_KV_TIME_MIX_EXTRA_DIM, "%s.time_mix_extra_dim" },
406
+ { LLM_KV_TIME_DECAY_EXTRA_DIM, "%s.time_decay_extra_dim" },
403
407
404
408
{ LLM_KV_ATTENTION_HEAD_COUNT, "%s.attention.head_count" },
405
409
{ LLM_KV_ATTENTION_HEAD_COUNT_KV, "%s.attention.head_count_kv" },
@@ -2296,6 +2300,8 @@ struct llama_hparams {
2296
2300
2297
2301
// for RWKV
2298
2302
uint32_t rescale_every_n_layers = 0;
2303
+ uint32_t time_mix_extra_dim = 0;
2304
+ uint32_t time_decay_extra_dim = 0;
2299
2305
uint32_t wkv_head_size = 0;
2300
2306
2301
2307
float rope_attn_factor = 1.0f;
@@ -2362,6 +2368,8 @@ struct llama_hparams {
2362
2368
if (this->ssm_dt_b_c_rms != other.ssm_dt_b_c_rms) return true;
2363
2369
2364
2370
if (this->rescale_every_n_layers != other.rescale_every_n_layers) return true;
2371
+ if (this->time_mix_extra_dim != other.time_mix_extra_dim) return true;
2372
+ if (this->time_decay_extra_dim != other.time_decay_extra_dim) return true;
2365
2373
if (this->wkv_head_size != other.wkv_head_size) return true;
2366
2374
2367
2375
if (this->dec_start_token_id != other.dec_start_token_id) return true;
@@ -5909,6 +5917,8 @@ static void llm_load_hparams(
5909
5917
{
5910
5918
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps);
5911
5919
ml.get_key(LLM_KV_WKV_HEAD_SIZE, hparams.wkv_head_size);
5920
+ ml.get_key(LLM_KV_TIME_MIX_EXTRA_DIM, hparams.time_mix_extra_dim);
5921
+ ml.get_key(LLM_KV_TIME_DECAY_EXTRA_DIM, hparams.time_decay_extra_dim);
5912
5922
ml.get_key(LLM_KV_RESCALE_EVERY_N_LAYERS, hparams.rescale_every_n_layers, false);
5913
5923
5914
5924
switch (hparams.n_layer) {
@@ -8364,8 +8374,8 @@ static bool llm_load_tensors(
8364
8374
model.output_norm_b = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd});
8365
8375
model.output = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab});
8366
8376
8367
- const int time_mix_extra_dim = (n_embd == 4096) ? 64 : 32 ;
8368
- const int time_decay_extra_dim = (n_embd == 4096) ? 128 : 64 ;
8377
+ const int time_mix_extra_dim = hparams.time_mix_extra_dim ;
8378
+ const int time_decay_extra_dim = hparams.time_decay_extra_dim ;
8369
8379
const int head_size = hparams.wkv_head_size;
8370
8380
const int attn_hidden_size = n_embd;
8371
8381
const int ffn_size = hparams.n_ff_arr[0];
0 commit comments