Skip to content

Commit bc8b9df

Browse files
committed
feat: Add support for distinguishing recurrent vs non-recurrent layers in hparams
Branch: GraniteFour Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>
1 parent 3754cbd commit bc8b9df

File tree

2 files changed

+20
-4
lines changed

2 files changed

+20
-4
lines changed

src/llama-hparams.cpp

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,10 @@ uint32_t llama_hparams::n_embd_v_gqa(uint32_t il) const {
4949
return n_embd_head_v * n_head_kv;
5050
}
5151

52-
uint32_t llama_hparams::n_embd_k_s() const {
52+
uint32_t llama_hparams::n_embd_k_s(uint32_t il) const {
53+
if (!recurrent_layer(il)) {
54+
return 0;
55+
}
5356
if (wkv_head_size != 0) {
5457
// for RWKV models
5558
return token_shift_count * n_embd;
@@ -60,7 +63,10 @@ uint32_t llama_hparams::n_embd_k_s() const {
6063
return (ssm_d_conv > 0 ? ssm_d_conv - 1 : 0) * ssm_d_inner;
6164
}
6265

63-
uint32_t llama_hparams::n_embd_v_s() const {
66+
uint32_t llama_hparams::n_embd_v_s(uint32_t il) const {
67+
if (!recurrent_layer(il)) {
68+
return 0;
69+
}
6470
if (wkv_head_size != 0) {
6571
// corresponds to RWKV's wkv_states size
6672
return n_embd * wkv_head_size;
@@ -70,6 +76,10 @@ uint32_t llama_hparams::n_embd_v_s() const {
7076
return ssm_d_state * ssm_d_inner;
7177
}
7278

79+
bool llama_hparams::recurrent_layer(uint32_t il) const {
80+
return recurrent_layer_arr[il];
81+
}
82+
7383
bool llama_hparams::is_swa(uint32_t il) const {
7484
if (il < n_layer) {
7585
return n_swa > 0 && n_swa_pattern > 0 && il % n_swa_pattern < (n_swa_pattern - 1);

src/llama-hparams.h

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,9 @@ struct llama_hparams {
112112
uint32_t ssm_d_state = 0;
113113
uint32_t ssm_dt_rank = 0;
114114

115+
// for hybrid state space models
116+
std::array<bool, LLAMA_MAX_LAYERS> recurrent_layer_arr;
117+
115118
bool ssm_dt_b_c_rms = false;
116119

117120
float f_clamp_kqv = 0.0f;
@@ -158,10 +161,13 @@ struct llama_hparams {
158161

159162
// dimension of the rolling state embeddings
160163
// corresponds to Mamba's conv_states size or RWKV's token_shift states size
161-
uint32_t n_embd_k_s() const;
164+
uint32_t n_embd_k_s(uint32_t il = 0) const;
162165

163166
// dimension of the recurrent state embeddings
164-
uint32_t n_embd_v_s() const;
167+
uint32_t n_embd_v_s(uint32_t il = 0) const;
168+
169+
// whether or not the given layer is recurrent (for hybrid models)
170+
bool recurrent_layer(uint32_t il) const;
165171

166172
bool is_swa(uint32_t il) const;
167173
};

0 commit comments

Comments
 (0)