Skip to content

Commit 27d5dcf

Browse files
committed
mamba : stop abusing attention metadata
This breaks existing converted-to-GGUF Mamba models, but will allow supporting mixed architectures like MambaFormer without needing to break Mamba models. This will also allow changing the size of Mamba's states without having to reconvert models in the future. (e.g. using something else than d_conv - 1 columns for the conv_states will not require breaking existing converted Mamba models again) * gguf-py : add new KV metadata key-value pairs for Mamba * llama : add new metadata key-value pairs for Mamba * llama : guard against divisions by zero when n_head is 0 * mamba : rename "unlimited" KV cache property to "recurrent"
1 parent 49d865f commit 27d5dcf

File tree

4 files changed

+128
-49
lines changed

4 files changed

+128
-49
lines changed

convert-hf-to-gguf.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1862,21 +1862,28 @@ def set_vocab(self):
18621862

18631863
def set_gguf_parameters(self):
18641864
d_model = self.hparams["d_model"]
1865+
d_conv = self.hparams.get("d_conv", 4)
18651866
d_inner = self.hparams.get("d_inner", 2 * d_model)
1867+
d_state = self.hparams.get("d_state", 16)
1868+
# ceiling division
1869+
# ref: https://stackoverflow.com/a/17511341/22827863
1870+
# ref: https://github.com/state-spaces/mamba/blob/ce59daea3a090d011d6476c6e5b97f6d58ddad8b/mamba_ssm/modules/mamba_simple.py#L58
1871+
dt_rank = self.hparams.get("dt_rank", -(d_model // -16))
1872+
18661873
# Fail early for models which don't have a block expansion factor of 2
18671874
assert d_inner == 2 * d_model
18681875

18691876
self.gguf_writer.add_name(self.dir_model.name)
18701877
self.gguf_writer.add_context_length(2**20) # arbitrary value; for those who use the default
18711878
self.gguf_writer.add_embedding_length(d_model)
18721879
self.gguf_writer.add_feed_forward_length(0) # unused, but seemingly required when loading
1873-
self.gguf_writer.add_head_count(d_inner) # the number of rows in conv_state and ssm_state
1880+
self.gguf_writer.add_head_count(0) # unused, but seemingly required when loading
18741881
self.gguf_writer.add_block_count(self.hparams["n_layer"])
1882+
self.gguf_writer.add_ssm_conv_kernel_size(d_conv)
1883+
self.gguf_writer.add_ssm_inner_length(d_inner)
1884+
self.gguf_writer.add_ssm_state_length(d_state)
1885+
self.gguf_writer.add_ssm_dt_rank(dt_rank)
18751886
self.gguf_writer.add_layer_norm_rms_eps(self.hparams.get("rms_norm_eps", 1e-5))
1876-
# NOTE: (ab)using the KV cache metadata to store dimensions for conv_state and ssm_state
1877-
# Since the first column of the conv_state is shifted out each time, it's not actually needed
1878-
self.gguf_writer.add_key_length(self.hparams.get("d_conv", 4) - 1)
1879-
self.gguf_writer.add_value_length(self.hparams.get("d_state", 16))
18801887
self.gguf_writer.add_file_type(self.ftype)
18811888

18821889
def write_tensors(self):

gguf-py/gguf/constants.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,12 @@ class Rope:
6161
SCALING_ORIG_CTX_LEN = "{arch}.rope.scaling.original_context_length"
6262
SCALING_FINETUNED = "{arch}.rope.scaling.finetuned"
6363

64+
class SSM:
65+
CONV_KERNEL_SIZE = "{arch}.ssm.d_conv"
66+
INNER_LENGTH = "{arch}.ssm.d_inner"
67+
STATE_LENGTH = "{arch}.ssm.d_state"
68+
DT_RANK = "{arch}.ssm.dt_rank"
69+
6470
class Tokenizer:
6571
MODEL = "tokenizer.ggml.model"
6672
LIST = "tokenizer.ggml.tokens"
@@ -747,6 +753,12 @@ def get_type(val: Any) -> GGUFValueType:
747753
KEY_ROPE_SCALING_ORIG_CTX_LEN = Keys.Rope.SCALING_ORIG_CTX_LEN
748754
KEY_ROPE_SCALING_FINETUNED = Keys.Rope.SCALING_FINETUNED
749755

756+
# SSM
757+
KEY_SSM_CONV_KERNEL_SIZE = Keys.SSM.CONV_KERNEL_SIZE
758+
KEY_SSM_INNER_LENGTH = Keys.SSM.INNER_LENGTH
759+
KEY_SSM_STATE_LENGTH = Keys.SSM.STATE_LENGTH
760+
KEY_SSM_DT_RANK = Keys.SSM.DT_RANK
761+
750762
# tokenization
751763
KEY_TOKENIZER_MODEL = Keys.Tokenizer.MODEL
752764
KEY_TOKENIZER_LIST = Keys.Tokenizer.LIST

gguf-py/gguf/gguf_writer.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -382,6 +382,18 @@ def add_rope_scaling_orig_ctx_len(self, value: int) -> None:
382382
def add_rope_scaling_finetuned(self, value: bool) -> None:
383383
self.add_bool(Keys.Rope.SCALING_FINETUNED.format(arch=self.arch), value)
384384

385+
def add_ssm_conv_kernel_size(self, value: int) -> None:
386+
self.add_uint32(Keys.SSM.CONV_KERNEL_SIZE.format(arch=self.arch), value)
387+
388+
def add_ssm_inner_length(self, value: int) -> None:
389+
self.add_uint32(Keys.SSM.INNER_LENGTH.format(arch=self.arch), value)
390+
391+
def add_ssm_state_length(self, value: int) -> None:
392+
self.add_uint32(Keys.SSM.STATE_LENGTH.format(arch=self.arch), value)
393+
394+
def add_ssm_dt_rank(self, value: int) -> None:
395+
self.add_uint32(Keys.SSM.DT_RANK.format(arch=self.arch), value)
396+
385397
def add_tokenizer_model(self, model: str) -> None:
386398
self.add_string(Keys.Tokenizer.MODEL, model)
387399

0 commit comments

Comments
 (0)