Skip to content

Commit 6d2f0b1

Browse files
committed
feat: Add conversion for Bamba models
This is borrowed and adapted from the original implementation ggml-org#10810 Branch: GraniteFour Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>
1 parent f70df62 commit 6d2f0b1

File tree

4 files changed

+151
-6
lines changed

4 files changed

+151
-6
lines changed

convert_hf_to_gguf.py

Lines changed: 105 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4427,6 +4427,9 @@ def __init__(self, dir_model: Path, *args, **kwargs):
44274427
with open(dir_model / "config.json", "r", encoding="utf-8") as f:
44284428
hparams = json.load(f)
44294429
super().__init__(dir_model, *args, hparams=hparams, **kwargs)
4430+
self.d_model = self.find_hparam(["hidden_size", "d_model", "dim"])
4431+
self.d_inner = self.find_hparam(["intermediate_size", "d_inner"], optional=True) or 2 * self.d_model
4432+
self.n_group = self.hparams.get("n_groups", 1)
44304433

44314434
def set_vocab(self):
44324435
vocab_size = self.hparams["vocab_size"]
@@ -4497,10 +4500,7 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
44974500
# (D is also unsqueezed, but for more straightforward broadcast internally)
44984501
data_torch = data_torch.reshape((*data_torch.shape, 1))
44994502
elif self.match_model_tensor_name(new_name, gguf.MODEL_TENSOR.SSM_NORM, bid):
4500-
d_model = self.find_hparam(["hidden_size", "d_model", "dim"])
4501-
d_inner = self.find_hparam(["intermediate_size", "d_inner"], optional=True) or 2 * d_model
4502-
n_group = self.hparams.get("n_groups", 1)
4503-
data_torch = data_torch.reshape((n_group, d_inner // n_group))
4503+
data_torch = data_torch.reshape((self.n_group, self.d_inner // self.n_group))
45044504

45054505
if name.endswith(".A_log"):
45064506
logger.debug("A_log --> A ==> " + new_name)
@@ -4509,6 +4509,107 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
45094509
yield (new_name, data_torch)
45104510

45114511

4512+
@ModelBase.register("BambaForCausalLM")
4513+
class BambaModel(Mamba2Model):
4514+
"""Bamba is a hybrid SSM + Attention model that uses Mamba2 SSM layers"""
4515+
model_arch = gguf.MODEL_ARCH.BAMBA
4516+
undo_permute = True
4517+
4518+
def __init__(self, *args, **kwargs):
4519+
4520+
# Hybrid mamba models use a prefix for the mamba-specific params.
4521+
# TODO: Extend this if the prefix(es) need to be configurable
4522+
self.hparam_prefixes = ["mamba"]
4523+
4524+
super().__init__(*args, **kwargs)
4525+
4526+
# Use Llama conversion for attention
4527+
self._transformer_model_class: type[TextModel] = LlamaModel
4528+
4529+
# Lists of which layers use ssm vs attention
4530+
self._attn_layers = self.hparams.get("attn_layer_indices", [])
4531+
if not self._attn_layers:
4532+
attn_period = self.hparams.get("attn_layer_period")
4533+
assert attn_period, "Didn't find attn_layer_indices or attn_layer_period"
4534+
attn_offset = self.hparams.get("attn_layer_offset")
4535+
assert attn_offset is not None, "No attention layer offset set with attn_layer_period"
4536+
self._attn_layers = [
4537+
i for i in range(self.block_count)
4538+
if i % attn_period == attn_offset
4539+
]
4540+
self._ssm_layers = [
4541+
i for i in range(self.block_count)
4542+
if i not in self._attn_layers
4543+
]
4544+
4545+
# n_group and d_inner are used during reshape_tensors for mamaba2
4546+
self.d_model = self.find_hparam(["hidden_size", "d_model"])
4547+
self.n_group = self.find_hparam(["n_groups"])
4548+
self.d_inner = self.find_hparam(["expand"]) * self.d_model
4549+
4550+
def find_hparam(self, keys: Iterable[str], *args, **kwargs) -> Any:
4551+
prefixed = []
4552+
for pfx in self.hparam_prefixes:
4553+
prefixed.extend(
4554+
"_".join([pfx, k])
4555+
for k in keys
4556+
)
4557+
keys = list(keys) + prefixed
4558+
return super().find_hparam(keys, *args, **kwargs)
4559+
4560+
def set_gguf_parameters(self):
4561+
4562+
## General Params ##
4563+
self.gguf_writer.add_embedding_length(self.d_model)
4564+
self.gguf_writer.add_block_count(self.block_count)
4565+
self.gguf_writer.add_context_length(self.hparams.get("max_position_embeddings", 0))
4566+
self.gguf_writer.add_vocab_size(self.hparams["vocab_size"])
4567+
self.gguf_writer.add_feed_forward_length(self.hparams["intermediate_size"])
4568+
4569+
## Mamba mixer params ##
4570+
self.gguf_writer.add_ssm_conv_kernel(self.find_hparam(["conv_kernel", "d_conv"]))
4571+
self.gguf_writer.add_ssm_state_size(self.find_hparam(["state_size", "d_state"]))
4572+
self.gguf_writer.add_ssm_group_count(self.n_group)
4573+
self.gguf_writer.add_ssm_inner_size(self.d_inner)
4574+
# NOTE: The mamba_dt_rank is _not_ the right field for how this is used
4575+
# in llama.cpp
4576+
self.gguf_writer.add_ssm_time_step_rank(self.find_hparam(["n_heads"]))
4577+
4578+
## Attention params ##
4579+
self.gguf_writer.add_attn_layer_indices(self._attn_layers)
4580+
self.gguf_writer.add_rope_dimension_count(self.hparams["attn_rotary_emb"])
4581+
self.gguf_writer.add_head_count(self.hparams["num_attention_heads"])
4582+
self.gguf_writer.add_head_count_kv(self.find_hparam(["num_key_value_heads", "n_head_kv"]))
4583+
4584+
## Feed Forward Params ##
4585+
self.gguf_writer.add_layer_norm_rms_eps(
4586+
self.find_hparam(["layer_norm_epsilon", "rms_norm_eps"], optional=True) or 1e-5
4587+
)
4588+
4589+
## Validation ##
4590+
d_head = self.find_hparam(["d_head"], optional=True) or 64
4591+
assert self.hparams.get("hidden_act") in [None, "silu"], "Only SILU activation supported"
4592+
assert self.d_inner % d_head == 0, f"SSM inner size {self.d_inner} not a multiple of head dim {d_head}"
4593+
4594+
def modify_tensors(
4595+
self, data_torch: Tensor, name: str, bid: int | None
4596+
) -> Iterable[tuple[str, Tensor]]:
4597+
4598+
# Determine whether this is a mamaba layer or an attention layer
4599+
if bid in self._ssm_layers:
4600+
for mamba_new_name, data_torch in super().modify_tensors(
4601+
data_torch, name, bid
4602+
):
4603+
yield mamba_new_name, data_torch
4604+
elif bid in self._attn_layers:
4605+
for llama_new_name, data_torch in self._transformer_model_class.modify_tensors(
4606+
self, data_torch, name, bid
4607+
):
4608+
yield llama_new_name, data_torch
4609+
else:
4610+
yield self.map_tensor_name(name), data_torch
4611+
4612+
45124613
@ModelBase.register("CohereForCausalLM")
45134614
class CommandR2Model(TextModel):
45144615
model_arch = gguf.MODEL_ARCH.COMMAND_R

gguf-py/gguf/constants.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -167,6 +167,9 @@ class SSM:
167167
GROUP_COUNT = "{arch}.ssm.group_count"
168168
DT_B_C_RMS = "{arch}.ssm.dt_b_c_rms"
169169

170+
class HybridAttention:
171+
ATTN_LAYER_INDICES = "{arch}.attention.layer_indices"
172+
170173
class WKV:
171174
HEAD_SIZE = "{arch}.wkv.head_size"
172175

@@ -300,6 +303,7 @@ class MODEL_ARCH(IntEnum):
300303
ARWKV7 = auto()
301304
MAMBA = auto()
302305
MAMBA2 = auto()
306+
BAMBA = auto()
303307
XVERSE = auto()
304308
COMMAND_R = auto()
305309
COHERE2 = auto()
@@ -563,6 +567,7 @@ class MODEL_TENSOR(IntEnum):
563567
MODEL_ARCH.ARWKV7: "arwkv7",
564568
MODEL_ARCH.MAMBA: "mamba",
565569
MODEL_ARCH.MAMBA2: "mamba2",
570+
MODEL_ARCH.BAMBA: "bamba",
566571
MODEL_ARCH.XVERSE: "xverse",
567572
MODEL_ARCH.COMMAND_R: "command-r",
568573
MODEL_ARCH.COHERE2: "cohere2",
@@ -1558,6 +1563,31 @@ class MODEL_TENSOR(IntEnum):
15581563
MODEL_TENSOR.SSM_NORM,
15591564
MODEL_TENSOR.SSM_OUT,
15601565
],
1566+
MODEL_ARCH.BAMBA: [
1567+
MODEL_TENSOR.TOKEN_EMBD,
1568+
MODEL_TENSOR.OUTPUT_NORM,
1569+
MODEL_TENSOR.OUTPUT,
1570+
MODEL_TENSOR.ATTN_NORM,
1571+
MODEL_TENSOR.SSM_IN,
1572+
MODEL_TENSOR.SSM_CONV1D,
1573+
MODEL_TENSOR.SSM_DT,
1574+
MODEL_TENSOR.SSM_A,
1575+
MODEL_TENSOR.SSM_D,
1576+
MODEL_TENSOR.SSM_NORM,
1577+
MODEL_TENSOR.SSM_OUT,
1578+
MODEL_TENSOR.ATTN_Q,
1579+
MODEL_TENSOR.ATTN_K,
1580+
MODEL_TENSOR.ATTN_V,
1581+
MODEL_TENSOR.ATTN_OUT,
1582+
MODEL_TENSOR.FFN_NORM,
1583+
MODEL_TENSOR.FFN_GATE,
1584+
MODEL_TENSOR.FFN_DOWN,
1585+
MODEL_TENSOR.FFN_UP,
1586+
MODEL_TENSOR.FFN_GATE_INP,
1587+
MODEL_TENSOR.FFN_GATE_EXP,
1588+
MODEL_TENSOR.FFN_DOWN_EXP,
1589+
MODEL_TENSOR.FFN_UP_EXP,
1590+
],
15611591
MODEL_ARCH.XVERSE: [
15621592
MODEL_TENSOR.TOKEN_EMBD,
15631593
MODEL_TENSOR.OUTPUT_NORM,

gguf-py/gguf/gguf_writer.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -848,6 +848,9 @@ def add_ssm_group_count(self, value: int) -> None:
848848
def add_ssm_dt_b_c_rms(self, value: bool) -> None:
849849
self.add_bool(Keys.SSM.DT_B_C_RMS.format(arch=self.arch), value)
850850

851+
def add_attn_layer_indices(self, values: list[int]) -> None:
852+
self.add_array(Keys.HybridAttention.ATTN_LAYER_INDICES.format(arch=self.arch), values)
853+
851854
def add_tokenizer_model(self, model: str) -> None:
852855
self.add_string(Keys.Tokenizer.MODEL, model)
853856

gguf-py/gguf/tensor_mapping.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ class TensorNameMap:
1313
"transformer.wte", # gpt2 gpt-j mpt refact qwen dbrx jais exaone
1414
"transformer.word_embeddings", # falcon
1515
"word_embeddings", # bloom
16-
"model.embed_tokens", # llama-hf nemotron olmoe olmo2 rwkv6qwen2 glm4-0414
16+
"model.embed_tokens", # llama-hf nemotron olmoe olmo2 rwkv6qwen2 glm4-0414 bamba
1717
"tok_embeddings", # llama-pth
1818
"embeddings.word_embeddings", # bert nomic-bert
1919
"language_model.embedding.word_embeddings", # persimmon
@@ -117,7 +117,7 @@ class TensorNameMap:
117117
"transformer.h.{bid}.input_layernorm", # falcon7b
118118
"h.{bid}.input_layernorm", # bloom
119119
"transformer.h.{bid}.ln_mlp", # falcon40b
120-
"model.layers.{bid}.input_layernorm", # llama-hf nemotron olmoe phimoe
120+
"model.layers.{bid}.input_layernorm", # llama-hf nemotron olmoe phimoe bamba
121121
"layers.{bid}.attention_norm", # llama-pth
122122
"language_model.encoder.layers.{bid}.input_layernorm", # persimmon
123123
"model.layers.{bid}.ln1", # yi
@@ -269,6 +269,7 @@ class TensorNameMap:
269269
"encoder.layers.{bid}.post_attention_layernorm", # chatglm
270270
"transformer.layers.{bid}.ffn_norm", # openelm
271271
"language_model.model.layers.{bid}.post_attention_layernorm", # llama4
272+
"model.layers.{bid}.pre_ff_layernorm", # bamba
272273
),
273274

274275
# Post feed-forward norm
@@ -330,6 +331,7 @@ class TensorNameMap:
330331
"encoder.layers.{bid}.mlp.dense_h_to_4h", # chatglm
331332
"transformer.h.{bid}.mlp.c_fc_1", # exaone
332333
"language_model.model.layers.{bid}.feed_forward.up_proj", # llama4
334+
"model.layers.{bid}.feed_forward.up_proj", # bamba
333335
),
334336

335337
MODEL_TENSOR.FFN_UP_EXP: (
@@ -367,6 +369,7 @@ class TensorNameMap:
367369
"model.layers.{bid}.residual_mlp.w1", # arctic
368370
"transformer.h.{bid}.mlp.c_fc_0", # exaone
369371
"language_model.model.layers.{bid}.feed_forward.gate_proj", # llama4
372+
"model.layers.{bid}.feed_forward.gate_proj", # bamba
370373
),
371374

372375
MODEL_TENSOR.FFN_GATE_EXP: (
@@ -411,6 +414,7 @@ class TensorNameMap:
411414
"encoder.layers.{bid}.mlp.dense_4h_to_h", # chatglm
412415
"model.layers.h.{bid}.mlp.c_proj", # exaone
413416
"language_model.model.layers.{bid}.feed_forward.down_proj", # llama4
417+
"model.layers.{bid}.feed_forward.down_proj", # bamba
414418
),
415419

416420
MODEL_TENSOR.FFN_DOWN_EXP: (
@@ -464,11 +468,13 @@ class TensorNameMap:
464468
MODEL_TENSOR.SSM_IN: (
465469
"model.layers.{bid}.in_proj",
466470
"backbone.layers.{bid}.mixer.in_proj",
471+
"model.layers.{bid}.mamba.in_proj", # bamba
467472
),
468473

469474
MODEL_TENSOR.SSM_CONV1D: (
470475
"model.layers.{bid}.conv1d",
471476
"backbone.layers.{bid}.mixer.conv1d",
477+
"model.layers.{bid}.mamba.conv1d", # bamba
472478
),
473479

474480
MODEL_TENSOR.SSM_X: (
@@ -479,25 +485,30 @@ class TensorNameMap:
479485
MODEL_TENSOR.SSM_DT: (
480486
"model.layers.{bid}.dt_proj",
481487
"backbone.layers.{bid}.mixer.dt_proj",
488+
"model.layers.{bid}.mamba.dt_proj", # bamba
482489
),
483490

484491
MODEL_TENSOR.SSM_A: (
485492
"model.layers.{bid}.A_log",
486493
"backbone.layers.{bid}.mixer.A_log",
494+
"model.layers.{bid}.mamba.A_log", # bamba
487495
),
488496

489497
MODEL_TENSOR.SSM_D: (
490498
"model.layers.{bid}.D",
491499
"backbone.layers.{bid}.mixer.D",
500+
"model.layers.{bid}.mamba.D", # bamba
492501
),
493502

494503
MODEL_TENSOR.SSM_NORM: (
495504
"backbone.layers.{bid}.mixer.norm", # mamba2
505+
"model.layers.{bid}.mamba.norm", # bamba
496506
),
497507

498508
MODEL_TENSOR.SSM_OUT: (
499509
"model.layers.{bid}.out_proj",
500510
"backbone.layers.{bid}.mixer.out_proj",
511+
"model.layers.{bid}.mamba.out_proj", # bamba
501512
),
502513

503514
MODEL_TENSOR.TIME_MIX_W0: (

0 commit comments

Comments
 (0)