Skip to content

Commit d7fd29f

Browse files
icecream95compiladeggerganov
authored
llama : add OpenELM support (#7359)
* Initial OpenELM support (270M only so far) * Fill out missing entries in llama_model_type_name * fixup! Initial OpenELM support (270M only so far) Fix formatting * llama : support all OpenELM models * llama : add variable GQA and variable FFN sizes Some metadata keys can now also be arrays to support setting their value per-layer for models like OpenELM. * llama : minor spacing changes Co-authored-by: Georgi Gerganov <ggerganov@gmail.com> * llama : use std::array for per-layer hparams * llama : fix save/load state * llama : do not print hparams for vocab-only models * llama : handle n_head == 0 * llama : use const ref for print_f and fix division by zero * llama : fix t5 uses of n_head and n_ff * llama : minor comment --------- Co-authored-by: Francis Couture-Harpin <git@compilade.net> Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
1 parent 6f63d64 commit d7fd29f

File tree

5 files changed

+675
-175
lines changed

5 files changed

+675
-175
lines changed

convert_hf_to_gguf.py

Lines changed: 123 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from enum import IntEnum
1414
from pathlib import Path
1515
from hashlib import sha256
16-
from typing import TYPE_CHECKING, Any, Callable, ContextManager, Iterable, Iterator, Sequence, TypeVar, cast
16+
from typing import TYPE_CHECKING, Any, Callable, ContextManager, Iterable, Iterator, Literal, Sequence, TypeVar, cast
1717

1818
import math
1919
import numpy as np
@@ -677,6 +677,51 @@ def _set_vocab_llama_hf(self):
677677
special_vocab = gguf.SpecialVocab(self.dir_model, n_vocab=len(tokens))
678678
special_vocab.add_to_gguf(self.gguf_writer)
679679

680+
def _set_vocab_builtin(self, model_name: Literal["gpt-neox", "llama-spm"], vocab_size: int):
681+
tokenizer_path = Path(sys.path[0]) / "models" / f"ggml-vocab-{model_name}.gguf"
682+
logger.warning(f"Using tokenizer from '{os.path.relpath(tokenizer_path, os.getcwd())}'")
683+
vocab_reader = gguf.GGUFReader(tokenizer_path, "r")
684+
685+
default_pre = "mpt" if model_name == "gpt-neox" else "default"
686+
687+
field = vocab_reader.get_field(gguf.Keys.Tokenizer.MODEL)
688+
assert field # tokenizer model
689+
self.gguf_writer.add_tokenizer_model(bytes(field.parts[-1]).decode("utf-8"))
690+
691+
field = vocab_reader.get_field(gguf.Keys.Tokenizer.PRE)
692+
self.gguf_writer.add_tokenizer_pre(bytes(field.parts[-1]).decode("utf-8") if field else default_pre)
693+
694+
field = vocab_reader.get_field(gguf.Keys.Tokenizer.LIST)
695+
assert field # token list
696+
self.gguf_writer.add_token_list([bytes(field.parts[i]) for i in field.data][:vocab_size])
697+
698+
if model_name == "llama-spm":
699+
field = vocab_reader.get_field(gguf.Keys.Tokenizer.SCORES)
700+
assert field # token scores
701+
self.gguf_writer.add_token_scores([field.parts[i].tolist()[0] for i in field.data][:vocab_size])
702+
703+
field = vocab_reader.get_field(gguf.Keys.Tokenizer.TOKEN_TYPE)
704+
assert field # token types
705+
self.gguf_writer.add_token_types([field.parts[i].tolist()[0] for i in field.data][:vocab_size])
706+
707+
if model_name != "llama-spm":
708+
field = vocab_reader.get_field(gguf.Keys.Tokenizer.MERGES)
709+
assert field # token merges
710+
self.gguf_writer.add_token_merges([bytes(field.parts[i]) for i in field.data])
711+
712+
if (field := vocab_reader.get_field(gguf.Keys.Tokenizer.BOS_ID)) is not None:
713+
self.gguf_writer.add_bos_token_id(field.parts[-1].tolist()[0])
714+
if (field := vocab_reader.get_field(gguf.Keys.Tokenizer.EOS_ID)) is not None:
715+
self.gguf_writer.add_eos_token_id(field.parts[-1].tolist()[0])
716+
if (field := vocab_reader.get_field(gguf.Keys.Tokenizer.UNK_ID)) is not None:
717+
self.gguf_writer.add_unk_token_id(field.parts[-1].tolist()[0])
718+
if (field := vocab_reader.get_field(gguf.Keys.Tokenizer.PAD_ID)) is not None:
719+
self.gguf_writer.add_pad_token_id(field.parts[-1].tolist()[0])
720+
if (field := vocab_reader.get_field(gguf.Keys.Tokenizer.ADD_BOS)) is not None:
721+
self.gguf_writer.add_add_bos_token(field.parts[-1].tolist()[0])
722+
if (field := vocab_reader.get_field(gguf.Keys.Tokenizer.ADD_EOS)) is not None:
723+
self.gguf_writer.add_add_eos_token(field.parts[-1].tolist()[0])
724+
680725

681726
@Model.register("GPTNeoXForCausalLM")
682727
class GPTNeoXModel(Model):
@@ -2439,39 +2484,7 @@ def set_vocab(self):
24392484
self._set_vocab_sentencepiece()
24402485
else:
24412486
# Use the GPT-NeoX tokenizer when no tokenizer files are present
2442-
tokenizer_path = Path(sys.path[0]) / "models" / "ggml-vocab-gpt-neox.gguf"
2443-
logger.warning(f"Using tokenizer from '{os.path.relpath(tokenizer_path, os.getcwd())}'")
2444-
neox_reader = gguf.GGUFReader(tokenizer_path, "r")
2445-
2446-
field = neox_reader.get_field(gguf.Keys.Tokenizer.MODEL)
2447-
self.gguf_writer.add_tokenizer_model(bytes(field.parts[-1]).decode("utf-8") if field else "gpt2")
2448-
2449-
field = neox_reader.get_field(gguf.Keys.Tokenizer.PRE)
2450-
self.gguf_writer.add_tokenizer_pre(bytes(field.parts[-1]).decode("utf-8") if field else "mpt")
2451-
2452-
field = neox_reader.get_field(gguf.Keys.Tokenizer.LIST)
2453-
assert field
2454-
self.gguf_writer.add_token_list([bytes(field.parts[i]) for i in field.data][:vocab_size])
2455-
2456-
field = neox_reader.get_field(gguf.Keys.Tokenizer.TOKEN_TYPE)
2457-
assert field
2458-
self.gguf_writer.add_token_types([field.parts[i].tolist()[0] for i in field.data][:vocab_size])
2459-
2460-
field = neox_reader.get_field(gguf.Keys.Tokenizer.MERGES)
2461-
assert field
2462-
self.gguf_writer.add_token_merges([bytes(field.parts[i]) for i in field.data])
2463-
2464-
field = neox_reader.get_field(gguf.Keys.Tokenizer.BOS_ID)
2465-
self.gguf_writer.add_bos_token_id(field.parts[-1].tolist()[0] if field else 1)
2466-
2467-
field = neox_reader.get_field(gguf.Keys.Tokenizer.EOS_ID)
2468-
self.gguf_writer.add_eos_token_id(field.parts[-1].tolist()[0] if field else 0)
2469-
2470-
field = neox_reader.get_field(gguf.Keys.Tokenizer.UNK_ID)
2471-
self.gguf_writer.add_unk_token_id(field.parts[-1].tolist()[0] if field else 0)
2472-
2473-
field = neox_reader.get_field(gguf.Keys.Tokenizer.PAD_ID)
2474-
self.gguf_writer.add_pad_token_id(field.parts[-1].tolist()[0] if field else 0)
2487+
self._set_vocab_builtin("gpt-neox", vocab_size)
24752488

24762489
def set_gguf_parameters(self):
24772490
d_model = self.find_hparam(["hidden_size", "d_model"])
@@ -2623,6 +2636,82 @@ def set_vocab(self, *args, **kwargs):
26232636
self.gguf_writer.add_add_eos_token(True)
26242637

26252638

2639+
@Model.register("OpenELMForCausalLM")
2640+
class OpenELMModel(Model):
2641+
model_arch = gguf.MODEL_ARCH.OPENELM
2642+
2643+
@staticmethod
2644+
def _make_divisible(v: float | int, divisor: int) -> int:
2645+
# ref: https://huggingface.co/apple/OpenELM-270M-Instruct/blob/eb111ff2e6724348e5b905984063d4064d4bc579/configuration_openelm.py#L34-L38
2646+
new_v = max(divisor, int(v + divisor / 2) // divisor * divisor)
2647+
# Make sure that round down does not go down by more than 10%.
2648+
if new_v < 0.9 * v:
2649+
new_v += divisor
2650+
return new_v
2651+
2652+
def __init__(self, *args, **kwargs):
2653+
super().__init__(*args, **kwargs)
2654+
2655+
ffn_multipliers: list[float] = self.hparams["ffn_multipliers"]
2656+
ffn_dim_divisor: int = self.hparams["ffn_dim_divisor"]
2657+
self._n_embd: int = self.hparams["model_dim"]
2658+
self._num_kv_heads: list[int] = self.hparams["num_kv_heads"]
2659+
self._num_query_heads: list[int] = self.hparams["num_query_heads"]
2660+
self._ffn_dims: list[int] = [
2661+
OpenELMModel._make_divisible(multiplier * self._n_embd, ffn_dim_divisor)
2662+
for multiplier in ffn_multipliers
2663+
]
2664+
assert isinstance(self._num_kv_heads, list) and isinstance(self._num_kv_heads[0], int)
2665+
assert isinstance(self._num_query_heads, list) and isinstance(self._num_query_heads[0], int)
2666+
2667+
# Uses the tokenizer from meta-llama/Llama-2-7b-hf
2668+
def set_vocab(self):
2669+
try:
2670+
self._set_vocab_sentencepiece()
2671+
except FileNotFoundError:
2672+
self._set_vocab_builtin("llama-spm", self.hparams["vocab_size"])
2673+
2674+
def set_gguf_parameters(self):
2675+
n_embd = self._n_embd
2676+
head_dim = self.hparams["head_dim"]
2677+
rot_pct = 1.0
2678+
assert self.block_count == len(self._num_kv_heads)
2679+
assert self.block_count == len(self._num_query_heads)
2680+
assert self.block_count == len(self._ffn_dims)
2681+
2682+
self.gguf_writer.add_name(self.dir_model.name if self.model_name is None else self.model_name)
2683+
self.gguf_writer.add_block_count(self.block_count)
2684+
self.gguf_writer.add_context_length(self.hparams["max_context_length"])
2685+
self.gguf_writer.add_embedding_length(n_embd)
2686+
self.gguf_writer.add_feed_forward_length(self._ffn_dims)
2687+
self.gguf_writer.add_head_count(self._num_query_heads)
2688+
self.gguf_writer.add_head_count_kv(self._num_kv_heads)
2689+
self.gguf_writer.add_rope_freq_base(self.hparams["rope_freq_constant"])
2690+
# https://huggingface.co/apple/OpenELM-270M-Instruct/blob/c401df2/modeling_openelm.py#L30
2691+
self.gguf_writer.add_layer_norm_rms_eps(1e-6)
2692+
self.gguf_writer.add_rope_dimension_count(int(rot_pct * head_dim))
2693+
self.gguf_writer.add_key_length(head_dim)
2694+
self.gguf_writer.add_value_length(head_dim)
2695+
self.gguf_writer.add_file_type(self.ftype)
2696+
2697+
def find_hparam(self, keys: Iterable[str], optional: bool = False) -> Any:
2698+
if "n_layers" in keys:
2699+
return self.hparams["num_transformer_layers"]
2700+
2701+
return super().find_hparam(keys, optional)
2702+
2703+
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
2704+
2705+
# split ff
2706+
if bid is not None and name == f"transformer.layers.{bid}.ffn.proj_1.weight":
2707+
ff_dim = self._ffn_dims[bid]
2708+
yield (self.format_tensor_name(gguf.MODEL_TENSOR.FFN_GATE, bid), data_torch[:ff_dim])
2709+
yield (self.format_tensor_name(gguf.MODEL_TENSOR.FFN_UP, bid), data_torch[ff_dim:])
2710+
return
2711+
2712+
yield (self.map_tensor_name(name), data_torch)
2713+
2714+
26262715
@Model.register("ArcticForCausalLM")
26272716
class ArcticModel(Model):
26282717
model_arch = gguf.MODEL_ARCH.ARCTIC

gguf-py/gguf/constants.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,7 @@ class MODEL_ARCH(IntEnum):
160160
COMMAND_R = auto()
161161
DBRX = auto()
162162
OLMO = auto()
163+
OPENELM = auto()
163164
ARCTIC = auto()
164165
DEEPSEEK2 = auto()
165166
BITNET = auto()
@@ -285,6 +286,7 @@ class MODEL_TENSOR(IntEnum):
285286
MODEL_ARCH.COMMAND_R: "command-r",
286287
MODEL_ARCH.DBRX: "dbrx",
287288
MODEL_ARCH.OLMO: "olmo",
289+
MODEL_ARCH.OPENELM: "openelm",
288290
MODEL_ARCH.ARCTIC: "arctic",
289291
MODEL_ARCH.DEEPSEEK2: "deepseek2",
290292
MODEL_ARCH.BITNET: "bitnet",
@@ -861,6 +863,19 @@ class MODEL_TENSOR(IntEnum):
861863
MODEL_TENSOR.FFN_DOWN,
862864
MODEL_TENSOR.FFN_UP,
863865
],
866+
MODEL_ARCH.OPENELM: [
867+
MODEL_TENSOR.TOKEN_EMBD,
868+
MODEL_TENSOR.OUTPUT_NORM,
869+
MODEL_TENSOR.ATTN_NORM,
870+
MODEL_TENSOR.ATTN_QKV,
871+
MODEL_TENSOR.ATTN_Q_NORM,
872+
MODEL_TENSOR.ATTN_K_NORM,
873+
MODEL_TENSOR.ATTN_OUT,
874+
MODEL_TENSOR.FFN_NORM,
875+
MODEL_TENSOR.FFN_GATE,
876+
MODEL_TENSOR.FFN_DOWN,
877+
MODEL_TENSOR.FFN_UP,
878+
],
864879
MODEL_ARCH.ARCTIC: [
865880
MODEL_TENSOR.TOKEN_EMBD,
866881
MODEL_TENSOR.OUTPUT_NORM,

gguf-py/gguf/gguf_writer.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -480,8 +480,11 @@ def add_block_count(self, length: int) -> None:
480480
def add_leading_dense_block_count(self, length: int) -> None:
481481
self.add_uint32(Keys.LLM.LEADING_DENSE_BLOCK_COUNT.format(arch=self.arch), length)
482482

483-
def add_feed_forward_length(self, length: int) -> None:
484-
self.add_uint32(Keys.LLM.FEED_FORWARD_LENGTH.format(arch=self.arch), length)
483+
def add_feed_forward_length(self, length: int | Sequence[int]) -> None:
484+
if isinstance(length, int):
485+
self.add_uint32(Keys.LLM.FEED_FORWARD_LENGTH.format(arch=self.arch), length)
486+
else:
487+
self.add_array(Keys.LLM.FEED_FORWARD_LENGTH.format(arch=self.arch), length)
485488

486489
def add_expert_feed_forward_length(self, length: int) -> None:
487490
self.add_uint32(Keys.LLM.EXPERT_FEED_FORWARD_LENGTH.format(arch=self.arch), length)
@@ -495,11 +498,17 @@ def add_parallel_residual(self, use: bool) -> None:
495498
def add_decoder_start_token_id(self, id: int) -> None:
496499
self.add_uint32(Keys.LLM.DECODER_START_TOKEN_ID.format(arch=self.arch), id)
497500

498-
def add_head_count(self, count: int) -> None:
499-
self.add_uint32(Keys.Attention.HEAD_COUNT.format(arch=self.arch), count)
501+
def add_head_count(self, count: int | Sequence[int]) -> None:
502+
if isinstance(count, int):
503+
self.add_uint32(Keys.Attention.HEAD_COUNT.format(arch=self.arch), count)
504+
else:
505+
self.add_array(Keys.Attention.HEAD_COUNT.format(arch=self.arch), count)
500506

501-
def add_head_count_kv(self, count: int) -> None:
502-
self.add_uint32(Keys.Attention.HEAD_COUNT_KV.format(arch=self.arch), count)
507+
def add_head_count_kv(self, count: int | Sequence[int]) -> None:
508+
if isinstance(count, int):
509+
self.add_uint32(Keys.Attention.HEAD_COUNT_KV.format(arch=self.arch), count)
510+
else:
511+
self.add_array(Keys.Attention.HEAD_COUNT_KV.format(arch=self.arch), count)
503512

504513
def add_key_length(self, length: int) -> None:
505514
self.add_uint32(Keys.Attention.KEY_LENGTH.format(arch=self.arch), length)

gguf-py/gguf/tensor_mapping.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ class TensorNameMap:
2424
"backbone.embedding", # mamba
2525
"backbone.embeddings", # mamba-hf
2626
"transformer.in_out_embed", # Grok
27+
"transformer.token_embeddings", # openelm
2728
"shared", # t5
2829
),
2930

@@ -37,6 +38,7 @@ class TensorNameMap:
3738
"word_embeddings_layernorm", # bloom
3839
"embeddings.LayerNorm", # bert
3940
"emb_ln", # nomic-bert
41+
"transformer.norm", # openelm
4042
),
4143

4244
# Position embeddings
@@ -69,6 +71,7 @@ class TensorNameMap:
6971
"model.norm_f", # mamba-qbert
7072
"backbone.norm_f", # mamba
7173
"transformer.rms_norm", # Grok
74+
"transformer.norm", # openelm
7275
),
7376

7477
# Rope frequencies
@@ -98,6 +101,7 @@ class TensorNameMap:
98101
"backbone.layers.{bid}.norm", # mamba
99102
"transformer.decoder_layer.{bid}.rms_norm", # Grok
100103
"transformer.blocks.{bid}.norm_attn_norm.norm_1", # dbrx
104+
"transformer.layers.{bid}.attn_norm", # openelm
101105
),
102106

103107
# Attention norm 2
@@ -119,7 +123,8 @@ class TensorNameMap:
119123
"h.{bid}.attn.c_attn", # gpt2
120124
"transformer.h.{bid}.mixer.Wqkv", # phi2
121125
"encoder.layers.{bid}.attn.Wqkv", # nomic-bert
122-
"model.layers.{bid}.self_attn.qkv_proj" # phi3
126+
"model.layers.{bid}.self_attn.qkv_proj", # phi3
127+
"transformer.layers.{bid}.attn.qkv_proj", # openelm
123128
),
124129

125130
# Attention query
@@ -177,6 +182,7 @@ class TensorNameMap:
177182
"encoder.layers.{bid}.attn.out_proj", # nomic-bert
178183
"transformer.decoder_layer.{bid}.multi_head_attention.linear", # Grok
179184
"transformer.blocks.{bid}.norm_attn_norm.attn.out_proj", # dbrx
185+
"transformer.layers.{bid}.attn.out_proj", # openelm
180186
),
181187

182188
# Attention output norm
@@ -212,6 +218,7 @@ class TensorNameMap:
212218
"h.{bid}.ln_2", # gpt2
213219
"model.layers.{bid}.ffn_norm", # internlm2
214220
"transformer.decoder_layer.{bid}.rms_norm_2", # Grok
221+
"transformer.layers.{bid}.ffn_norm", # openelm
215222
),
216223

217224
# Post feed-forward norm
@@ -327,6 +334,7 @@ class TensorNameMap:
327334
"encoder.layers.{bid}.mlp.fc2", # nomic-bert
328335
"model.layers.{bid}.mlp.c_proj", # starcoder2
329336
"encoder.layer.{bid}.mlp.wo", # jina-bert-v2
337+
"transformer.layers.{bid}.ffn.proj_2", # openelm
330338
"model.layers.{bid}.residual_mlp.w2", # arctic
331339
"encoder.layer.{bid}.mlp.down_layer", # jina-bert-v2
332340
),
@@ -348,15 +356,17 @@ class TensorNameMap:
348356
"model.layers.{bid}.self_attn.q_layernorm", # persimmon
349357
"model.layers.{bid}.self_attn.q_norm", # cohere
350358
"transformer.blocks.{bid}.attn.q_ln", # sea-lion
351-
"encoder.layer.{bid}.attention.self.layer_norm_q" # jina-bert-v2
359+
"encoder.layer.{bid}.attention.self.layer_norm_q", # jina-bert-v2
360+
"transformer.layers.{bid}.attn.q_norm", # openelm
352361
),
353362

354363
MODEL_TENSOR.ATTN_K_NORM: (
355364
"language_model.encoder.layers.{bid}.self_attention.k_layernorm",
356365
"model.layers.{bid}.self_attn.k_layernorm", # persimmon
357366
"model.layers.{bid}.self_attn.k_norm", # cohere
358367
"transformer.blocks.{bid}.attn.k_ln", # sea-lion
359-
"encoder.layer.{bid}.attention.self.layer_norm_k" # jina-bert-v2
368+
"encoder.layer.{bid}.attention.self.layer_norm_k", # jina-bert-v2
369+
"transformer.layers.{bid}.attn.k_norm", # openelm
360370
),
361371

362372
MODEL_TENSOR.ROPE_FREQS: (

0 commit comments

Comments
 (0)