From a6529ccd7662fc9e893e23134237789bf70f29ca Mon Sep 17 00:00:00 2001 From: Gabe Goodhart Date: Wed, 2 Apr 2025 14:22:33 -0600 Subject: [PATCH 01/11] feat: Add GGUF conversion for granitemoeshared Branch: GraniteMoEShared Signed-off-by: Gabe Goodhart --- convert_hf_to_gguf.py | 15 ++ gguf-py/gguf/constants.py | 291 ++++++++++++++++++--------------- gguf-py/gguf/tensor_mapping.py | 2 + 3 files changed, 172 insertions(+), 136 deletions(-) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index bf6bc68380b19..3d48eb5810928 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -5693,6 +5693,21 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter return super().modify_tensors(data_torch, name, bid) +@ModelBase.register("GraniteMoeSharedForCausalLM") +class GraniteMoeSharedModel(GraniteMoeModel): + """Conversion for IBM's GraniteMoeSharedForCausalLM""" + model_arch = gguf.MODEL_ARCH.GRANITE_MOE_SHARED + + def set_gguf_parameters(self): + """GraniteMoeShared uses GraniteMoe parameters plus the following: + - shared_intermediate_size + """ + super().set_gguf_parameters() + if shared_feed_forward_length := self.hparams.get("shared_intermediate_size"): + self.gguf_writer.add_expert_shared_feed_forward_length(shared_feed_forward_length) + logger.info("gguf: (granitemoeshared) shared_feed_forward_length = %s", shared_feed_forward_length) + + @ModelBase.register("BailingMoeForCausalLM") class BailingMoeModel(TextModel): model_arch = gguf.MODEL_ARCH.BAILINGMOE diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index 7dd7bb6d1b5d9..7df57413e52d6 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -255,74 +255,75 @@ class GGUFType: class MODEL_ARCH(IntEnum): - CLIP_VISION = auto() # dummy arch for clip.cpp - LLAMA = auto() - LLAMA4 = auto() - DECI = auto() - FALCON = auto() - BAICHUAN = auto() - GROK = auto() - GPT2 = auto() - GPTJ = auto() - GPTNEOX = auto() - MPT = auto() - STARCODER = auto() - REFACT = auto() - BERT = auto() - NOMIC_BERT = auto() - NOMIC_BERT_MOE = auto() - JINA_BERT_V2 = auto() - BLOOM = auto() - STABLELM = auto() - QWEN = auto() - QWEN2 = auto() - QWEN2MOE = auto() - QWEN2VL = auto() - QWEN3 = auto() - QWEN3MOE = auto() - PHI2 = auto() - PHI3 = auto() - PHIMOE = auto() - PLAMO = auto() - CODESHELL = auto() - ORION = auto() - INTERNLM2 = auto() - MINICPM = auto() - MINICPM3 = auto() - GEMMA = auto() - GEMMA2 = auto() - GEMMA3 = auto() - STARCODER2 = auto() - RWKV6 = auto() - RWKV6QWEN2 = auto() - RWKV7 = auto() - ARWKV7 = auto() - MAMBA = auto() - XVERSE = auto() - COMMAND_R = auto() - COHERE2 = auto() - DBRX = auto() - OLMO = auto() - OLMO2 = auto() - OLMOE = auto() - OPENELM = auto() - ARCTIC = auto() - DEEPSEEK = auto() - DEEPSEEK2 = auto() - CHATGLM = auto() - GLM4 = auto() - BITNET = auto() - T5 = auto() - T5ENCODER = auto() - JAIS = auto() - NEMOTRON = auto() - EXAONE = auto() - GRANITE = auto() - GRANITE_MOE = auto() - CHAMELEON = auto() - WAVTOKENIZER_DEC = auto() - PLM = auto() - BAILINGMOE = auto() + CLIP_VISION = auto() # dummy arch for clip.cpp + LLAMA = auto() + LLAMA4 = auto() + DECI = auto() + FALCON = auto() + BAICHUAN = auto() + GROK = auto() + GPT2 = auto() + GPTJ = auto() + GPTNEOX = auto() + MPT = auto() + STARCODER = auto() + REFACT = auto() + BERT = auto() + NOMIC_BERT = auto() + NOMIC_BERT_MOE = auto() + JINA_BERT_V2 = auto() + BLOOM = auto() + STABLELM = auto() + QWEN = auto() + QWEN2 = auto() + QWEN2MOE = auto() + QWEN2VL = auto() + QWEN3 = auto() + QWEN3MOE = auto() + PHI2 = auto() + PHI3 = auto() + PHIMOE = auto() + PLAMO = auto() + CODESHELL = auto() + ORION = auto() + INTERNLM2 = auto() + MINICPM = auto() + MINICPM3 = auto() + GEMMA = auto() + GEMMA2 = auto() + GEMMA3 = auto() + STARCODER2 = auto() + RWKV6 = auto() + RWKV6QWEN2 = auto() + RWKV7 = auto() + ARWKV7 = auto() + MAMBA = auto() + XVERSE = auto() + COMMAND_R = auto() + COHERE2 = auto() + DBRX = auto() + OLMO = auto() + OLMO2 = auto() + OLMOE = auto() + OPENELM = auto() + ARCTIC = auto() + DEEPSEEK = auto() + DEEPSEEK2 = auto() + CHATGLM = auto() + GLM4 = auto() + BITNET = auto() + T5 = auto() + T5ENCODER = auto() + JAIS = auto() + NEMOTRON = auto() + EXAONE = auto() + GRANITE = auto() + GRANITE_MOE = auto() + GRANITE_MOE_SHARED = auto() + CHAMELEON = auto() + WAVTOKENIZER_DEC = auto() + PLM = auto() + BAILINGMOE = auto() class VISION_PROJECTOR_TYPE(IntEnum): @@ -512,74 +513,75 @@ class MODEL_TENSOR(IntEnum): MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = { - MODEL_ARCH.CLIP_VISION: "clip", # dummy arch for clip.cpp - MODEL_ARCH.LLAMA: "llama", - MODEL_ARCH.LLAMA4: "llama4", - MODEL_ARCH.DECI: "deci", - MODEL_ARCH.FALCON: "falcon", - MODEL_ARCH.BAICHUAN: "baichuan", - MODEL_ARCH.GROK: "grok", - MODEL_ARCH.GPT2: "gpt2", - MODEL_ARCH.GPTJ: "gptj", - MODEL_ARCH.GPTNEOX: "gptneox", - MODEL_ARCH.MPT: "mpt", - MODEL_ARCH.STARCODER: "starcoder", - MODEL_ARCH.REFACT: "refact", - MODEL_ARCH.BERT: "bert", - MODEL_ARCH.NOMIC_BERT: "nomic-bert", - MODEL_ARCH.NOMIC_BERT_MOE: "nomic-bert-moe", - MODEL_ARCH.JINA_BERT_V2: "jina-bert-v2", - MODEL_ARCH.BLOOM: "bloom", - MODEL_ARCH.STABLELM: "stablelm", - MODEL_ARCH.QWEN: "qwen", - MODEL_ARCH.QWEN2: "qwen2", - MODEL_ARCH.QWEN2MOE: "qwen2moe", - MODEL_ARCH.QWEN2VL: "qwen2vl", - MODEL_ARCH.QWEN3: "qwen3", - MODEL_ARCH.QWEN3MOE: "qwen3moe", - MODEL_ARCH.PHI2: "phi2", - MODEL_ARCH.PHI3: "phi3", - MODEL_ARCH.PHIMOE: "phimoe", - MODEL_ARCH.PLAMO: "plamo", - MODEL_ARCH.CODESHELL: "codeshell", - MODEL_ARCH.ORION: "orion", - MODEL_ARCH.INTERNLM2: "internlm2", - MODEL_ARCH.MINICPM: "minicpm", - MODEL_ARCH.MINICPM3: "minicpm3", - MODEL_ARCH.GEMMA: "gemma", - MODEL_ARCH.GEMMA2: "gemma2", - MODEL_ARCH.GEMMA3: "gemma3", - MODEL_ARCH.STARCODER2: "starcoder2", - MODEL_ARCH.RWKV6: "rwkv6", - MODEL_ARCH.RWKV6QWEN2: "rwkv6qwen2", - MODEL_ARCH.RWKV7: "rwkv7", - MODEL_ARCH.ARWKV7: "arwkv7", - MODEL_ARCH.MAMBA: "mamba", - MODEL_ARCH.XVERSE: "xverse", - MODEL_ARCH.COMMAND_R: "command-r", - MODEL_ARCH.COHERE2: "cohere2", - MODEL_ARCH.DBRX: "dbrx", - MODEL_ARCH.OLMO: "olmo", - MODEL_ARCH.OLMO2: "olmo2", - MODEL_ARCH.OLMOE: "olmoe", - MODEL_ARCH.OPENELM: "openelm", - MODEL_ARCH.ARCTIC: "arctic", - MODEL_ARCH.DEEPSEEK: "deepseek", - MODEL_ARCH.DEEPSEEK2: "deepseek2", - MODEL_ARCH.CHATGLM: "chatglm", - MODEL_ARCH.GLM4: "glm4", - MODEL_ARCH.BITNET: "bitnet", - MODEL_ARCH.T5: "t5", - MODEL_ARCH.T5ENCODER: "t5encoder", - MODEL_ARCH.JAIS: "jais", - MODEL_ARCH.NEMOTRON: "nemotron", - MODEL_ARCH.EXAONE: "exaone", - MODEL_ARCH.GRANITE: "granite", - MODEL_ARCH.GRANITE_MOE: "granitemoe", - MODEL_ARCH.CHAMELEON: "chameleon", - MODEL_ARCH.WAVTOKENIZER_DEC: "wavtokenizer-dec", - MODEL_ARCH.PLM: "plm", - MODEL_ARCH.BAILINGMOE: "bailingmoe", + MODEL_ARCH.CLIP_VISION: "clip", # dummy arch for clip.cpp + MODEL_ARCH.LLAMA: "llama", + MODEL_ARCH.LLAMA4: "llama4", + MODEL_ARCH.DECI: "deci", + MODEL_ARCH.FALCON: "falcon", + MODEL_ARCH.BAICHUAN: "baichuan", + MODEL_ARCH.GROK: "grok", + MODEL_ARCH.GPT2: "gpt2", + MODEL_ARCH.GPTJ: "gptj", + MODEL_ARCH.GPTNEOX: "gptneox", + MODEL_ARCH.MPT: "mpt", + MODEL_ARCH.STARCODER: "starcoder", + MODEL_ARCH.REFACT: "refact", + MODEL_ARCH.BERT: "bert", + MODEL_ARCH.NOMIC_BERT: "nomic-bert", + MODEL_ARCH.NOMIC_BERT_MOE: "nomic-bert-moe", + MODEL_ARCH.JINA_BERT_V2: "jina-bert-v2", + MODEL_ARCH.BLOOM: "bloom", + MODEL_ARCH.STABLELM: "stablelm", + MODEL_ARCH.QWEN: "qwen", + MODEL_ARCH.QWEN2: "qwen2", + MODEL_ARCH.QWEN2MOE: "qwen2moe", + MODEL_ARCH.QWEN2VL: "qwen2vl", + MODEL_ARCH.QWEN3: "qwen3", + MODEL_ARCH.QWEN3MOE: "qwen3moe", + MODEL_ARCH.PHI2: "phi2", + MODEL_ARCH.PHI3: "phi3", + MODEL_ARCH.PHIMOE: "phimoe", + MODEL_ARCH.PLAMO: "plamo", + MODEL_ARCH.CODESHELL: "codeshell", + MODEL_ARCH.ORION: "orion", + MODEL_ARCH.INTERNLM2: "internlm2", + MODEL_ARCH.MINICPM: "minicpm", + MODEL_ARCH.MINICPM3: "minicpm3", + MODEL_ARCH.GEMMA: "gemma", + MODEL_ARCH.GEMMA2: "gemma2", + MODEL_ARCH.GEMMA3: "gemma3", + MODEL_ARCH.STARCODER2: "starcoder2", + MODEL_ARCH.RWKV6: "rwkv6", + MODEL_ARCH.RWKV6QWEN2: "rwkv6qwen2", + MODEL_ARCH.RWKV7: "rwkv7", + MODEL_ARCH.ARWKV7: "arwkv7", + MODEL_ARCH.MAMBA: "mamba", + MODEL_ARCH.XVERSE: "xverse", + MODEL_ARCH.COMMAND_R: "command-r", + MODEL_ARCH.COHERE2: "cohere2", + MODEL_ARCH.DBRX: "dbrx", + MODEL_ARCH.OLMO: "olmo", + MODEL_ARCH.OLMO2: "olmo2", + MODEL_ARCH.OLMOE: "olmoe", + MODEL_ARCH.OPENELM: "openelm", + MODEL_ARCH.ARCTIC: "arctic", + MODEL_ARCH.DEEPSEEK: "deepseek", + MODEL_ARCH.DEEPSEEK2: "deepseek2", + MODEL_ARCH.CHATGLM: "chatglm", + MODEL_ARCH.GLM4: "glm4", + MODEL_ARCH.BITNET: "bitnet", + MODEL_ARCH.T5: "t5", + MODEL_ARCH.T5ENCODER: "t5encoder", + MODEL_ARCH.JAIS: "jais", + MODEL_ARCH.NEMOTRON: "nemotron", + MODEL_ARCH.EXAONE: "exaone", + MODEL_ARCH.GRANITE: "granite", + MODEL_ARCH.GRANITE_MOE: "granitemoe", + MODEL_ARCH.GRANITE_MOE_SHARED: "granitemoeshared", + MODEL_ARCH.CHAMELEON: "chameleon", + MODEL_ARCH.WAVTOKENIZER_DEC: "wavtokenizer-dec", + MODEL_ARCH.PLM: "plm", + MODEL_ARCH.BAILINGMOE: "bailingmoe", } VISION_PROJECTOR_TYPE_NAMES: dict[VISION_PROJECTOR_TYPE, str] = { @@ -1894,6 +1896,23 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.FFN_DOWN_EXP, MODEL_TENSOR.FFN_UP_EXP, ], + MODEL_ARCH.GRANITE_MOE_SHARED: [ + MODEL_TENSOR.TOKEN_EMBD, + MODEL_TENSOR.OUTPUT_NORM, + MODEL_TENSOR.OUTPUT, + MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_Q, + MODEL_TENSOR.ATTN_K, + MODEL_TENSOR.ATTN_V, + MODEL_TENSOR.ATTN_OUT, + MODEL_TENSOR.FFN_NORM, + MODEL_TENSOR.FFN_GATE_INP, + MODEL_TENSOR.FFN_GATE_EXP, + MODEL_TENSOR.FFN_DOWN_EXP, + MODEL_TENSOR.FFN_UP_EXP, + MODEL_TENSOR.FFN_UP_SHEXP, + MODEL_TENSOR.FFN_DOWN_SHEXP, + ], MODEL_ARCH.CHAMELEON: [ MODEL_TENSOR.TOKEN_EMBD, MODEL_TENSOR.OUTPUT_NORM, diff --git a/gguf-py/gguf/tensor_mapping.py b/gguf-py/gguf/tensor_mapping.py index 003b0172c77b0..b4e6634c497b4 100644 --- a/gguf-py/gguf/tensor_mapping.py +++ b/gguf-py/gguf/tensor_mapping.py @@ -346,6 +346,7 @@ class TensorNameMap: "model.layers.{bid}.mlp.shared_expert.up_proj", # qwen2moe "model.layers.{bid}.mlp.shared_experts.up_proj", # deepseek deepseek2 "language_model.model.layers.{bid}.feed_forward.shared_expert.up_proj", # llama4 + "model.layers.{bid}.shared_mlp.input_linear", # granitemoeshared ), # AWQ-activation gate @@ -428,6 +429,7 @@ class TensorNameMap: "model.layers.{bid}.mlp.shared_expert.down_proj", # qwen2moe "model.layers.{bid}.mlp.shared_experts.down_proj", # deepseek deepseek2 "language_model.model.layers.{bid}.feed_forward.shared_expert.down_proj", # llama4 + "model.layers.{bid}.shared_mlp.output_linear", # granitemoeshared ), MODEL_TENSOR.ATTN_Q_NORM: ( From 731c5fc44d842f2b93eef19ef12bd76823db6a14 Mon Sep 17 00:00:00 2001 From: Gabe Goodhart Date: Wed, 2 Apr 2025 14:28:24 -0600 Subject: [PATCH 02/11] feat: hparam and arch plumbing for granitemoeshared Branch: GraniteMoEShared Signed-off-by: Gabe Goodhart --- src/llama-arch.cpp | 157 +++++++++++++++++++++++++-------------------- src/llama-arch.h | 1 + 2 files changed, 90 insertions(+), 68 deletions(-) diff --git a/src/llama-arch.cpp b/src/llama-arch.cpp index f2bc8ca768502..bff49ddbdd68b 100644 --- a/src/llama-arch.cpp +++ b/src/llama-arch.cpp @@ -5,74 +5,75 @@ #include static const std::map LLM_ARCH_NAMES = { - { LLM_ARCH_LLAMA, "llama" }, - { LLM_ARCH_LLAMA4, "llama4" }, - { LLM_ARCH_DECI, "deci" }, - { LLM_ARCH_FALCON, "falcon" }, - { LLM_ARCH_GROK, "grok" }, - { LLM_ARCH_GPT2, "gpt2" }, - { LLM_ARCH_GPTJ, "gptj" }, - { LLM_ARCH_GPTNEOX, "gptneox" }, - { LLM_ARCH_MPT, "mpt" }, - { LLM_ARCH_BAICHUAN, "baichuan" }, - { LLM_ARCH_STARCODER, "starcoder" }, - { LLM_ARCH_REFACT, "refact" }, - { LLM_ARCH_BERT, "bert" }, - { LLM_ARCH_NOMIC_BERT, "nomic-bert" }, - { LLM_ARCH_NOMIC_BERT_MOE, "nomic-bert-moe" }, - { LLM_ARCH_JINA_BERT_V2, "jina-bert-v2" }, - { LLM_ARCH_BLOOM, "bloom" }, - { LLM_ARCH_STABLELM, "stablelm" }, - { LLM_ARCH_QWEN, "qwen" }, - { LLM_ARCH_QWEN2, "qwen2" }, - { LLM_ARCH_QWEN2MOE, "qwen2moe" }, - { LLM_ARCH_QWEN2VL, "qwen2vl" }, - { LLM_ARCH_QWEN3, "qwen3" }, - { LLM_ARCH_QWEN3MOE, "qwen3moe" }, - { LLM_ARCH_PHI2, "phi2" }, - { LLM_ARCH_PHI3, "phi3" }, - { LLM_ARCH_PHIMOE, "phimoe" }, - { LLM_ARCH_PLAMO, "plamo" }, - { LLM_ARCH_CODESHELL, "codeshell" }, - { LLM_ARCH_ORION, "orion" }, - { LLM_ARCH_INTERNLM2, "internlm2" }, - { LLM_ARCH_MINICPM, "minicpm" }, - { LLM_ARCH_MINICPM3, "minicpm3" }, - { LLM_ARCH_GEMMA, "gemma" }, - { LLM_ARCH_GEMMA2, "gemma2" }, - { LLM_ARCH_GEMMA3, "gemma3" }, - { LLM_ARCH_STARCODER2, "starcoder2" }, - { LLM_ARCH_MAMBA, "mamba" }, - { LLM_ARCH_XVERSE, "xverse" }, - { LLM_ARCH_COMMAND_R, "command-r" }, - { LLM_ARCH_COHERE2, "cohere2" }, - { LLM_ARCH_DBRX, "dbrx" }, - { LLM_ARCH_OLMO, "olmo" }, - { LLM_ARCH_OLMO2, "olmo2" }, - { LLM_ARCH_OLMOE, "olmoe" }, - { LLM_ARCH_OPENELM, "openelm" }, - { LLM_ARCH_ARCTIC, "arctic" }, - { LLM_ARCH_DEEPSEEK, "deepseek" }, - { LLM_ARCH_DEEPSEEK2, "deepseek2" }, - { LLM_ARCH_CHATGLM, "chatglm" }, - { LLM_ARCH_GLM4, "glm4" }, - { LLM_ARCH_BITNET, "bitnet" }, - { LLM_ARCH_T5, "t5" }, - { LLM_ARCH_T5ENCODER, "t5encoder" }, - { LLM_ARCH_JAIS, "jais" }, - { LLM_ARCH_NEMOTRON, "nemotron" }, - { LLM_ARCH_EXAONE, "exaone" }, - { LLM_ARCH_RWKV6, "rwkv6" }, - { LLM_ARCH_RWKV6QWEN2, "rwkv6qwen2" }, - { LLM_ARCH_RWKV7, "rwkv7" }, - { LLM_ARCH_ARWKV7, "arwkv7" }, - { LLM_ARCH_GRANITE, "granite" }, - { LLM_ARCH_GRANITE_MOE, "granitemoe" }, - { LLM_ARCH_CHAMELEON, "chameleon" }, - { LLM_ARCH_WAVTOKENIZER_DEC, "wavtokenizer-dec" }, - { LLM_ARCH_PLM, "plm" }, - { LLM_ARCH_BAILINGMOE, "bailingmoe" }, - { LLM_ARCH_UNKNOWN, "(unknown)" }, + { LLM_ARCH_LLAMA, "llama" }, + { LLM_ARCH_LLAMA4, "llama4" }, + { LLM_ARCH_DECI, "deci" }, + { LLM_ARCH_FALCON, "falcon" }, + { LLM_ARCH_GROK, "grok" }, + { LLM_ARCH_GPT2, "gpt2" }, + { LLM_ARCH_GPTJ, "gptj" }, + { LLM_ARCH_GPTNEOX, "gptneox" }, + { LLM_ARCH_MPT, "mpt" }, + { LLM_ARCH_BAICHUAN, "baichuan" }, + { LLM_ARCH_STARCODER, "starcoder" }, + { LLM_ARCH_REFACT, "refact" }, + { LLM_ARCH_BERT, "bert" }, + { LLM_ARCH_NOMIC_BERT, "nomic-bert" }, + { LLM_ARCH_NOMIC_BERT_MOE, "nomic-bert-moe" }, + { LLM_ARCH_JINA_BERT_V2, "jina-bert-v2" }, + { LLM_ARCH_BLOOM, "bloom" }, + { LLM_ARCH_STABLELM, "stablelm" }, + { LLM_ARCH_QWEN, "qwen" }, + { LLM_ARCH_QWEN2, "qwen2" }, + { LLM_ARCH_QWEN2MOE, "qwen2moe" }, + { LLM_ARCH_QWEN2VL, "qwen2vl" }, + { LLM_ARCH_QWEN3, "qwen3" }, + { LLM_ARCH_QWEN3MOE, "qwen3moe" }, + { LLM_ARCH_PHI2, "phi2" }, + { LLM_ARCH_PHI3, "phi3" }, + { LLM_ARCH_PHIMOE, "phimoe" }, + { LLM_ARCH_PLAMO, "plamo" }, + { LLM_ARCH_CODESHELL, "codeshell" }, + { LLM_ARCH_ORION, "orion" }, + { LLM_ARCH_INTERNLM2, "internlm2" }, + { LLM_ARCH_MINICPM, "minicpm" }, + { LLM_ARCH_MINICPM3, "minicpm3" }, + { LLM_ARCH_GEMMA, "gemma" }, + { LLM_ARCH_GEMMA2, "gemma2" }, + { LLM_ARCH_GEMMA3, "gemma3" }, + { LLM_ARCH_STARCODER2, "starcoder2" }, + { LLM_ARCH_MAMBA, "mamba" }, + { LLM_ARCH_XVERSE, "xverse" }, + { LLM_ARCH_COMMAND_R, "command-r" }, + { LLM_ARCH_COHERE2, "cohere2" }, + { LLM_ARCH_DBRX, "dbrx" }, + { LLM_ARCH_OLMO, "olmo" }, + { LLM_ARCH_OLMO2, "olmo2" }, + { LLM_ARCH_OLMOE, "olmoe" }, + { LLM_ARCH_OPENELM, "openelm" }, + { LLM_ARCH_ARCTIC, "arctic" }, + { LLM_ARCH_DEEPSEEK, "deepseek" }, + { LLM_ARCH_DEEPSEEK2, "deepseek2" }, + { LLM_ARCH_CHATGLM, "chatglm" }, + { LLM_ARCH_GLM4, "glm4" }, + { LLM_ARCH_BITNET, "bitnet" }, + { LLM_ARCH_T5, "t5" }, + { LLM_ARCH_T5ENCODER, "t5encoder" }, + { LLM_ARCH_JAIS, "jais" }, + { LLM_ARCH_NEMOTRON, "nemotron" }, + { LLM_ARCH_EXAONE, "exaone" }, + { LLM_ARCH_RWKV6, "rwkv6" }, + { LLM_ARCH_RWKV6QWEN2, "rwkv6qwen2" }, + { LLM_ARCH_RWKV7, "rwkv7" }, + { LLM_ARCH_ARWKV7, "arwkv7" }, + { LLM_ARCH_GRANITE, "granite" }, + { LLM_ARCH_GRANITE_MOE, "granitemoe" }, + { LLM_ARCH_GRANITE_MOE_SHARED, "granitemoeshared" }, + { LLM_ARCH_CHAMELEON, "chameleon" }, + { LLM_ARCH_WAVTOKENIZER_DEC, "wavtokenizer-dec" }, + { LLM_ARCH_PLM, "plm" }, + { LLM_ARCH_BAILINGMOE, "bailingmoe" }, + { LLM_ARCH_UNKNOWN, "(unknown)" }, }; static const std::map LLM_KV_NAMES = { @@ -1483,6 +1484,26 @@ static const std::map> LLM_TENSOR_N { LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" }, }, }, + { + LLM_ARCH_GRANITE_MOE_SHARED, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" }, + { LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" }, + { LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" }, + { LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" }, + { LLM_TENSOR_FFN_DOWN_SHEXP, "blk.%d.ffn_down_shexp" }, + { LLM_TENSOR_FFN_UP_SHEXP, "blk.%d.ffn_up_shexp" }, + }, + }, { LLM_ARCH_CHAMELEON, { diff --git a/src/llama-arch.h b/src/llama-arch.h index 41a023da3da6e..6462273244711 100644 --- a/src/llama-arch.h +++ b/src/llama-arch.h @@ -72,6 +72,7 @@ enum llm_arch { LLM_ARCH_ARWKV7, LLM_ARCH_GRANITE, LLM_ARCH_GRANITE_MOE, + LLM_ARCH_GRANITE_MOE_SHARED, LLM_ARCH_CHAMELEON, LLM_ARCH_WAVTOKENIZER_DEC, LLM_ARCH_PLM, From c5d897edf00b70a7033697bda99e7f80bc81b967 Mon Sep 17 00:00:00 2001 From: Gabe Goodhart Date: Fri, 18 Apr 2025 13:18:55 -0600 Subject: [PATCH 03/11] fix: Split MoE fused tensors for shared experts in conversion Branch: GraniteMoEShared Signed-off-by: Gabe Goodhart --- convert_hf_to_gguf.py | 18 ++++++++++++++++++ gguf-py/gguf/constants.py | 1 + gguf-py/gguf/tensor_mapping.py | 1 - 3 files changed, 19 insertions(+), 1 deletion(-) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 3d48eb5810928..a03937abcc43f 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -5707,6 +5707,24 @@ def set_gguf_parameters(self): self.gguf_writer.add_expert_shared_feed_forward_length(shared_feed_forward_length) logger.info("gguf: (granitemoeshared) shared_feed_forward_length = %s", shared_feed_forward_length) + def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + """In modeling_granitemoeshared, the implementation of parallel experts + is used. This essentially merges w1 and w3 into a single tensor with 2x + the hidden size that is then split during forward. To keep compatibility + with existing shared expert support, we pull them apart here. + """ + + if name.endswith("shared_mlp.input_linear.weight"): + ffn_dim = self.hparams["shared_intermediate_size"] + assert data_torch.shape[-2] == 2 * ffn_dim, "Merged FFN tensor size must be 2 * shared_intermediate_size" + gate, up = data_torch[..., :ffn_dim, :], data_torch[..., ffn_dim:, :] + return [ + (self.format_tensor_name(gguf.MODEL_TENSOR.FFN_GATE_SHEXP, bid), gate), + (self.format_tensor_name(gguf.MODEL_TENSOR.FFN_UP_SHEXP, bid), up), + ] + + return super().modify_tensors(data_torch, name, bid) + @ModelBase.register("BailingMoeForCausalLM") class BailingMoeModel(TextModel): diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index 7df57413e52d6..de001fd15a28a 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -1910,6 +1910,7 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.FFN_GATE_EXP, MODEL_TENSOR.FFN_DOWN_EXP, MODEL_TENSOR.FFN_UP_EXP, + MODEL_TENSOR.FFN_GATE_SHEXP, MODEL_TENSOR.FFN_UP_SHEXP, MODEL_TENSOR.FFN_DOWN_SHEXP, ], diff --git a/gguf-py/gguf/tensor_mapping.py b/gguf-py/gguf/tensor_mapping.py index b4e6634c497b4..5082cadb388e1 100644 --- a/gguf-py/gguf/tensor_mapping.py +++ b/gguf-py/gguf/tensor_mapping.py @@ -346,7 +346,6 @@ class TensorNameMap: "model.layers.{bid}.mlp.shared_expert.up_proj", # qwen2moe "model.layers.{bid}.mlp.shared_experts.up_proj", # deepseek deepseek2 "language_model.model.layers.{bid}.feed_forward.shared_expert.up_proj", # llama4 - "model.layers.{bid}.shared_mlp.input_linear", # granitemoeshared ), # AWQ-activation gate From 054059ea4b9a5d92a9ad79618e95bbe593c69b6f Mon Sep 17 00:00:00 2001 From: Gabe Goodhart Date: Fri, 18 Apr 2025 13:20:15 -0600 Subject: [PATCH 04/11] feat: First WIP cut at model arch in cpp The hparam and architecture plumbing should be correct, but the implementation of the shared experts seems to still be broken. Branch: GraniteMoEShared Signed-off-by: Gabe Goodhart --- src/llama-arch.cpp | 1 + src/llama-model.cpp | 36 +++++++++++++++++++++++++++++++++++- 2 files changed, 36 insertions(+), 1 deletion(-) diff --git a/src/llama-arch.cpp b/src/llama-arch.cpp index bff49ddbdd68b..bf13378a98b46 100644 --- a/src/llama-arch.cpp +++ b/src/llama-arch.cpp @@ -1500,6 +1500,7 @@ static const std::map> LLM_TENSOR_N { LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" }, { LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" }, { LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" }, + { LLM_TENSOR_FFN_GATE_SHEXP, "blk.%d.ffn_gate_shexp" }, { LLM_TENSOR_FFN_DOWN_SHEXP, "blk.%d.ffn_down_shexp" }, { LLM_TENSOR_FFN_UP_SHEXP, "blk.%d.ffn_up_shexp" }, }, diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 21b12339a221b..1db0e11931794 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -1372,6 +1372,7 @@ void llama_model::load_hparams(llama_model_loader & ml) { } break; case LLM_ARCH_GRANITE: case LLM_ARCH_GRANITE_MOE: + case LLM_ARCH_GRANITE_MOE_SHARED: { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); ml.get_key(LLM_KV_LOGIT_SCALE, hparams.f_logit_scale); @@ -1385,6 +1386,11 @@ void llama_model::load_hparams(llama_model_loader & ml) { // Add additional layer/vocab/etc checks here for other model sizes default: type = LLM_TYPE_UNKNOWN; } + + // For Granite MoE Shared + if (arch == LLM_ARCH_GRANITE_MOE_SHARED) { + ml.get_key(LLM_KV_EXPERT_SHARED_FEED_FORWARD_LENGTH, hparams.n_ff_shexp); + } } break; case LLM_ARCH_CHAMELEON: { @@ -1716,6 +1722,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { case LLM_ARCH_MINICPM: case LLM_ARCH_GRANITE: case LLM_ARCH_GRANITE_MOE: + case LLM_ARCH_GRANITE_MOE_SHARED: { tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); @@ -1768,6 +1775,13 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff, n_expert}, TENSOR_NOT_REQUIRED); layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), { n_ff, n_embd, n_expert}, 0); layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, n_ff, n_expert}, 0); + + // For Granite MoE Shared + if (arch == LLM_ARCH_GRANITE_MOE_SHARED) { + layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), {n_embd, hparams.n_ff_shexp}, 0); + layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), {n_embd, hparams.n_ff_shexp}, 0); + layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), {hparams.n_ff_shexp, n_embd}, 0); + } } } } break; @@ -4381,10 +4395,14 @@ void llama_model::print_info() const { LLAMA_LOG_INFO("%s: n_ff_exp = %d\n", __func__, hparams.n_ff_exp); } - if (arch == LLM_ARCH_MINICPM || arch == LLM_ARCH_GRANITE || arch == LLM_ARCH_GRANITE_MOE) { + if (arch == LLM_ARCH_MINICPM || + arch == LLM_ARCH_GRANITE || + arch == LLM_ARCH_GRANITE_MOE || + arch == LLM_ARCH_GRANITE_MOE_SHARED) { LLAMA_LOG_INFO("%s: f_embedding_scale = %f\n", __func__, hparams.f_embedding_scale); LLAMA_LOG_INFO("%s: f_residual_scale = %f\n", __func__, hparams.f_residual_scale); LLAMA_LOG_INFO("%s: f_attention_scale = %f\n", __func__, hparams.f_attention_scale); + LLAMA_LOG_INFO("%s: n_ff_shexp = %d\n", __func__, hparams.n_ff_shexp); } if (arch == LLM_ARCH_BAILINGMOE) { @@ -4668,6 +4686,20 @@ struct llm_build_llama : public llm_graph_context { LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX, il); cb(cur, "ffn_moe_out", il); + + // For Granite MoE Shared + if (model.arch == LLM_ARCH_GRANITE_MOE_SHARED) { + ggml_tensor * ffn_shexp = build_ffn(cur, + model.layers[il].ffn_up_shexp, NULL, NULL, + model.layers[il].ffn_gate_shexp, NULL, NULL, + model.layers[il].ffn_down_shexp, NULL, NULL, + NULL, + LLM_FFN_SILU, LLM_FFN_PAR, il); + cb(ffn_shexp, "ffn_shexp", il); + + cur = ggml_add(ctx0, cur, ffn_shexp); + cb(cur, "ffn_out", il); + } } // For Granite architecture @@ -12919,6 +12951,7 @@ llm_graph_result_ptr llama_model::build_graph( case LLM_ARCH_MINICPM: case LLM_ARCH_GRANITE: case LLM_ARCH_GRANITE_MOE: + case LLM_ARCH_GRANITE_MOE_SHARED: { llm = std::make_unique(*this, params, gf); } break; @@ -13296,6 +13329,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) { case LLM_ARCH_GLM4: case LLM_ARCH_GRANITE: case LLM_ARCH_GRANITE_MOE: + case LLM_ARCH_GRANITE_MOE_SHARED: case LLM_ARCH_CHAMELEON: case LLM_ARCH_BAILINGMOE: return LLAMA_ROPE_TYPE_NORM; From 5a98b485fcf6b64750ffcedaaaff143e18bc14f7 Mon Sep 17 00:00:00 2001 From: Gabe Goodhart Date: Fri, 25 Apr 2025 16:24:04 -0600 Subject: [PATCH 05/11] fix: Cleaner (maybe more correct?) splitting for gate/up Branch: GraniteMoEShared Signed-off-by: Gabe Goodhart --- convert_hf_to_gguf.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index a03937abcc43f..ec038e6b9deb4 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -5717,7 +5717,7 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter if name.endswith("shared_mlp.input_linear.weight"): ffn_dim = self.hparams["shared_intermediate_size"] assert data_torch.shape[-2] == 2 * ffn_dim, "Merged FFN tensor size must be 2 * shared_intermediate_size" - gate, up = data_torch[..., :ffn_dim, :], data_torch[..., ffn_dim:, :] + gate, up = data_torch.split(ffn_dim, dim=-2) return [ (self.format_tensor_name(gguf.MODEL_TENSOR.FFN_GATE_SHEXP, bid), gate), (self.format_tensor_name(gguf.MODEL_TENSOR.FFN_UP_SHEXP, bid), up), From 9763c9a21b7cf14ccbe33adcad60f6d2d50adbf2 Mon Sep 17 00:00:00 2001 From: Gabe Goodhart Date: Fri, 25 Apr 2025 16:24:58 -0600 Subject: [PATCH 06/11] fix: Fix the input to the shared experts I had misread that the shared experts take the inputs _before_ the standard MoE layer and was feeding the output of the MoE to the shared experts. Branch: GraniteMoEShared Signed-off-by: Gabe Goodhart --- src/llama-model.cpp | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 1db0e11931794..7b8e075952b81 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -4674,7 +4674,7 @@ struct llm_build_llama : public llm_graph_context { LLM_NORM_RMS, il); cb(cur, "ffn_norm", il); - cur = build_moe_ffn(cur, + ggml_tensor * moe_out = build_moe_ffn(cur, model.layers[il].ffn_gate_inp, model.layers[il].ffn_up_exps, model.layers[il].ffn_gate_exps, @@ -4685,7 +4685,7 @@ struct llm_build_llama : public llm_graph_context { false, 0.0, LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX, il); - cb(cur, "ffn_moe_out", il); + cb(moe_out, "ffn_moe_out", il); // For Granite MoE Shared if (model.arch == LLM_ARCH_GRANITE_MOE_SHARED) { @@ -4697,8 +4697,10 @@ struct llm_build_llama : public llm_graph_context { LLM_FFN_SILU, LLM_FFN_PAR, il); cb(ffn_shexp, "ffn_shexp", il); - cur = ggml_add(ctx0, cur, ffn_shexp); + cur = ggml_add(ctx0, moe_out, ffn_shexp); cb(cur, "ffn_out", il); + } else { + cur = moe_out; } } From 52d2ed6e0e146c3cba9e0664205cb757d5038320 Mon Sep 17 00:00:00 2001 From: Gabe Goodhart Date: Fri, 9 May 2025 12:32:54 -0600 Subject: [PATCH 07/11] fix: Avoid architecture-specific checks for Granite MoE Shared This is a cleaner way that will allow more flexibility in architecture strings going forward. Branch: GraniteMoEShared Signed-off-by: Gabe Goodhart --- src/llama-model.cpp | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 7b8e075952b81..4b53486ff3c6c 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -1388,9 +1388,7 @@ void llama_model::load_hparams(llama_model_loader & ml) { } // For Granite MoE Shared - if (arch == LLM_ARCH_GRANITE_MOE_SHARED) { - ml.get_key(LLM_KV_EXPERT_SHARED_FEED_FORWARD_LENGTH, hparams.n_ff_shexp); - } + ml.get_key(LLM_KV_EXPERT_SHARED_FEED_FORWARD_LENGTH, hparams.n_ff_shexp, /* required */ false); } break; case LLM_ARCH_CHAMELEON: { @@ -1777,7 +1775,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, n_ff, n_expert}, 0); // For Granite MoE Shared - if (arch == LLM_ARCH_GRANITE_MOE_SHARED) { + if (hparams.n_ff_shexp > 0) { layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), {n_embd, hparams.n_ff_shexp}, 0); layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), {n_embd, hparams.n_ff_shexp}, 0); layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), {hparams.n_ff_shexp, n_embd}, 0); @@ -4688,7 +4686,7 @@ struct llm_build_llama : public llm_graph_context { cb(moe_out, "ffn_moe_out", il); // For Granite MoE Shared - if (model.arch == LLM_ARCH_GRANITE_MOE_SHARED) { + if (hparams.n_ff_shexp > 0) { ggml_tensor * ffn_shexp = build_ffn(cur, model.layers[il].ffn_up_shexp, NULL, NULL, model.layers[il].ffn_gate_shexp, NULL, NULL, From 4446994924eb382cbc447b1f0a67298faaa2e8e3 Mon Sep 17 00:00:00 2001 From: Gabe Goodhart Date: Fri, 9 May 2025 12:52:55 -0600 Subject: [PATCH 08/11] refactor: Split granite architectures out of llm_build_llama This helps de-clutter the llama-family graph construction and allows granite to diverge further (in preparation for Granite 4). NOTE: I removed the granite scale factors from llm_build_deci because they appear to only be there as copy-paste from llm_build_llama. The HF config does not seem to set those values: https://huggingface.co/Deci/DeciLM-7B/blob/main/config.json Branch: GraniteMoEShared Signed-off-by: Gabe Goodhart --- src/llama-model.cpp | 247 +++++++++++++++++++++++++++++++++++--------- 1 file changed, 196 insertions(+), 51 deletions(-) diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 4b53486ff3c6c..ef8eb5a40ac07 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -4610,11 +4610,6 @@ struct llm_build_llama : public llm_graph_context { inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); } - // For Granite architecture - if (hparams.f_residual_scale) { - cur = ggml_scale(ctx0, cur, hparams.f_residual_scale); - } - ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); cb(ffn_inp, "ffn_inp", il); @@ -4672,7 +4667,7 @@ struct llm_build_llama : public llm_graph_context { LLM_NORM_RMS, il); cb(cur, "ffn_norm", il); - ggml_tensor * moe_out = build_moe_ffn(cur, + cur = build_moe_ffn(cur, model.layers[il].ffn_gate_inp, model.layers[il].ffn_up_exps, model.layers[il].ffn_gate_exps, @@ -4683,28 +4678,7 @@ struct llm_build_llama : public llm_graph_context { false, 0.0, LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX, il); - cb(moe_out, "ffn_moe_out", il); - - // For Granite MoE Shared - if (hparams.n_ff_shexp > 0) { - ggml_tensor * ffn_shexp = build_ffn(cur, - model.layers[il].ffn_up_shexp, NULL, NULL, - model.layers[il].ffn_gate_shexp, NULL, NULL, - model.layers[il].ffn_down_shexp, NULL, NULL, - NULL, - LLM_FFN_SILU, LLM_FFN_PAR, il); - cb(ffn_shexp, "ffn_shexp", il); - - cur = ggml_add(ctx0, moe_out, ffn_shexp); - cb(cur, "ffn_out", il); - } else { - cur = moe_out; - } - } - - // For Granite architecture - if (hparams.f_residual_scale) { - cur = ggml_scale(ctx0, cur, hparams.f_residual_scale); + cb(cur, "ffn_moe_out", il); } cur = ggml_add(ctx0, cur, ffn_inp); @@ -4729,11 +4703,6 @@ struct llm_build_llama : public llm_graph_context { // lm_head cur = build_lora_mm(model.output, cur); - // For Granite architecture - if (hparams.f_logit_scale) { - cur = ggml_scale(ctx0, cur, 1.0f / hparams.f_logit_scale); - } - cb(cur, "result_output", -1); res->t_logits = cur; @@ -4844,11 +4813,6 @@ struct llm_build_deci : public llm_graph_context { continue; } - // For Granite architecture - if (hparams.f_residual_scale) { - cur = ggml_scale(ctx0, cur, hparams.f_residual_scale); - } - // modified to support attention-free layer of Llama-3_1-Nemotron-51B ggml_tensor * ffn_inp = cur; if (n_head > 0) { @@ -4872,11 +4836,6 @@ struct llm_build_deci : public llm_graph_context { cb(cur, "ffn_out", il); } - // For Granite architecture - if (hparams.f_residual_scale) { - cur = ggml_scale(ctx0, cur, hparams.f_residual_scale); - } - cur = ggml_add(ctx0, cur, ffn_inp); cb(cur, "ffn_out", il); @@ -4899,11 +4858,6 @@ struct llm_build_deci : public llm_graph_context { // lm_head cur = build_lora_mm(model.output, cur); - // For Granite architecture - if (hparams.f_logit_scale) { - cur = ggml_scale(ctx0, cur, 1.0f / hparams.f_logit_scale); - } - cb(cur, "result_output", -1); res->t_logits = cur; @@ -12242,6 +12196,194 @@ struct llm_build_arwkv7 : public llm_build_rwkv7_base { } }; + +struct llm_build_granite : public llm_graph_context { + llm_build_granite( + const llama_model & model, + const llm_graph_params & params, + ggml_cgraph * gf, + const bool use_rope = true) + : llm_graph_context(params) { + + const int64_t n_embd_head = hparams.n_embd_head_v; + + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + GGML_ASSERT(n_embd_head == hparams.n_rot); + + ggml_tensor * cur; + ggml_tensor * inpL; + + inpL = build_inp_embd(model.tok_embd); + + // inp_pos - used for rope if enabled + ggml_tensor * inp_pos; + if (use_rope) { + inp_pos = build_inp_pos(); + } + + auto * inp_attn = build_attn_inp_kv_unified(); + + const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f/sqrtf(float(n_embd_head)) : hparams.f_attention_scale; + for (int il = 0; il < n_layer; ++il) { + ggml_tensor * inpSA = inpL; + + // norm + cur = build_norm(inpL, + model.layers[il].attn_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "attn_norm", il); + + // self-attention + { + // compute Q and K and (optionally) RoPE them + ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); + cb(Qcur, "Qcur", il); + if (model.layers[il].bq) { + Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); + cb(Qcur, "Qcur", il); + } + + ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); + cb(Kcur, "Kcur", il); + if (model.layers[il].bk) { + Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); + cb(Kcur, "Kcur", il); + } + + ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); + cb(Vcur, "Vcur", il); + if (model.layers[il].bv) { + Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); + cb(Vcur, "Vcur", il); + } + + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + + if (use_rope) { + ggml_tensor * rope_factors = model.get_rope_factors(n_ctx_per_seq, il); + Qcur = ggml_rope_ext( + ctx0, Qcur, inp_pos, rope_factors, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + Kcur = ggml_rope_ext( + ctx0, Kcur, inp_pos, rope_factors, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + } + + cb(Qcur, "Qcur", il); + cb(Kcur, "Kcur", il); + cb(Vcur, "Vcur", il); + + cur = build_attn(inp_attn, gf, + model.layers[il].wo, model.layers[il].bo, + Qcur, Kcur, Vcur, nullptr, nullptr, kq_scale, il); + cb(cur, "attn_out", il); + } + + if (il == n_layer - 1) { + // skip computing output for unused tokens + ggml_tensor * inp_out_ids = build_inp_out_ids(); + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); + } + + // For Granite architectures - scale residual + cur = ggml_scale(ctx0, cur, hparams.f_residual_scale); + ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); + cb(ffn_inp, "ffn_inp", il); + + // feed-forward network (non-MoE) + if (model.layers[il].ffn_gate_inp == nullptr) { + + cur = build_norm(ffn_inp, + model.layers[il].ffn_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "ffn_norm", il); + + cur = build_ffn(cur, + model.layers[il].ffn_up, model.layers[il].ffn_up_b, NULL, + model.layers[il].ffn_gate, model.layers[il].ffn_gate_b, NULL, + model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL, + NULL, + LLM_FFN_SILU, LLM_FFN_PAR, il); + cb(cur, "ffn_out", il); + + } else { + // MoE branch + cur = build_norm(ffn_inp, + model.layers[il].ffn_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "ffn_norm", il); + + ggml_tensor * moe_out = build_moe_ffn(cur, + model.layers[il].ffn_gate_inp, + model.layers[il].ffn_up_exps, + model.layers[il].ffn_gate_exps, + model.layers[il].ffn_down_exps, + nullptr, + n_expert, n_expert_used, + LLM_FFN_SILU, true, + false, 0.0, + LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX, + il); + cb(moe_out, "ffn_moe_out", il); + + // For Granite MoE Shared + if (hparams.n_ff_shexp > 0) { + ggml_tensor * ffn_shexp = build_ffn(cur, + model.layers[il].ffn_up_shexp, NULL, NULL, + model.layers[il].ffn_gate_shexp, NULL, NULL, + model.layers[il].ffn_down_shexp, NULL, NULL, + NULL, + LLM_FFN_SILU, LLM_FFN_PAR, il); + cb(ffn_shexp, "ffn_shexp", il); + + cur = ggml_add(ctx0, moe_out, ffn_shexp); + cb(cur, "ffn_out", il); + } else { + cur = moe_out; + } + } + + // For Granite architectures - scale residual + cur = ggml_scale(ctx0, cur, hparams.f_residual_scale); + cur = ggml_add(ctx0, cur, ffn_inp); + cb(cur, "ffn_out", il); + + cur = build_cvec(cur, il); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; + } + + cur = inpL; + + cur = build_norm(cur, + model.output_norm, NULL, + LLM_NORM_RMS, -1); + + cb(cur, "result_norm", -1); + res->t_embd = cur; + + // lm_head + cur = build_lora_mm(model.output, cur); + + // For Granite architectures - scale logits + cur = ggml_scale(ctx0, cur, 1.0f / hparams.f_logit_scale); + cb(cur, "result_output", -1); + res->t_logits = cur; + + ggml_build_forward_expand(gf, cur); + } +}; + // ref: https://github.com/facebookresearch/chameleon // based on the original build_llama() function, changes: // * qk-norm @@ -12949,9 +13091,6 @@ llm_graph_result_ptr llama_model::build_graph( case LLM_ARCH_LLAMA: case LLM_ARCH_LLAMA4: case LLM_ARCH_MINICPM: - case LLM_ARCH_GRANITE: - case LLM_ARCH_GRANITE_MOE: - case LLM_ARCH_GRANITE_MOE_SHARED: { llm = std::make_unique(*this, params, gf); } break; @@ -13182,6 +13321,12 @@ llm_graph_result_ptr llama_model::build_graph( { llm = std::make_unique(*this, params, gf); } break; + case LLM_ARCH_GRANITE: + case LLM_ARCH_GRANITE_MOE: + case LLM_ARCH_GRANITE_MOE_SHARED: + { + llm = std::make_unique(*this, params, gf); + } break; case LLM_ARCH_CHAMELEON: { llm = std::make_unique(*this, params, gf); From 3d792146e17e03ee4ed8c09ce40eb3d66b2801a6 Mon Sep 17 00:00:00 2001 From: Gabe Goodhart Date: Fri, 9 May 2025 13:03:52 -0600 Subject: [PATCH 09/11] fix: Fix compiler warning about uninitialized inp_pos This should not have been reachable, but it warns on some compliers Branch: GraniteMoEShared Signed-off-by: Gabe Goodhart --- src/llama-model.cpp | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/src/llama-model.cpp b/src/llama-model.cpp index ef8eb5a40ac07..f33cc470eec9f 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -12215,11 +12215,8 @@ struct llm_build_granite : public llm_graph_context { inpL = build_inp_embd(model.tok_embd); - // inp_pos - used for rope if enabled - ggml_tensor * inp_pos; - if (use_rope) { - inp_pos = build_inp_pos(); - } + // inp_pos - built only if rope enabled + ggml_tensor * inp_pos = nullptr; auto * inp_attn = build_attn_inp_kv_unified(); @@ -12262,6 +12259,10 @@ struct llm_build_granite : public llm_graph_context { Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); if (use_rope) { + + if (!inp_pos) { + inp_pos = build_inp_pos(); + } ggml_tensor * rope_factors = model.get_rope_factors(n_ctx_per_seq, il); Qcur = ggml_rope_ext( ctx0, Qcur, inp_pos, rope_factors, From 2aed91c3340f5a655539512a1dc653fab2dd8d79 Mon Sep 17 00:00:00 2001 From: Gabe Goodhart Date: Mon, 12 May 2025 21:11:59 -0600 Subject: [PATCH 10/11] fix: Consoladate GraniteMoEShared into GraniteMoE for conversion Branch: GraniteMoEShared Signed-off-by: Gabe Goodhart --- convert_hf_to_gguf.py | 37 ++--- gguf-py/gguf/constants.py | 289 ++++++++++++++++----------------- gguf-py/gguf/tensor_mapping.py | 2 +- 3 files changed, 148 insertions(+), 180 deletions(-) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index ec038e6b9deb4..8f9d21143d6b1 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -5669,11 +5669,20 @@ def set_gguf_parameters(self): logger.info("gguf: (granite) logits_scale = %s", logits_scale) -@ModelBase.register("GraniteMoeForCausalLM") +@ModelBase.register("GraniteMoeForCausalLM", "GraniteMoeSharedForCausalLM") class GraniteMoeModel(GraniteModel): """Conversion for IBM's GraniteMoeForCausalLM""" model_arch = gguf.MODEL_ARCH.GRANITE_MOE + def set_gguf_parameters(self): + """GraniteMoeShared uses GraniteMoe parameters plus the following: + - shared_intermediate_size + """ + super().set_gguf_parameters() + if shared_feed_forward_length := self.hparams.get("shared_intermediate_size"): + self.gguf_writer.add_expert_shared_feed_forward_length(shared_feed_forward_length) + logger.info("gguf: (granitemoeshared) shared_feed_forward_length = %s", shared_feed_forward_length) + def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: """In modeling_granitemoe, the JetMoe implementation of parallel experts is used. This essentially merges w1 and w3 into a single tensor with 2x @@ -5684,36 +5693,12 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter if name.endswith("block_sparse_moe.input_linear.weight"): ffn_dim = self.hparams["intermediate_size"] assert data_torch.shape[-2] == 2 * ffn_dim, "Merged FFN tensor size must be 2 * intermediate_size" - gate, up = data_torch[..., :ffn_dim, :], data_torch[..., ffn_dim:, :] + gate, up = data_torch.split(ffn_dim, dim=-2) return [ (self.format_tensor_name(gguf.MODEL_TENSOR.FFN_GATE_EXP, bid), gate), (self.format_tensor_name(gguf.MODEL_TENSOR.FFN_UP_EXP, bid), up), ] - return super().modify_tensors(data_torch, name, bid) - - -@ModelBase.register("GraniteMoeSharedForCausalLM") -class GraniteMoeSharedModel(GraniteMoeModel): - """Conversion for IBM's GraniteMoeSharedForCausalLM""" - model_arch = gguf.MODEL_ARCH.GRANITE_MOE_SHARED - - def set_gguf_parameters(self): - """GraniteMoeShared uses GraniteMoe parameters plus the following: - - shared_intermediate_size - """ - super().set_gguf_parameters() - if shared_feed_forward_length := self.hparams.get("shared_intermediate_size"): - self.gguf_writer.add_expert_shared_feed_forward_length(shared_feed_forward_length) - logger.info("gguf: (granitemoeshared) shared_feed_forward_length = %s", shared_feed_forward_length) - - def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: - """In modeling_granitemoeshared, the implementation of parallel experts - is used. This essentially merges w1 and w3 into a single tensor with 2x - the hidden size that is then split during forward. To keep compatibility - with existing shared expert support, we pull them apart here. - """ - if name.endswith("shared_mlp.input_linear.weight"): ffn_dim = self.hparams["shared_intermediate_size"] assert data_torch.shape[-2] == 2 * ffn_dim, "Merged FFN tensor size must be 2 * shared_intermediate_size" diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index de001fd15a28a..7f0968bb4806e 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -255,75 +255,74 @@ class GGUFType: class MODEL_ARCH(IntEnum): - CLIP_VISION = auto() # dummy arch for clip.cpp - LLAMA = auto() - LLAMA4 = auto() - DECI = auto() - FALCON = auto() - BAICHUAN = auto() - GROK = auto() - GPT2 = auto() - GPTJ = auto() - GPTNEOX = auto() - MPT = auto() - STARCODER = auto() - REFACT = auto() - BERT = auto() - NOMIC_BERT = auto() - NOMIC_BERT_MOE = auto() - JINA_BERT_V2 = auto() - BLOOM = auto() - STABLELM = auto() - QWEN = auto() - QWEN2 = auto() - QWEN2MOE = auto() - QWEN2VL = auto() - QWEN3 = auto() - QWEN3MOE = auto() - PHI2 = auto() - PHI3 = auto() - PHIMOE = auto() - PLAMO = auto() - CODESHELL = auto() - ORION = auto() - INTERNLM2 = auto() - MINICPM = auto() - MINICPM3 = auto() - GEMMA = auto() - GEMMA2 = auto() - GEMMA3 = auto() - STARCODER2 = auto() - RWKV6 = auto() - RWKV6QWEN2 = auto() - RWKV7 = auto() - ARWKV7 = auto() - MAMBA = auto() - XVERSE = auto() - COMMAND_R = auto() - COHERE2 = auto() - DBRX = auto() - OLMO = auto() - OLMO2 = auto() - OLMOE = auto() - OPENELM = auto() - ARCTIC = auto() - DEEPSEEK = auto() - DEEPSEEK2 = auto() - CHATGLM = auto() - GLM4 = auto() - BITNET = auto() - T5 = auto() - T5ENCODER = auto() - JAIS = auto() - NEMOTRON = auto() - EXAONE = auto() - GRANITE = auto() - GRANITE_MOE = auto() - GRANITE_MOE_SHARED = auto() - CHAMELEON = auto() - WAVTOKENIZER_DEC = auto() - PLM = auto() - BAILINGMOE = auto() + CLIP_VISION = auto() # dummy arch for clip.cpp + LLAMA = auto() + LLAMA4 = auto() + DECI = auto() + FALCON = auto() + BAICHUAN = auto() + GROK = auto() + GPT2 = auto() + GPTJ = auto() + GPTNEOX = auto() + MPT = auto() + STARCODER = auto() + REFACT = auto() + BERT = auto() + NOMIC_BERT = auto() + NOMIC_BERT_MOE = auto() + JINA_BERT_V2 = auto() + BLOOM = auto() + STABLELM = auto() + QWEN = auto() + QWEN2 = auto() + QWEN2MOE = auto() + QWEN2VL = auto() + QWEN3 = auto() + QWEN3MOE = auto() + PHI2 = auto() + PHI3 = auto() + PHIMOE = auto() + PLAMO = auto() + CODESHELL = auto() + ORION = auto() + INTERNLM2 = auto() + MINICPM = auto() + MINICPM3 = auto() + GEMMA = auto() + GEMMA2 = auto() + GEMMA3 = auto() + STARCODER2 = auto() + RWKV6 = auto() + RWKV6QWEN2 = auto() + RWKV7 = auto() + ARWKV7 = auto() + MAMBA = auto() + XVERSE = auto() + COMMAND_R = auto() + COHERE2 = auto() + DBRX = auto() + OLMO = auto() + OLMO2 = auto() + OLMOE = auto() + OPENELM = auto() + ARCTIC = auto() + DEEPSEEK = auto() + DEEPSEEK2 = auto() + CHATGLM = auto() + GLM4 = auto() + BITNET = auto() + T5 = auto() + T5ENCODER = auto() + JAIS = auto() + NEMOTRON = auto() + EXAONE = auto() + GRANITE = auto() + GRANITE_MOE = auto() + CHAMELEON = auto() + WAVTOKENIZER_DEC = auto() + PLM = auto() + BAILINGMOE = auto() class VISION_PROJECTOR_TYPE(IntEnum): @@ -513,75 +512,74 @@ class MODEL_TENSOR(IntEnum): MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = { - MODEL_ARCH.CLIP_VISION: "clip", # dummy arch for clip.cpp - MODEL_ARCH.LLAMA: "llama", - MODEL_ARCH.LLAMA4: "llama4", - MODEL_ARCH.DECI: "deci", - MODEL_ARCH.FALCON: "falcon", - MODEL_ARCH.BAICHUAN: "baichuan", - MODEL_ARCH.GROK: "grok", - MODEL_ARCH.GPT2: "gpt2", - MODEL_ARCH.GPTJ: "gptj", - MODEL_ARCH.GPTNEOX: "gptneox", - MODEL_ARCH.MPT: "mpt", - MODEL_ARCH.STARCODER: "starcoder", - MODEL_ARCH.REFACT: "refact", - MODEL_ARCH.BERT: "bert", - MODEL_ARCH.NOMIC_BERT: "nomic-bert", - MODEL_ARCH.NOMIC_BERT_MOE: "nomic-bert-moe", - MODEL_ARCH.JINA_BERT_V2: "jina-bert-v2", - MODEL_ARCH.BLOOM: "bloom", - MODEL_ARCH.STABLELM: "stablelm", - MODEL_ARCH.QWEN: "qwen", - MODEL_ARCH.QWEN2: "qwen2", - MODEL_ARCH.QWEN2MOE: "qwen2moe", - MODEL_ARCH.QWEN2VL: "qwen2vl", - MODEL_ARCH.QWEN3: "qwen3", - MODEL_ARCH.QWEN3MOE: "qwen3moe", - MODEL_ARCH.PHI2: "phi2", - MODEL_ARCH.PHI3: "phi3", - MODEL_ARCH.PHIMOE: "phimoe", - MODEL_ARCH.PLAMO: "plamo", - MODEL_ARCH.CODESHELL: "codeshell", - MODEL_ARCH.ORION: "orion", - MODEL_ARCH.INTERNLM2: "internlm2", - MODEL_ARCH.MINICPM: "minicpm", - MODEL_ARCH.MINICPM3: "minicpm3", - MODEL_ARCH.GEMMA: "gemma", - MODEL_ARCH.GEMMA2: "gemma2", - MODEL_ARCH.GEMMA3: "gemma3", - MODEL_ARCH.STARCODER2: "starcoder2", - MODEL_ARCH.RWKV6: "rwkv6", - MODEL_ARCH.RWKV6QWEN2: "rwkv6qwen2", - MODEL_ARCH.RWKV7: "rwkv7", - MODEL_ARCH.ARWKV7: "arwkv7", - MODEL_ARCH.MAMBA: "mamba", - MODEL_ARCH.XVERSE: "xverse", - MODEL_ARCH.COMMAND_R: "command-r", - MODEL_ARCH.COHERE2: "cohere2", - MODEL_ARCH.DBRX: "dbrx", - MODEL_ARCH.OLMO: "olmo", - MODEL_ARCH.OLMO2: "olmo2", - MODEL_ARCH.OLMOE: "olmoe", - MODEL_ARCH.OPENELM: "openelm", - MODEL_ARCH.ARCTIC: "arctic", - MODEL_ARCH.DEEPSEEK: "deepseek", - MODEL_ARCH.DEEPSEEK2: "deepseek2", - MODEL_ARCH.CHATGLM: "chatglm", - MODEL_ARCH.GLM4: "glm4", - MODEL_ARCH.BITNET: "bitnet", - MODEL_ARCH.T5: "t5", - MODEL_ARCH.T5ENCODER: "t5encoder", - MODEL_ARCH.JAIS: "jais", - MODEL_ARCH.NEMOTRON: "nemotron", - MODEL_ARCH.EXAONE: "exaone", - MODEL_ARCH.GRANITE: "granite", - MODEL_ARCH.GRANITE_MOE: "granitemoe", - MODEL_ARCH.GRANITE_MOE_SHARED: "granitemoeshared", - MODEL_ARCH.CHAMELEON: "chameleon", - MODEL_ARCH.WAVTOKENIZER_DEC: "wavtokenizer-dec", - MODEL_ARCH.PLM: "plm", - MODEL_ARCH.BAILINGMOE: "bailingmoe", + MODEL_ARCH.CLIP_VISION: "clip", # dummy arch for clip.cpp + MODEL_ARCH.LLAMA: "llama", + MODEL_ARCH.LLAMA4: "llama4", + MODEL_ARCH.DECI: "deci", + MODEL_ARCH.FALCON: "falcon", + MODEL_ARCH.BAICHUAN: "baichuan", + MODEL_ARCH.GROK: "grok", + MODEL_ARCH.GPT2: "gpt2", + MODEL_ARCH.GPTJ: "gptj", + MODEL_ARCH.GPTNEOX: "gptneox", + MODEL_ARCH.MPT: "mpt", + MODEL_ARCH.STARCODER: "starcoder", + MODEL_ARCH.REFACT: "refact", + MODEL_ARCH.BERT: "bert", + MODEL_ARCH.NOMIC_BERT: "nomic-bert", + MODEL_ARCH.NOMIC_BERT_MOE: "nomic-bert-moe", + MODEL_ARCH.JINA_BERT_V2: "jina-bert-v2", + MODEL_ARCH.BLOOM: "bloom", + MODEL_ARCH.STABLELM: "stablelm", + MODEL_ARCH.QWEN: "qwen", + MODEL_ARCH.QWEN2: "qwen2", + MODEL_ARCH.QWEN2MOE: "qwen2moe", + MODEL_ARCH.QWEN2VL: "qwen2vl", + MODEL_ARCH.QWEN3: "qwen3", + MODEL_ARCH.QWEN3MOE: "qwen3moe", + MODEL_ARCH.PHI2: "phi2", + MODEL_ARCH.PHI3: "phi3", + MODEL_ARCH.PHIMOE: "phimoe", + MODEL_ARCH.PLAMO: "plamo", + MODEL_ARCH.CODESHELL: "codeshell", + MODEL_ARCH.ORION: "orion", + MODEL_ARCH.INTERNLM2: "internlm2", + MODEL_ARCH.MINICPM: "minicpm", + MODEL_ARCH.MINICPM3: "minicpm3", + MODEL_ARCH.GEMMA: "gemma", + MODEL_ARCH.GEMMA2: "gemma2", + MODEL_ARCH.GEMMA3: "gemma3", + MODEL_ARCH.STARCODER2: "starcoder2", + MODEL_ARCH.RWKV6: "rwkv6", + MODEL_ARCH.RWKV6QWEN2: "rwkv6qwen2", + MODEL_ARCH.RWKV7: "rwkv7", + MODEL_ARCH.ARWKV7: "arwkv7", + MODEL_ARCH.MAMBA: "mamba", + MODEL_ARCH.XVERSE: "xverse", + MODEL_ARCH.COMMAND_R: "command-r", + MODEL_ARCH.COHERE2: "cohere2", + MODEL_ARCH.DBRX: "dbrx", + MODEL_ARCH.OLMO: "olmo", + MODEL_ARCH.OLMO2: "olmo2", + MODEL_ARCH.OLMOE: "olmoe", + MODEL_ARCH.OPENELM: "openelm", + MODEL_ARCH.ARCTIC: "arctic", + MODEL_ARCH.DEEPSEEK: "deepseek", + MODEL_ARCH.DEEPSEEK2: "deepseek2", + MODEL_ARCH.CHATGLM: "chatglm", + MODEL_ARCH.GLM4: "glm4", + MODEL_ARCH.BITNET: "bitnet", + MODEL_ARCH.T5: "t5", + MODEL_ARCH.T5ENCODER: "t5encoder", + MODEL_ARCH.JAIS: "jais", + MODEL_ARCH.NEMOTRON: "nemotron", + MODEL_ARCH.EXAONE: "exaone", + MODEL_ARCH.GRANITE: "granite", + MODEL_ARCH.GRANITE_MOE: "granitemoe", + MODEL_ARCH.CHAMELEON: "chameleon", + MODEL_ARCH.WAVTOKENIZER_DEC: "wavtokenizer-dec", + MODEL_ARCH.PLM: "plm", + MODEL_ARCH.BAILINGMOE: "bailingmoe", } VISION_PROJECTOR_TYPE_NAMES: dict[VISION_PROJECTOR_TYPE, str] = { @@ -1895,21 +1893,6 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.FFN_GATE_EXP, MODEL_TENSOR.FFN_DOWN_EXP, MODEL_TENSOR.FFN_UP_EXP, - ], - MODEL_ARCH.GRANITE_MOE_SHARED: [ - MODEL_TENSOR.TOKEN_EMBD, - MODEL_TENSOR.OUTPUT_NORM, - MODEL_TENSOR.OUTPUT, - MODEL_TENSOR.ATTN_NORM, - MODEL_TENSOR.ATTN_Q, - MODEL_TENSOR.ATTN_K, - MODEL_TENSOR.ATTN_V, - MODEL_TENSOR.ATTN_OUT, - MODEL_TENSOR.FFN_NORM, - MODEL_TENSOR.FFN_GATE_INP, - MODEL_TENSOR.FFN_GATE_EXP, - MODEL_TENSOR.FFN_DOWN_EXP, - MODEL_TENSOR.FFN_UP_EXP, MODEL_TENSOR.FFN_GATE_SHEXP, MODEL_TENSOR.FFN_UP_SHEXP, MODEL_TENSOR.FFN_DOWN_SHEXP, diff --git a/gguf-py/gguf/tensor_mapping.py b/gguf-py/gguf/tensor_mapping.py index 5082cadb388e1..83c5b27ae1199 100644 --- a/gguf-py/gguf/tensor_mapping.py +++ b/gguf-py/gguf/tensor_mapping.py @@ -428,7 +428,7 @@ class TensorNameMap: "model.layers.{bid}.mlp.shared_expert.down_proj", # qwen2moe "model.layers.{bid}.mlp.shared_experts.down_proj", # deepseek deepseek2 "language_model.model.layers.{bid}.feed_forward.shared_expert.down_proj", # llama4 - "model.layers.{bid}.shared_mlp.output_linear", # granitemoeshared + "model.layers.{bid}.shared_mlp.output_linear", # granitemoe ), MODEL_TENSOR.ATTN_Q_NORM: ( From 33008e8c852e5ab5d56c108450fa8839dc1b6481 Mon Sep 17 00:00:00 2001 From: Gabe Goodhart Date: Mon, 12 May 2025 21:12:42 -0600 Subject: [PATCH 11/11] fix: Consolidate GraniteMoEShared into GraniteMoE on the c++ side Branch: GraniteMoEShared Signed-off-by: Gabe Goodhart --- src/llama-arch.cpp | 155 +++++++++++++++++++------------------------- src/llama-arch.h | 1 - src/llama-model.cpp | 7 +- 3 files changed, 69 insertions(+), 94 deletions(-) diff --git a/src/llama-arch.cpp b/src/llama-arch.cpp index bf13378a98b46..abf436adac416 100644 --- a/src/llama-arch.cpp +++ b/src/llama-arch.cpp @@ -5,75 +5,74 @@ #include static const std::map LLM_ARCH_NAMES = { - { LLM_ARCH_LLAMA, "llama" }, - { LLM_ARCH_LLAMA4, "llama4" }, - { LLM_ARCH_DECI, "deci" }, - { LLM_ARCH_FALCON, "falcon" }, - { LLM_ARCH_GROK, "grok" }, - { LLM_ARCH_GPT2, "gpt2" }, - { LLM_ARCH_GPTJ, "gptj" }, - { LLM_ARCH_GPTNEOX, "gptneox" }, - { LLM_ARCH_MPT, "mpt" }, - { LLM_ARCH_BAICHUAN, "baichuan" }, - { LLM_ARCH_STARCODER, "starcoder" }, - { LLM_ARCH_REFACT, "refact" }, - { LLM_ARCH_BERT, "bert" }, - { LLM_ARCH_NOMIC_BERT, "nomic-bert" }, - { LLM_ARCH_NOMIC_BERT_MOE, "nomic-bert-moe" }, - { LLM_ARCH_JINA_BERT_V2, "jina-bert-v2" }, - { LLM_ARCH_BLOOM, "bloom" }, - { LLM_ARCH_STABLELM, "stablelm" }, - { LLM_ARCH_QWEN, "qwen" }, - { LLM_ARCH_QWEN2, "qwen2" }, - { LLM_ARCH_QWEN2MOE, "qwen2moe" }, - { LLM_ARCH_QWEN2VL, "qwen2vl" }, - { LLM_ARCH_QWEN3, "qwen3" }, - { LLM_ARCH_QWEN3MOE, "qwen3moe" }, - { LLM_ARCH_PHI2, "phi2" }, - { LLM_ARCH_PHI3, "phi3" }, - { LLM_ARCH_PHIMOE, "phimoe" }, - { LLM_ARCH_PLAMO, "plamo" }, - { LLM_ARCH_CODESHELL, "codeshell" }, - { LLM_ARCH_ORION, "orion" }, - { LLM_ARCH_INTERNLM2, "internlm2" }, - { LLM_ARCH_MINICPM, "minicpm" }, - { LLM_ARCH_MINICPM3, "minicpm3" }, - { LLM_ARCH_GEMMA, "gemma" }, - { LLM_ARCH_GEMMA2, "gemma2" }, - { LLM_ARCH_GEMMA3, "gemma3" }, - { LLM_ARCH_STARCODER2, "starcoder2" }, - { LLM_ARCH_MAMBA, "mamba" }, - { LLM_ARCH_XVERSE, "xverse" }, - { LLM_ARCH_COMMAND_R, "command-r" }, - { LLM_ARCH_COHERE2, "cohere2" }, - { LLM_ARCH_DBRX, "dbrx" }, - { LLM_ARCH_OLMO, "olmo" }, - { LLM_ARCH_OLMO2, "olmo2" }, - { LLM_ARCH_OLMOE, "olmoe" }, - { LLM_ARCH_OPENELM, "openelm" }, - { LLM_ARCH_ARCTIC, "arctic" }, - { LLM_ARCH_DEEPSEEK, "deepseek" }, - { LLM_ARCH_DEEPSEEK2, "deepseek2" }, - { LLM_ARCH_CHATGLM, "chatglm" }, - { LLM_ARCH_GLM4, "glm4" }, - { LLM_ARCH_BITNET, "bitnet" }, - { LLM_ARCH_T5, "t5" }, - { LLM_ARCH_T5ENCODER, "t5encoder" }, - { LLM_ARCH_JAIS, "jais" }, - { LLM_ARCH_NEMOTRON, "nemotron" }, - { LLM_ARCH_EXAONE, "exaone" }, - { LLM_ARCH_RWKV6, "rwkv6" }, - { LLM_ARCH_RWKV6QWEN2, "rwkv6qwen2" }, - { LLM_ARCH_RWKV7, "rwkv7" }, - { LLM_ARCH_ARWKV7, "arwkv7" }, - { LLM_ARCH_GRANITE, "granite" }, - { LLM_ARCH_GRANITE_MOE, "granitemoe" }, - { LLM_ARCH_GRANITE_MOE_SHARED, "granitemoeshared" }, - { LLM_ARCH_CHAMELEON, "chameleon" }, - { LLM_ARCH_WAVTOKENIZER_DEC, "wavtokenizer-dec" }, - { LLM_ARCH_PLM, "plm" }, - { LLM_ARCH_BAILINGMOE, "bailingmoe" }, - { LLM_ARCH_UNKNOWN, "(unknown)" }, + { LLM_ARCH_LLAMA, "llama" }, + { LLM_ARCH_LLAMA4, "llama4" }, + { LLM_ARCH_DECI, "deci" }, + { LLM_ARCH_FALCON, "falcon" }, + { LLM_ARCH_GROK, "grok" }, + { LLM_ARCH_GPT2, "gpt2" }, + { LLM_ARCH_GPTJ, "gptj" }, + { LLM_ARCH_GPTNEOX, "gptneox" }, + { LLM_ARCH_MPT, "mpt" }, + { LLM_ARCH_BAICHUAN, "baichuan" }, + { LLM_ARCH_STARCODER, "starcoder" }, + { LLM_ARCH_REFACT, "refact" }, + { LLM_ARCH_BERT, "bert" }, + { LLM_ARCH_NOMIC_BERT, "nomic-bert" }, + { LLM_ARCH_NOMIC_BERT_MOE, "nomic-bert-moe" }, + { LLM_ARCH_JINA_BERT_V2, "jina-bert-v2" }, + { LLM_ARCH_BLOOM, "bloom" }, + { LLM_ARCH_STABLELM, "stablelm" }, + { LLM_ARCH_QWEN, "qwen" }, + { LLM_ARCH_QWEN2, "qwen2" }, + { LLM_ARCH_QWEN2MOE, "qwen2moe" }, + { LLM_ARCH_QWEN2VL, "qwen2vl" }, + { LLM_ARCH_QWEN3, "qwen3" }, + { LLM_ARCH_QWEN3MOE, "qwen3moe" }, + { LLM_ARCH_PHI2, "phi2" }, + { LLM_ARCH_PHI3, "phi3" }, + { LLM_ARCH_PHIMOE, "phimoe" }, + { LLM_ARCH_PLAMO, "plamo" }, + { LLM_ARCH_CODESHELL, "codeshell" }, + { LLM_ARCH_ORION, "orion" }, + { LLM_ARCH_INTERNLM2, "internlm2" }, + { LLM_ARCH_MINICPM, "minicpm" }, + { LLM_ARCH_MINICPM3, "minicpm3" }, + { LLM_ARCH_GEMMA, "gemma" }, + { LLM_ARCH_GEMMA2, "gemma2" }, + { LLM_ARCH_GEMMA3, "gemma3" }, + { LLM_ARCH_STARCODER2, "starcoder2" }, + { LLM_ARCH_MAMBA, "mamba" }, + { LLM_ARCH_XVERSE, "xverse" }, + { LLM_ARCH_COMMAND_R, "command-r" }, + { LLM_ARCH_COHERE2, "cohere2" }, + { LLM_ARCH_DBRX, "dbrx" }, + { LLM_ARCH_OLMO, "olmo" }, + { LLM_ARCH_OLMO2, "olmo2" }, + { LLM_ARCH_OLMOE, "olmoe" }, + { LLM_ARCH_OPENELM, "openelm" }, + { LLM_ARCH_ARCTIC, "arctic" }, + { LLM_ARCH_DEEPSEEK, "deepseek" }, + { LLM_ARCH_DEEPSEEK2, "deepseek2" }, + { LLM_ARCH_CHATGLM, "chatglm" }, + { LLM_ARCH_GLM4, "glm4" }, + { LLM_ARCH_BITNET, "bitnet" }, + { LLM_ARCH_T5, "t5" }, + { LLM_ARCH_T5ENCODER, "t5encoder" }, + { LLM_ARCH_JAIS, "jais" }, + { LLM_ARCH_NEMOTRON, "nemotron" }, + { LLM_ARCH_EXAONE, "exaone" }, + { LLM_ARCH_RWKV6, "rwkv6" }, + { LLM_ARCH_RWKV6QWEN2, "rwkv6qwen2" }, + { LLM_ARCH_RWKV7, "rwkv7" }, + { LLM_ARCH_ARWKV7, "arwkv7" }, + { LLM_ARCH_GRANITE, "granite" }, + { LLM_ARCH_GRANITE_MOE, "granitemoe" }, + { LLM_ARCH_CHAMELEON, "chameleon" }, + { LLM_ARCH_WAVTOKENIZER_DEC, "wavtokenizer-dec" }, + { LLM_ARCH_PLM, "plm" }, + { LLM_ARCH_BAILINGMOE, "bailingmoe" }, + { LLM_ARCH_UNKNOWN, "(unknown)" }, }; static const std::map LLM_KV_NAMES = { @@ -1468,24 +1467,6 @@ static const std::map> LLM_TENSOR_N }, { LLM_ARCH_GRANITE_MOE, - { - { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, - { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, - { LLM_TENSOR_OUTPUT, "output" }, - { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, - { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, - { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, - { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, - { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, - { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, - { LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" }, - { LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" }, - { LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" }, - { LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" }, - }, - }, - { - LLM_ARCH_GRANITE_MOE_SHARED, { { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, diff --git a/src/llama-arch.h b/src/llama-arch.h index 6462273244711..41a023da3da6e 100644 --- a/src/llama-arch.h +++ b/src/llama-arch.h @@ -72,7 +72,6 @@ enum llm_arch { LLM_ARCH_ARWKV7, LLM_ARCH_GRANITE, LLM_ARCH_GRANITE_MOE, - LLM_ARCH_GRANITE_MOE_SHARED, LLM_ARCH_CHAMELEON, LLM_ARCH_WAVTOKENIZER_DEC, LLM_ARCH_PLM, diff --git a/src/llama-model.cpp b/src/llama-model.cpp index f33cc470eec9f..65f80da9372fe 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -1372,7 +1372,6 @@ void llama_model::load_hparams(llama_model_loader & ml) { } break; case LLM_ARCH_GRANITE: case LLM_ARCH_GRANITE_MOE: - case LLM_ARCH_GRANITE_MOE_SHARED: { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); ml.get_key(LLM_KV_LOGIT_SCALE, hparams.f_logit_scale); @@ -1720,7 +1719,6 @@ bool llama_model::load_tensors(llama_model_loader & ml) { case LLM_ARCH_MINICPM: case LLM_ARCH_GRANITE: case LLM_ARCH_GRANITE_MOE: - case LLM_ARCH_GRANITE_MOE_SHARED: { tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); @@ -4395,8 +4393,7 @@ void llama_model::print_info() const { if (arch == LLM_ARCH_MINICPM || arch == LLM_ARCH_GRANITE || - arch == LLM_ARCH_GRANITE_MOE || - arch == LLM_ARCH_GRANITE_MOE_SHARED) { + arch == LLM_ARCH_GRANITE_MOE) { LLAMA_LOG_INFO("%s: f_embedding_scale = %f\n", __func__, hparams.f_embedding_scale); LLAMA_LOG_INFO("%s: f_residual_scale = %f\n", __func__, hparams.f_residual_scale); LLAMA_LOG_INFO("%s: f_attention_scale = %f\n", __func__, hparams.f_attention_scale); @@ -13324,7 +13321,6 @@ llm_graph_result_ptr llama_model::build_graph( } break; case LLM_ARCH_GRANITE: case LLM_ARCH_GRANITE_MOE: - case LLM_ARCH_GRANITE_MOE_SHARED: { llm = std::make_unique(*this, params, gf); } break; @@ -13475,7 +13471,6 @@ llama_rope_type llama_model_rope_type(const llama_model * model) { case LLM_ARCH_GLM4: case LLM_ARCH_GRANITE: case LLM_ARCH_GRANITE_MOE: - case LLM_ARCH_GRANITE_MOE_SHARED: case LLM_ARCH_CHAMELEON: case LLM_ARCH_BAILINGMOE: return LLAMA_ROPE_TYPE_NORM;