Skip to content

Commit 587e739

Browse files
committed
feat: Add conversion for Bamba models
This is borrowed and adapted from the original implementation ggml-org#10810 Branch: GraniteFour Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>
1 parent b1869d4 commit 587e739

File tree

4 files changed

+155
-10
lines changed

4 files changed

+155
-10
lines changed

convert_hf_to_gguf.py

Lines changed: 105 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4451,6 +4451,9 @@ def __init__(self, dir_model: Path, *args, **kwargs):
44514451
with open(dir_model / "config.json", "r", encoding="utf-8") as f:
44524452
hparams = json.load(f)
44534453
super().__init__(dir_model, *args, hparams=hparams, **kwargs)
4454+
self.d_model = self.find_hparam(["hidden_size", "d_model", "dim"])
4455+
self.d_inner = self.find_hparam(["intermediate_size", "d_inner"], optional=True) or 2 * self.d_model
4456+
self.n_group = self.hparams.get("n_groups", 1)
44544457

44554458
def set_vocab(self):
44564459
vocab_size = self.hparams["vocab_size"]
@@ -4521,10 +4524,7 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
45214524
# (D is also unsqueezed, but for more straightforward broadcast internally)
45224525
data_torch = data_torch.reshape((*data_torch.shape, 1))
45234526
elif self.match_model_tensor_name(new_name, gguf.MODEL_TENSOR.SSM_NORM, bid):
4524-
d_model = self.find_hparam(["hidden_size", "d_model", "dim"])
4525-
d_inner = self.find_hparam(["intermediate_size", "d_inner"], optional=True) or 2 * d_model
4526-
n_group = self.hparams.get("n_groups", 1)
4527-
data_torch = data_torch.reshape((n_group, d_inner // n_group))
4527+
data_torch = data_torch.reshape((self.n_group, self.d_inner // self.n_group))
45284528

45294529
if name.endswith(".A_log"):
45304530
logger.debug("A_log --> A ==> " + new_name)
@@ -4533,6 +4533,107 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
45334533
yield (new_name, data_torch)
45344534

45354535

4536+
@ModelBase.register("BambaForCausalLM")
4537+
class BambaModel(Mamba2Model):
4538+
"""Bamba is a hybrid SSM + Attention model that uses Mamba2 SSM layers"""
4539+
model_arch = gguf.MODEL_ARCH.BAMBA
4540+
undo_permute = True
4541+
4542+
def __init__(self, *args, **kwargs):
4543+
4544+
# Hybrid mamba models use a prefix for the mamba-specific params.
4545+
# TODO: Extend this if the prefix(es) need to be configurable
4546+
self.hparam_prefixes = ["mamba"]
4547+
4548+
super().__init__(*args, **kwargs)
4549+
4550+
# Use Llama conversion for attention
4551+
self._transformer_model_class: type[TextModel] = LlamaModel
4552+
4553+
# Lists of which layers use ssm vs attention
4554+
self._attn_layers = self.hparams.get("attn_layer_indices", [])
4555+
if not self._attn_layers:
4556+
attn_period = self.hparams.get("attn_layer_period")
4557+
assert attn_period, "Didn't find attn_layer_indices or attn_layer_period"
4558+
attn_offset = self.hparams.get("attn_layer_offset")
4559+
assert attn_offset is not None, "No attention layer offset set with attn_layer_period"
4560+
self._attn_layers = [
4561+
i for i in range(self.block_count)
4562+
if i % attn_period == attn_offset
4563+
]
4564+
self._ssm_layers = [
4565+
i for i in range(self.block_count)
4566+
if i not in self._attn_layers
4567+
]
4568+
4569+
# n_group and d_inner are used during reshape_tensors for mamaba2
4570+
self.d_model = self.find_hparam(["hidden_size", "d_model"])
4571+
self.n_group = self.find_hparam(["n_groups"])
4572+
self.d_inner = self.find_hparam(["expand"]) * self.d_model
4573+
4574+
def find_hparam(self, keys: Iterable[str], *args, **kwargs) -> Any:
4575+
prefixed = []
4576+
for pfx in self.hparam_prefixes:
4577+
prefixed.extend(
4578+
"_".join([pfx, k])
4579+
for k in keys
4580+
)
4581+
keys = list(keys) + prefixed
4582+
return super().find_hparam(keys, *args, **kwargs)
4583+
4584+
def set_gguf_parameters(self):
4585+
4586+
## General Params ##
4587+
self.gguf_writer.add_embedding_length(self.d_model)
4588+
self.gguf_writer.add_block_count(self.block_count)
4589+
self.gguf_writer.add_context_length(self.hparams.get("max_position_embeddings", 0))
4590+
self.gguf_writer.add_vocab_size(self.hparams["vocab_size"])
4591+
self.gguf_writer.add_feed_forward_length(self.hparams["intermediate_size"])
4592+
4593+
## Mamba mixer params ##
4594+
self.gguf_writer.add_ssm_conv_kernel(self.find_hparam(["conv_kernel", "d_conv"]))
4595+
self.gguf_writer.add_ssm_state_size(self.find_hparam(["state_size", "d_state"]))
4596+
self.gguf_writer.add_ssm_group_count(self.n_group)
4597+
self.gguf_writer.add_ssm_inner_size(self.d_inner)
4598+
# NOTE: The mamba_dt_rank is _not_ the right field for how this is used
4599+
# in llama.cpp
4600+
self.gguf_writer.add_ssm_time_step_rank(self.find_hparam(["n_heads"]))
4601+
4602+
## Attention params ##
4603+
self.gguf_writer.add_attn_layer_indices(self._attn_layers)
4604+
self.gguf_writer.add_rope_dimension_count(self.hparams["attn_rotary_emb"])
4605+
self.gguf_writer.add_head_count(self.hparams["num_attention_heads"])
4606+
self.gguf_writer.add_head_count_kv(self.find_hparam(["num_key_value_heads", "n_head_kv"]))
4607+
4608+
## Feed Forward Params ##
4609+
self.gguf_writer.add_layer_norm_rms_eps(
4610+
self.find_hparam(["layer_norm_epsilon", "rms_norm_eps"], optional=True) or 1e-5
4611+
)
4612+
4613+
## Validation ##
4614+
d_head = self.find_hparam(["d_head"], optional=True) or 64
4615+
assert self.hparams.get("hidden_act") in [None, "silu"], "Only SILU activation supported"
4616+
assert self.d_inner % d_head == 0, f"SSM inner size {self.d_inner} not a multiple of head dim {d_head}"
4617+
4618+
def modify_tensors(
4619+
self, data_torch: Tensor, name: str, bid: int | None
4620+
) -> Iterable[tuple[str, Tensor]]:
4621+
4622+
# Determine whether this is a mamaba layer or an attention layer
4623+
if bid in self._ssm_layers:
4624+
for mamba_new_name, data_torch in super().modify_tensors(
4625+
data_torch, name, bid
4626+
):
4627+
yield mamba_new_name, data_torch
4628+
elif bid in self._attn_layers:
4629+
for llama_new_name, data_torch in self._transformer_model_class.modify_tensors(
4630+
self, data_torch, name, bid
4631+
):
4632+
yield llama_new_name, data_torch
4633+
else:
4634+
yield self.map_tensor_name(name), data_torch
4635+
4636+
45364637
@ModelBase.register("CohereForCausalLM")
45374638
class CommandR2Model(TextModel):
45384639
model_arch = gguf.MODEL_ARCH.COMMAND_R

gguf-py/gguf/constants.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -167,6 +167,9 @@ class SSM:
167167
GROUP_COUNT = "{arch}.ssm.group_count"
168168
DT_B_C_RMS = "{arch}.ssm.dt_b_c_rms"
169169

170+
class HybridAttention:
171+
ATTN_LAYER_INDICES = "{arch}.attention.layer_indices"
172+
170173
class WKV:
171174
HEAD_SIZE = "{arch}.wkv.head_size"
172175

@@ -300,6 +303,7 @@ class MODEL_ARCH(IntEnum):
300303
ARWKV7 = auto()
301304
MAMBA = auto()
302305
MAMBA2 = auto()
306+
BAMBA = auto()
303307
XVERSE = auto()
304308
COMMAND_R = auto()
305309
COHERE2 = auto()
@@ -564,6 +568,7 @@ class MODEL_TENSOR(IntEnum):
564568
MODEL_ARCH.ARWKV7: "arwkv7",
565569
MODEL_ARCH.MAMBA: "mamba",
566570
MODEL_ARCH.MAMBA2: "mamba2",
571+
MODEL_ARCH.BAMBA: "bamba",
567572
MODEL_ARCH.XVERSE: "xverse",
568573
MODEL_ARCH.COMMAND_R: "command-r",
569574
MODEL_ARCH.COHERE2: "cohere2",
@@ -1561,6 +1566,31 @@ class MODEL_TENSOR(IntEnum):
15611566
MODEL_TENSOR.SSM_NORM,
15621567
MODEL_TENSOR.SSM_OUT,
15631568
],
1569+
MODEL_ARCH.BAMBA: [
1570+
MODEL_TENSOR.TOKEN_EMBD,
1571+
MODEL_TENSOR.OUTPUT_NORM,
1572+
MODEL_TENSOR.OUTPUT,
1573+
MODEL_TENSOR.ATTN_NORM,
1574+
MODEL_TENSOR.SSM_IN,
1575+
MODEL_TENSOR.SSM_CONV1D,
1576+
MODEL_TENSOR.SSM_DT,
1577+
MODEL_TENSOR.SSM_A,
1578+
MODEL_TENSOR.SSM_D,
1579+
MODEL_TENSOR.SSM_NORM,
1580+
MODEL_TENSOR.SSM_OUT,
1581+
MODEL_TENSOR.ATTN_Q,
1582+
MODEL_TENSOR.ATTN_K,
1583+
MODEL_TENSOR.ATTN_V,
1584+
MODEL_TENSOR.ATTN_OUT,
1585+
MODEL_TENSOR.FFN_NORM,
1586+
MODEL_TENSOR.FFN_GATE,
1587+
MODEL_TENSOR.FFN_DOWN,
1588+
MODEL_TENSOR.FFN_UP,
1589+
MODEL_TENSOR.FFN_GATE_INP,
1590+
MODEL_TENSOR.FFN_GATE_EXP,
1591+
MODEL_TENSOR.FFN_DOWN_EXP,
1592+
MODEL_TENSOR.FFN_UP_EXP,
1593+
],
15641594
MODEL_ARCH.XVERSE: [
15651595
MODEL_TENSOR.TOKEN_EMBD,
15661596
MODEL_TENSOR.OUTPUT_NORM,

gguf-py/gguf/gguf_writer.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -848,6 +848,9 @@ def add_ssm_group_count(self, value: int) -> None:
848848
def add_ssm_dt_b_c_rms(self, value: bool) -> None:
849849
self.add_bool(Keys.SSM.DT_B_C_RMS.format(arch=self.arch), value)
850850

851+
def add_attn_layer_indices(self, values: list[int]) -> None:
852+
self.add_array(Keys.HybridAttention.ATTN_LAYER_INDICES.format(arch=self.arch), values)
853+
851854
def add_tokenizer_model(self, model: str) -> None:
852855
self.add_string(Keys.Tokenizer.MODEL, model)
853856

gguf-py/gguf/tensor_mapping.py

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ class TensorNameMap:
1313
"transformer.wte", # gpt2 gpt-j mpt refact qwen dbrx jais exaone
1414
"transformer.word_embeddings", # falcon
1515
"word_embeddings", # bloom
16-
"model.embed_tokens", # llama-hf nemotron olmoe olmo2 rwkv6qwen2 glm4-0414
16+
"model.embed_tokens", # llama-hf nemotron olmoe olmo2 rwkv6qwen2 glm4-0414 bamba
1717
"tok_embeddings", # llama-pth
1818
"embeddings.word_embeddings", # bert nomic-bert
1919
"language_model.embedding.word_embeddings", # persimmon
@@ -117,7 +117,7 @@ class TensorNameMap:
117117
"transformer.h.{bid}.input_layernorm", # falcon7b
118118
"h.{bid}.input_layernorm", # bloom
119119
"transformer.h.{bid}.ln_mlp", # falcon40b
120-
"model.layers.{bid}.input_layernorm", # llama-hf nemotron olmoe phimoe
120+
"model.layers.{bid}.input_layernorm", # llama-hf nemotron olmoe phimoe bamba
121121
"layers.{bid}.attention_norm", # llama-pth
122122
"language_model.encoder.layers.{bid}.input_layernorm", # persimmon
123123
"model.layers.{bid}.ln1", # yi
@@ -268,7 +268,8 @@ class TensorNameMap:
268268
"transformer.decoder_layer.{bid}.rms_norm_2", # Grok
269269
"encoder.layers.{bid}.post_attention_layernorm", # chatglm
270270
"transformer.layers.{bid}.ffn_norm", # openelm
271-
"model.layers.{bid}.post_attention_layernorm", # llama4
271+
"language_model.model.layers.{bid}.post_attention_layernorm", # llama4
272+
"model.layers.{bid}.pre_ff_layernorm", # bamba
272273
),
273274

274275
# Post feed-forward norm
@@ -329,7 +330,8 @@ class TensorNameMap:
329330
"model.layers.{bid}.residual_mlp.w3", # arctic
330331
"encoder.layers.{bid}.mlp.dense_h_to_4h", # chatglm
331332
"transformer.h.{bid}.mlp.c_fc_1", # exaone
332-
"model.layers.{bid}.feed_forward.up_proj", # llama4
333+
"language_model.model.layers.{bid}.feed_forward.up_proj", # llama4
334+
"model.layers.{bid}.feed_forward.up_proj", # bamba
333335
),
334336

335337
MODEL_TENSOR.FFN_UP_EXP: (
@@ -366,7 +368,8 @@ class TensorNameMap:
366368
"transformer.h.{bid}.mlp.linear_1", # refact
367369
"model.layers.{bid}.residual_mlp.w1", # arctic
368370
"transformer.h.{bid}.mlp.c_fc_0", # exaone
369-
"model.layers.{bid}.feed_forward.gate_proj", # llama4
371+
"language_model.model.layers.{bid}.feed_forward.gate_proj", # llama4
372+
"model.layers.{bid}.feed_forward.gate_proj", # bamba
370373
),
371374

372375
MODEL_TENSOR.FFN_GATE_EXP: (
@@ -410,7 +413,8 @@ class TensorNameMap:
410413
"encoder.layer.{bid}.mlp.down_layer", # jina-bert-v2
411414
"encoder.layers.{bid}.mlp.dense_4h_to_h", # chatglm
412415
"model.layers.h.{bid}.mlp.c_proj", # exaone
413-
"model.layers.{bid}.feed_forward.down_proj", # llama4
416+
"language_model.model.layers.{bid}.feed_forward.down_proj", # llama4
417+
"model.layers.{bid}.feed_forward.down_proj", # bamba
414418
),
415419

416420
MODEL_TENSOR.FFN_DOWN_EXP: (
@@ -464,11 +468,13 @@ class TensorNameMap:
464468
MODEL_TENSOR.SSM_IN: (
465469
"model.layers.{bid}.in_proj",
466470
"backbone.layers.{bid}.mixer.in_proj",
471+
"model.layers.{bid}.mamba.in_proj", # bamba
467472
),
468473

469474
MODEL_TENSOR.SSM_CONV1D: (
470475
"model.layers.{bid}.conv1d",
471476
"backbone.layers.{bid}.mixer.conv1d",
477+
"model.layers.{bid}.mamba.conv1d", # bamba
472478
),
473479

474480
MODEL_TENSOR.SSM_X: (
@@ -479,25 +485,30 @@ class TensorNameMap:
479485
MODEL_TENSOR.SSM_DT: (
480486
"model.layers.{bid}.dt_proj",
481487
"backbone.layers.{bid}.mixer.dt_proj",
488+
"model.layers.{bid}.mamba.dt_proj", # bamba
482489
),
483490

484491
MODEL_TENSOR.SSM_A: (
485492
"model.layers.{bid}.A_log",
486493
"backbone.layers.{bid}.mixer.A_log",
494+
"model.layers.{bid}.mamba.A_log", # bamba
487495
),
488496

489497
MODEL_TENSOR.SSM_D: (
490498
"model.layers.{bid}.D",
491499
"backbone.layers.{bid}.mixer.D",
500+
"model.layers.{bid}.mamba.D", # bamba
492501
),
493502

494503
MODEL_TENSOR.SSM_NORM: (
495504
"backbone.layers.{bid}.mixer.norm", # mamba2
505+
"model.layers.{bid}.mamba.norm", # bamba
496506
),
497507

498508
MODEL_TENSOR.SSM_OUT: (
499509
"model.layers.{bid}.out_proj",
500510
"backbone.layers.{bid}.mixer.out_proj",
511+
"model.layers.{bid}.mamba.out_proj", # bamba
501512
),
502513

503514
MODEL_TENSOR.TIME_MIX_W0: (

0 commit comments

Comments
 (0)