Skip to content

Commit 99db667

Browse files
slarenggerganovMrkvak
authored andcommitted
llama : add Mixtral support (ggml-org#4406)
* convert : support Mixtral as LLAMA arch * convert : fix n_ff typo * llama : model loading * ggml : sync latest ggml_mul_mat_id * llama : update graph to support MoE * llama : fix cur -> cur_expert * llama : first working version * llama : fix expert weighting in the FFN * ggml : ggml_get_rows support 2D indexing [n_tokens, n_experts] (cpu only) * ggml : add n_as argument to ggml_mul_mat_id * ggml : fix ggml_get_rows to take into account ne02 / ne11 * metal : add more general support for ggml_get_rows + tests * llama : add basic support for offloading moe with CUDA * metal : add/mul/div use general kernel when src1 not cont * metal : reduce the kernel launches for ggml_mul_mat_id * ggml : get_rows : support non-contiguos tensors with gaps, generalize up to 3D * ggml : update get_rows f16 and q * cuda : support non-contiguous src1 in get_rows * llama : offload missing ffn_moe_silu * metal : fix ggml_get_rows to work with non-cont src1 * metal : add indirect mat-vec kernels for all quantization types * llama : do not quantize expert gating tensors * llama : add n_expert and n_expert_used to hparams + change quants * test-backend-ops : add moe test * cuda : fix get_rows when ncols is odd * convert : determine n_ctx correctly * metal : fix ggml_mul_mat_id for F32 * test-backend-ops : make experts more evenly probable (test_moe) * test-backend-ops : cleanup, add moe test for batches * test-backend-ops : add cpy from f32 -> all types test * test-backend-ops : fix dequantize block offset * llama : fix hard-coded number of experts * test-backend-ops : simplify and disable slow tests to avoid CI timeout * test-backend-ops : disable MOE test with thread sanitizer * cuda : fix mul_mat_id with multi gpu * convert : use 1e6 rope_freq_base for mixtral * convert : fix style * convert : support safetensors format * gguf-py : bump version * metal : add cpy f16 -> f32 kernel * metal : fix binary ops for ne10 % 4 != 0 * test-backend-ops : add one more sum_rows test * ggml : do not use BLAS with ggml_mul_mat_id * convert-hf : support for mixtral-instruct (ggml-org#4428) * convert : typo fix, add additional hyperparameters, use LLaMA arch for Mixtral-instruct * convert : use sentencepiece tokenizer for Mixtral-instruct * convert : make flake8 happy * metal : fix soft_max kernels ref: ggml-org/ggml@1914017 * metal : limit kernels to not use more than the allowed threads --------- Co-authored-by: Georgi Gerganov <ggerganov@gmail.com> Co-authored-by: Radek Pilar <github@mrkva.eu>
1 parent 9cfbe94 commit 99db667

14 files changed

+2369
-394
lines changed

Makefile

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -399,6 +399,11 @@ ifdef LLAMA_CUBLAS
399399
MK_LDFLAGS += -lcublas -lculibos -lcudart -lcublasLt -lpthread -ldl -lrt -L/usr/local/cuda/lib64 -L/opt/cuda/lib64 -L$(CUDA_PATH)/targets/x86_64-linux/lib
400400
OBJS += ggml-cuda.o
401401
NVCCFLAGS = --forward-unknown-to-host-compiler -use_fast_math
402+
403+
ifdef LLAMA_DEBUG
404+
NVCCFLAGS += -lineinfo
405+
endif
406+
402407
ifdef LLAMA_CUDA_NVCC
403408
NVCC = $(LLAMA_CUDA_NVCC)
404409
else

convert-hf-to-gguf.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,8 +77,18 @@ def set_gguf_parameters(self):
7777
self.gguf_writer.add_embedding_length(n_embd)
7878
if (n_ff := self.hparams.get("intermediate_size")) is not None:
7979
self.gguf_writer.add_feed_forward_length(n_ff)
80-
if (n_head := self.hparams.get("num_attention_head")) is not None:
80+
if (n_head := self.hparams.get("num_attention_heads")) is not None:
8181
self.gguf_writer.add_head_count(n_head)
82+
if (n_head_kv := self.hparams.get("num_key_value_heads")) is not None:
83+
self.gguf_writer.add_head_count_kv(n_head_kv)
84+
85+
if (n_rms_eps := self.hparams.get("rms_norm_eps")) is not None:
86+
self.gguf_writer.add_layer_norm_rms_eps(n_rms_eps)
87+
if (n_experts := self.hparams.get("num_local_experts")) is not None:
88+
self.gguf_writer.add_expert_count(n_experts)
89+
if (n_experts_used := self.hparams.get("num_experts_per_tok")) is not None:
90+
self.gguf_writer.add_expert_used_count(n_experts_used)
91+
8292
self.gguf_writer.add_parallel_residual(self.hparams.get("use_parallel_residual", True))
8393

8494
def write_tensors(self):
@@ -170,6 +180,8 @@ def from_model_architecture(model_architecture):
170180
return StableLMModel
171181
if model_architecture == "QWenLMHeadModel":
172182
return QwenModel
183+
if model_architecture == "MixtralForCausalLM":
184+
return MixtralModel
173185
return Model
174186

175187
def _is_model_safetensors(self) -> bool:
@@ -207,6 +219,8 @@ def _get_model_architecture(self) -> gguf.MODEL_ARCH:
207219
return gguf.MODEL_ARCH.STABLELM
208220
if arch == "QWenLMHeadModel":
209221
return gguf.MODEL_ARCH.QWEN
222+
if arch == "MixtralForCausalLM":
223+
return gguf.MODEL_ARCH.LLAMA
210224

211225
raise NotImplementedError(f'Architecture "{arch}" not supported!')
212226

@@ -837,6 +851,11 @@ def set_gguf_parameters(self):
837851
self.gguf_writer.add_layer_norm_eps(1e-5)
838852

839853

854+
class MixtralModel(Model):
855+
def set_vocab(self):
856+
self._set_vocab_sentencepiece()
857+
858+
840859
class QwenModel(Model):
841860
@staticmethod
842861
def token_bytes_to_string(b):

convert.py

Lines changed: 57 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
ARCH = gguf.MODEL_ARCH.LLAMA
4343

4444
DEFAULT_CONCURRENCY = 8
45+
4546
#
4647
# data types
4748
#
@@ -62,10 +63,10 @@ class UnquantizedDataType(DataType):
6263
pass
6364

6465

65-
DT_F16 = UnquantizedDataType('F16', dtype = np.dtype(np.float16), valid_conversions = ['F32', 'Q8_0'])
66-
DT_F32 = UnquantizedDataType('F32', dtype = np.dtype(np.float32), valid_conversions = ['F16', 'Q8_0'])
67-
DT_I32 = UnquantizedDataType('I32', dtype = np.dtype(np.int16), valid_conversions = [])
68-
DT_BF16 = UnquantizedDataType('BF16', dtype = np.dtype(np.uint16), valid_conversions = ['F32', 'F16', 'Q8_0'])
66+
DT_F16 = UnquantizedDataType('F16', dtype = np.dtype(np.float16), valid_conversions = ['F32', 'Q8_0'])
67+
DT_F32 = UnquantizedDataType('F32', dtype = np.dtype(np.float32), valid_conversions = ['F16', 'Q8_0'])
68+
DT_I32 = UnquantizedDataType('I32', dtype = np.dtype(np.int16), valid_conversions = [])
69+
DT_BF16 = UnquantizedDataType('BF16', dtype = np.dtype(np.uint16), valid_conversions = ['F32', 'F16', 'Q8_0'])
6970

7071

7172
@dataclass(frozen=True)
@@ -151,14 +152,16 @@ def type_for_tensor(self, name: str, tensor: LazyTensor) -> DataType:
151152

152153
@dataclass
153154
class Params:
154-
n_vocab: int
155-
n_embd: int
156-
n_layer: int
157-
n_ctx: int
158-
n_ff: int
159-
n_head: int
160-
n_head_kv: int
161-
f_norm_eps: float
155+
n_vocab: int
156+
n_embd: int
157+
n_layer: int
158+
n_ctx: int
159+
n_ff: int
160+
n_head: int
161+
n_head_kv: int
162+
n_experts: int | None = None
163+
n_experts_used: int | None = None
164+
f_norm_eps: float | None = None
162165

163166
rope_scaling_type: gguf.RopeScalingType | None = None
164167
f_rope_freq_base: float | None = None
@@ -233,6 +236,13 @@ def loadHFTransformerJson(model: LazyModel, config_path: Path) -> Params:
233236
raise Exception("failed to guess 'n_ctx'. This model is unknown or unsupported.\n"
234237
"Suggestion: provide 'config.json' of the model in the same directory containing model files.")
235238

239+
n_experts = None
240+
n_experts_used = None
241+
242+
if "num_local_experts" in config:
243+
n_experts = config["num_local_experts"]
244+
n_experts_used = config["num_experts_per_tok"]
245+
236246
return Params(
237247
n_vocab = config["vocab_size"],
238248
n_embd = config["hidden_size"],
@@ -241,6 +251,8 @@ def loadHFTransformerJson(model: LazyModel, config_path: Path) -> Params:
241251
n_ff = config["intermediate_size"],
242252
n_head = (n_head := config["num_attention_heads"]),
243253
n_head_kv = config.get("num_key_value_heads", n_head),
254+
n_experts = n_experts,
255+
n_experts_used = n_experts_used,
244256
f_norm_eps = config["rms_norm_eps"],
245257
f_rope_freq_base = config.get("rope_theta"),
246258
rope_scaling_type = rope_scaling_type,
@@ -255,8 +267,15 @@ def loadHFTransformerJson(model: LazyModel, config_path: Path) -> Params:
255267
def loadOriginalParamsJson(model: LazyModel, config_path: Path) -> Params:
256268
config = json.load(open(config_path))
257269

270+
n_experts = None
271+
n_experts_used = None
272+
f_rope_freq_base = None
273+
258274
# hack to determine LLaMA v1 vs v2 vs CodeLlama
259-
if config.get("rope_theta") == 1000000:
275+
if config.get("moe"):
276+
# Mixtral
277+
n_ctx = 32768
278+
elif config.get("rope_theta") == 1000000:
260279
# CodeLlama
261280
n_ctx = 16384
262281
elif config["norm_eps"] == 1e-05:
@@ -266,16 +285,27 @@ def loadOriginalParamsJson(model: LazyModel, config_path: Path) -> Params:
266285
# LLaMA v1
267286
n_ctx = 2048
268287

288+
if "layers.0.feed_forward.w1.weight" in model:
289+
n_ff = model["layers.0.feed_forward.w1.weight"].shape[0]
290+
291+
if config.get("moe"):
292+
n_ff = model["layers.0.feed_forward.experts.0.w1.weight"].shape[0]
293+
n_experts = config["moe"]["num_experts"]
294+
n_experts_used = config["moe"]["num_experts_per_tok"]
295+
f_rope_freq_base = 1e6
296+
269297
return Params(
270298
n_vocab = model["tok_embeddings.weight"].shape[0],
271299
n_embd = config["dim"],
272300
n_layer = config["n_layers"],
273301
n_ctx = n_ctx,
274-
n_ff = model["layers.0.feed_forward.w1.weight"].shape[0],
302+
n_ff = n_ff,
275303
n_head = (n_head := config["n_heads"]),
276304
n_head_kv = config.get("n_kv_heads", n_head),
305+
n_experts = n_experts,
306+
n_experts_used = n_experts_used,
277307
f_norm_eps = config["norm_eps"],
278-
f_rope_freq_base = config.get("rope_theta"),
308+
f_rope_freq_base = config.get("rope_theta", f_rope_freq_base),
279309
)
280310

281311
@staticmethod
@@ -832,7 +862,17 @@ def add_meta_arch(self, params: Params) -> None:
832862
self.gguf.add_rope_dimension_count(params.n_embd // params.n_head)
833863
self.gguf.add_head_count (params.n_head)
834864
self.gguf.add_head_count_kv (params.n_head_kv)
835-
self.gguf.add_layer_norm_rms_eps (params.f_norm_eps)
865+
866+
if params.n_experts:
867+
self.gguf.add_expert_count(params.n_experts)
868+
869+
if params.n_experts_used:
870+
self.gguf.add_expert_used_count(params.n_experts_used)
871+
872+
if params.f_norm_eps:
873+
self.gguf.add_layer_norm_rms_eps(params.f_norm_eps)
874+
else:
875+
raise ValueError('f_norm_eps is None')
836876

837877
if params.f_rope_freq_base is not None:
838878
self.gguf.add_rope_freq_base(params.f_rope_freq_base)
@@ -956,7 +996,7 @@ def write_all(fname_out: Path, ftype: GGMLFileType, params: Params, model: LazyM
956996

957997

958998
def pick_output_type(model: LazyModel, output_type_str: str | None) -> GGMLFileType:
959-
wq_type = model[gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.ATTN_Q].format(bid=0) +".weight"].data_type
999+
wq_type = model[gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.ATTN_Q].format(bid=0) + ".weight"].data_type
9601000

9611001
if output_type_str == "f32" or (output_type_str is None and wq_type == DT_F32):
9621002
return GGMLFileType.AllF32

0 commit comments

Comments
 (0)