Skip to content

Commit 82b051a

Browse files
ymckitinglou
authored andcommitted
llama : support for Llama-3_1-Nemotron-51B (ggml-org#10669)
* conflict resolution * move comments after bracket to its own line
1 parent 8e8b068 commit 82b051a

File tree

4 files changed

+471
-1
lines changed

4 files changed

+471
-1
lines changed

convert_hf_to_gguf.py

Lines changed: 178 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1692,6 +1692,184 @@ def prepare_tensors(self):
16921692
raise ValueError(f"Unprocessed experts: {experts}")
16931693

16941694

1695+
@Model.register("DeciLMForCausalLM")
1696+
class DeciModel(Model):
1697+
model_arch = gguf.MODEL_ARCH.DECI
1698+
1699+
@staticmethod
1700+
def _ffn_mult_to_intermediate_size(ffn_mult: float, n_embd: int) -> int:
1701+
# DeciLM-specific code
1702+
intermediate_size = int(2 * ffn_mult * n_embd / 3)
1703+
return DeciModel._find_multiple(intermediate_size, 256)
1704+
1705+
@staticmethod
1706+
def _find_multiple(n: int, k: int) -> int:
1707+
# DeciLM-specific code
1708+
if n % k == 0:
1709+
return n
1710+
return n + k - (n % k)
1711+
1712+
def __init__(self, *args, **kwargs):
1713+
super().__init__(*args, **kwargs)
1714+
1715+
if "block_configs" in self.hparams: # Llama-3_1-Nemotron-51B
1716+
_block_configs: list[dict[str,Any]] = self.hparams["block_configs"]
1717+
assert self.block_count == len(_block_configs)
1718+
self._num_kv_heads = list()
1719+
self._num_heads = list()
1720+
_ffn_multipliers = list()
1721+
# ***linear attention layer***
1722+
# if n_heads_in_group is None and replace_with_linear is True
1723+
# then _num_kv_heads[il] is 0 and _num_heads[il] is num_attention_heads
1724+
# ***attention-free layer***
1725+
# if n_heads_in_group is None and replace_with_linear is False
1726+
# then _num_kv_heads[il] is 0 and _num_heads[il] is 0
1727+
# ***normal attention-layer***
1728+
# if n_heads_in_group is not None, then
1729+
# _num_kv_heads[il] is num_attention_head // n_heads_in_group and
1730+
# _num_heads[il] is num_attention_head
1731+
for il in range(len(_block_configs)):
1732+
if _block_configs[il]["attention"]["n_heads_in_group"] is None:
1733+
if _block_configs[il]["attention"]["replace_with_linear"] is True:
1734+
self._num_kv_heads.append(0)
1735+
self._num_heads.append(self.hparams["num_attention_heads"])
1736+
else:
1737+
self._num_kv_heads.append(0)
1738+
self._num_heads.append(0)
1739+
else:
1740+
self._num_kv_heads.append(self.hparams["num_attention_heads"] // _block_configs[il]["attention"]["n_heads_in_group"])
1741+
self._num_heads.append(self.hparams["num_attention_heads"])
1742+
_ffn_multipliers.append(_block_configs[il]["ffn"]["ffn_mult"])
1743+
assert self.block_count == len(self._num_kv_heads)
1744+
assert self.block_count == len(self._num_heads)
1745+
assert self.block_count == len(_ffn_multipliers)
1746+
assert isinstance(self._num_kv_heads, list) and isinstance(self._num_kv_heads[0], int)
1747+
assert isinstance(self._num_heads, list) and isinstance(self._num_heads[0], int)
1748+
assert isinstance(_ffn_multipliers, list) and isinstance(_ffn_multipliers[0], float)
1749+
self._ffn_dims: list[int] = [
1750+
DeciModel._ffn_mult_to_intermediate_size(multiplier, self.hparams["hidden_size"])
1751+
for multiplier in _ffn_multipliers
1752+
]
1753+
1754+
def set_vocab(self):
1755+
# Please change tokenizer_config.json of Llama-3_1-Nemotron-51B's
1756+
# eos_token from '|eot_id|' to '|end_of_text|'
1757+
if self.hparams.get("vocab_size", 128256) == 128256:
1758+
tokens, toktypes, tokpre = self.get_vocab_base()
1759+
self.gguf_writer.add_tokenizer_model("gpt2")
1760+
self.gguf_writer.add_tokenizer_pre(tokpre)
1761+
self.gguf_writer.add_token_list(tokens)
1762+
self.gguf_writer.add_token_types(toktypes)
1763+
1764+
special_vocab = gguf.SpecialVocab(
1765+
self.dir_model, load_merges=True,
1766+
special_token_types = ['bos', 'eos', 'eom', 'eot']
1767+
)
1768+
special_vocab._set_special_token("bos", 128000)
1769+
special_vocab._set_special_token("eos", 128001)
1770+
special_vocab._set_special_token("eom", 128008)
1771+
special_vocab._set_special_token("eot", 128009)
1772+
special_vocab.add_to_gguf(self.gguf_writer)
1773+
else:
1774+
# DeciLM-7B
1775+
self._set_vocab_llama_hf()
1776+
# self._set_vocab_gpt2()
1777+
1778+
def set_gguf_parameters(self):
1779+
if "block_configs" in self.hparams: # Llama-3_1-Nemotron-51B
1780+
assert self.block_count == len(self._num_kv_heads)
1781+
assert self.block_count == len(self._num_heads)
1782+
assert self.block_count == len(self._ffn_dims)
1783+
self.gguf_writer.add_head_count_kv(self._num_kv_heads)
1784+
self.gguf_writer.add_head_count(self._num_heads)
1785+
self.gguf_writer.add_feed_forward_length(self._ffn_dims)
1786+
self.gguf_writer.add_block_count(self.block_count)
1787+
self.gguf_writer.add_context_length(self.hparams["max_position_embeddings"])
1788+
self.gguf_writer.add_embedding_length(self.hparams["hidden_size"])
1789+
self.gguf_writer.add_layer_norm_rms_eps(self.hparams["rms_norm_eps"])
1790+
self.gguf_writer.add_key_length(self.hparams["hidden_size"] // self.hparams["num_attention_heads"])
1791+
self.gguf_writer.add_value_length(self.hparams["hidden_size"] // self.hparams["num_attention_heads"])
1792+
self.gguf_writer.add_file_type(self.ftype)
1793+
else: # DeciLM-7B
1794+
super().set_gguf_parameters()
1795+
if "num_key_value_heads_per_layer" in self.hparams: # DeciLM-7B
1796+
self._num_kv_heads: list[int] = self.hparams["num_key_value_heads_per_layer"]
1797+
assert self.block_count == len(self._num_kv_heads)
1798+
self.gguf_writer.add_head_count_kv(self._num_kv_heads)
1799+
hparams = self.hparams
1800+
self.gguf_writer.add_vocab_size(hparams["vocab_size"])
1801+
1802+
if "head_dim" in hparams:
1803+
rope_dim = hparams["head_dim"]
1804+
else:
1805+
rope_dim = hparams["hidden_size"] // hparams["num_attention_heads"]
1806+
self.gguf_writer.add_rope_dimension_count(rope_dim)
1807+
1808+
if self.hparams.get("rope_scaling") is not None and "factor" in self.hparams["rope_scaling"]:
1809+
if self.hparams["rope_scaling"].get("type") == "linear":
1810+
self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.LINEAR)
1811+
self.gguf_writer.add_rope_scaling_factor(self.hparams["rope_scaling"]["factor"])
1812+
1813+
@staticmethod
1814+
def permute(weights: Tensor, n_head: int, n_head_kv: int | None):
1815+
if n_head_kv is not None and n_head != n_head_kv:
1816+
n_head = n_head_kv
1817+
return (weights.reshape(n_head, 2, weights.shape[0] // n_head // 2, *weights.shape[1:])
1818+
.swapaxes(1, 2)
1819+
.reshape(weights.shape))
1820+
1821+
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
1822+
n_head = self.hparams["num_attention_heads"]
1823+
if bid is not None:
1824+
if "num_key_value_heads_per_layer" in self.hparams:
1825+
n_kv_head = self.hparams["num_key_value_heads_per_layer"][bid]
1826+
elif "block_configs" in self.hparams:
1827+
n_kv_head = self._num_kv_heads[bid]
1828+
n_head = self._num_heads[bid]
1829+
else:
1830+
n_kv_head = self.hparams.get("num_key_value_heads")
1831+
else:
1832+
n_kv_head = self.hparams.get("num_key_value_heads")
1833+
1834+
if name.endswith(("q_proj.weight", "q_proj.bias")):
1835+
data_torch = DeciModel.permute(data_torch, n_head, n_head)
1836+
if name.endswith(("k_proj.weight", "k_proj.bias")):
1837+
data_torch = DeciModel.permute(data_torch, n_head, n_kv_head)
1838+
return [(self.map_tensor_name(name), data_torch)]
1839+
1840+
def generate_extra_tensors(self) -> Iterable[tuple[str, Tensor]]:
1841+
if rope_scaling := self.find_hparam(["rope_scaling"], optional=True):
1842+
if rope_scaling.get("rope_type", '').lower() == "llama3":
1843+
base = self.hparams.get("rope_theta", 10000.0)
1844+
dim = self.hparams.get("head_dim", self.hparams["hidden_size"] // self.hparams["num_attention_heads"])
1845+
freqs = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim))
1846+
1847+
factor = rope_scaling.get("factor", 8.0)
1848+
low_freq_factor = rope_scaling.get("low_freq_factor", 1.0)
1849+
high_freq_factor = rope_scaling.get("high_freq_factor", 4.0)
1850+
old_context_len = self.hparams.get("original_max_position_embeddings", 8192)
1851+
1852+
low_freq_wavelen = old_context_len / low_freq_factor
1853+
high_freq_wavelen = old_context_len / high_freq_factor
1854+
assert low_freq_wavelen != high_freq_wavelen
1855+
1856+
rope_factors = []
1857+
for freq in freqs:
1858+
wavelen = 2 * math.pi / freq
1859+
if wavelen < high_freq_wavelen:
1860+
rope_factors.append(1)
1861+
elif wavelen > low_freq_wavelen:
1862+
rope_factors.append(factor)
1863+
else:
1864+
smooth = (old_context_len / wavelen - low_freq_factor) / (high_freq_factor - low_freq_factor)
1865+
rope_factors.append(1 / ((1 - smooth) / factor + smooth))
1866+
1867+
yield (self.format_tensor_name(gguf.MODEL_TENSOR.ROPE_FREQS), torch.tensor(rope_factors, dtype=torch.float32))
1868+
1869+
def prepare_tensors(self):
1870+
super().prepare_tensors()
1871+
1872+
16951873
@Model.register("BitnetForCausalLM")
16961874
class BitnetModel(Model):
16971875
model_arch = gguf.MODEL_ARCH.BITNET

gguf-py/gguf/constants.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -221,6 +221,7 @@ class GGUFType:
221221

222222
class MODEL_ARCH(IntEnum):
223223
LLAMA = auto()
224+
DECI = auto()
224225
FALCON = auto()
225226
BAICHUAN = auto()
226227
GROK = auto()
@@ -402,6 +403,7 @@ class MODEL_TENSOR(IntEnum):
402403

403404
MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
404405
MODEL_ARCH.LLAMA: "llama",
406+
MODEL_ARCH.DECI: "deci",
405407
MODEL_ARCH.FALCON: "falcon",
406408
MODEL_ARCH.BAICHUAN: "baichuan",
407409
MODEL_ARCH.GROK: "grok",
@@ -602,6 +604,26 @@ class MODEL_TENSOR(IntEnum):
602604
MODEL_TENSOR.FFN_DOWN_EXP,
603605
MODEL_TENSOR.FFN_UP_EXP,
604606
],
607+
MODEL_ARCH.DECI: [
608+
MODEL_TENSOR.TOKEN_EMBD,
609+
MODEL_TENSOR.OUTPUT_NORM,
610+
MODEL_TENSOR.OUTPUT,
611+
MODEL_TENSOR.ROPE_FREQS,
612+
MODEL_TENSOR.ATTN_NORM,
613+
MODEL_TENSOR.ATTN_Q,
614+
MODEL_TENSOR.ATTN_K,
615+
MODEL_TENSOR.ATTN_V,
616+
MODEL_TENSOR.ATTN_OUT,
617+
MODEL_TENSOR.ATTN_ROT_EMBD,
618+
MODEL_TENSOR.FFN_GATE_INP,
619+
MODEL_TENSOR.FFN_NORM,
620+
MODEL_TENSOR.FFN_GATE,
621+
MODEL_TENSOR.FFN_DOWN,
622+
MODEL_TENSOR.FFN_UP,
623+
MODEL_TENSOR.FFN_GATE_EXP,
624+
MODEL_TENSOR.FFN_DOWN_EXP,
625+
MODEL_TENSOR.FFN_UP_EXP,
626+
],
605627
MODEL_ARCH.GROK: [
606628
MODEL_TENSOR.TOKEN_EMBD,
607629
MODEL_TENSOR.OUTPUT_NORM,
@@ -1448,6 +1470,10 @@ class MODEL_TENSOR(IntEnum):
14481470
MODEL_TENSOR.ROPE_FREQS,
14491471
MODEL_TENSOR.ATTN_ROT_EMBD,
14501472
],
1473+
MODEL_ARCH.DECI: [
1474+
MODEL_TENSOR.ROPE_FREQS,
1475+
MODEL_TENSOR.ATTN_ROT_EMBD,
1476+
],
14511477
MODEL_ARCH.BAICHUAN: [
14521478
MODEL_TENSOR.ROPE_FREQS,
14531479
MODEL_TENSOR.ATTN_ROT_EMBD,

gguf-py/gguf/tensor_mapping.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -198,6 +198,7 @@ class TensorNameMap:
198198
"transformer.h.{bid}.self_attention.dense", # falcon
199199
"h.{bid}.self_attention.dense", # bloom
200200
"model.layers.{bid}.self_attn.o_proj", # llama-hf nemotron olmoe olmo2
201+
"model.layers.{bid}.self_attn.linear_attn", # deci
201202
"layers.{bid}.attention.wo", # llama-pth
202203
"encoder.layer.{bid}.attention.output.dense", # bert
203204
"transformer.h.{bid}.attn.out_proj", # gpt-j

0 commit comments

Comments
 (0)