Skip to content

Commit d8024a4

Browse files
committed
convert-hf : support new metadata keys for Mamba
For the models available at https://huggingface.co/collections/state-spaces/transformers-compatible-mamba-65e7b40ab87e5297e45ae406
1 parent 7cd5a1f commit d8024a4

File tree

1 file changed

+7
-6
lines changed

1 file changed

+7
-6
lines changed

convert-hf-to-gguf.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1884,14 +1884,15 @@ def set_vocab(self):
18841884
self.gguf_writer.add_unk_token_id(field.parts[-1].tolist()[0])
18851885

18861886
def set_gguf_parameters(self):
1887-
d_model = self.hparams["d_model"]
1888-
d_conv = self.hparams.get("d_conv", 4)
1889-
d_inner = self.hparams.get("d_inner", 2 * d_model)
1890-
d_state = self.hparams.get("d_state", 16)
1887+
d_model = self.find_hparam(["hidden_size", "d_model"])
1888+
d_conv = self.find_hparam(["conv_kernel", "d_conv"], optional=True) or 4
1889+
d_inner = self.find_hparam(["intermediate_size", "d_inner"], optional=True) or 2 * d_model
1890+
d_state = self.find_hparam(["state_size", "d_state"], optional=True) or 16
18911891
# ceiling division
18921892
# ref: https://stackoverflow.com/a/17511341/22827863
18931893
# ref: https://github.com/state-spaces/mamba/blob/ce59daea3a090d011d6476c6e5b97f6d58ddad8b/mamba_ssm/modules/mamba_simple.py#L58
1894-
dt_rank = self.hparams.get("dt_rank", -(d_model // -16))
1894+
dt_rank = self.find_hparam(["time_step_rank", "dt_rank"], optional=True) or -(d_model // -16)
1895+
rms_norm_eps = self.find_hparam(["layer_norm_epsilon", "rms_norm_eps"], optional=True) or 1e-5
18951896

18961897
# Fail early for models which don't have a block expansion factor of 2
18971898
assert d_inner == 2 * d_model
@@ -1906,7 +1907,7 @@ def set_gguf_parameters(self):
19061907
self.gguf_writer.add_ssm_inner_length(d_inner)
19071908
self.gguf_writer.add_ssm_state_length(d_state)
19081909
self.gguf_writer.add_ssm_dt_rank(dt_rank)
1909-
self.gguf_writer.add_layer_norm_rms_eps(self.hparams.get("rms_norm_eps", 1e-5))
1910+
self.gguf_writer.add_layer_norm_rms_eps(rms_norm_eps)
19101911
self.gguf_writer.add_file_type(self.ftype)
19111912

19121913
def write_tensors(self):

0 commit comments

Comments
 (0)