diff --git a/model.py b/model.py index c2657995..17178c4f 100644 --- a/model.py +++ b/model.py @@ -79,6 +79,12 @@ def from_name(cls, name: str): "llama-3.1-405b": dict(block_size=131072, n_layer=126, n_head=128, n_local_heads=8, dim=16384, intermediate_size=53248, vocab_size=128256, rope_base=500000, rope_scaling=dict(factor=8.0, low_freq_factor=1.0, high_freq_factor=4.0, original_max_position_embeddings=8192), ), + "llama-3.2-1b": dict(block_size=131072, n_layer=16, n_head=32, n_local_heads=8, dim=2048, intermediate_size=8192, vocab_size=128256, rope_base=500000, + rope_scaling=dict(factor=32.0, low_freq_factor=1.0, high_freq_factor=4.0, original_max_position_embeddings=8192), + ), + "llama-3.2-3b": dict(block_size=131072, n_layer=28, n_head=24, n_local_heads=8, dim=3072, intermediate_size=8192, vocab_size=128256, rope_base=500000, + rope_scaling=dict(factor=32.0, low_freq_factor=1.0, high_freq_factor=4.0, original_max_position_embeddings=8192), + ), } class KVCache(nn.Module): diff --git a/scripts/convert_hf_checkpoint.py b/scripts/convert_hf_checkpoint.py index f14ba6ca..a45b955c 100644 --- a/scripts/convert_hf_checkpoint.py +++ b/scripts/convert_hf_checkpoint.py @@ -7,6 +7,7 @@ import re import shutil import sys +from typing import Dict from pathlib import Path from typing import Optional from safetensors.torch import load_file as load_safetensors_file @@ -22,7 +23,9 @@ @torch.inference_mode() def convert_hf_checkpoint( *, - checkpoint_dir: Path = Path("checkpoints/meta-Transformer/Transformer-2-7b-chat-hf"), + checkpoint_dir: Path = Path( + "checkpoints/meta-Transformer/Transformer-2-7b-chat-hf" + ), model_name: Optional[str] = None, ) -> None: if model_name is None: @@ -31,38 +34,81 @@ def convert_hf_checkpoint( config = ModelArgs.from_name(model_name) print(f"Model config {config.__dict__}") - # Load the json file containing weight mapping - model_map_json_safetensors = checkpoint_dir / 'model.safetensors.index.json' + # Check for solo safetensors file + model_solo_safetensors = checkpoint_dir / "model.safetensors" + if model_solo_safetensors.is_file(): + print(f"Found whole safetensors file at {model_solo_safetensors}") + state_dict = load_safetensors_file(str(model_solo_safetensors), device="cpu") + else: + # If solo file doesn't exist, merge indices + state_dict = merge_model_indices(checkpoint_dir) + + final_result = process_state_dict(state_dict, config) + + print(f"Saving checkpoint to {checkpoint_dir / 'model.pth'}") + torch.save(final_result, checkpoint_dir / "model.pth") + + if "llama-3-" in model_name.lower() or "llama-3.1-" in model_name.lower(): + if "llama-3.1-405b" in model_name.lower(): + original_dir = checkpoint_dir / "original" / "mp16" + else: + original_dir = checkpoint_dir / "original" + tokenizer_model = original_dir / "tokenizer.model" + tokenizer_model_tiktoken = checkpoint_dir / "tokenizer.model" + print(f"Copying {tokenizer_model} to {tokenizer_model_tiktoken}") + shutil.copy(tokenizer_model, tokenizer_model_tiktoken) + + +def merge_model_indices(checkpoint_dir: Path) -> Dict[str, torch.Tensor]: + model_map_json_safetensors = checkpoint_dir / "model.safetensors.index.json" model_map_json_pytorch = checkpoint_dir / "pytorch_model.bin.index.json" model_map_json = None - + try: - assert model_map_json_safetensors.is_file() - model_map_json = model_map_json_safetensors - print(f"Found safetensors index at {model_map_json_safetensors}") + assert model_map_json_safetensors.is_file() + model_map_json = model_map_json_safetensors + print(f"Found safetensors index at {model_map_json_safetensors}") except AssertionError: - print(f"{model_map_json_safetensors} not found") + print(f"{model_map_json_safetensors} not found") + if model_map_json is None: + try: + assert model_map_json_pytorch.is_file() + model_map_json = model_map_json_pytorch + print(f"Found pytorch index at {model_map_json_pytorch}") + except AssertionError: + print(f"{model_map_json_pytorch} not found") + if model_map_json is None: - try: - assert model_map_json_pytorch.is_file() - model_map_json = model_map_json_pytorch - print(f"Found pytorch index at {model_map_json_pytorch}") - except AssertionError: - print(f"{model_map_json_pytorch} not found") - - if model_map_json is None: raise Exception("No model map found!") + raise Exception("No model map found!") with open(model_map_json) as json_map: bin_index = json.load(json_map) + bin_files = {checkpoint_dir / bin for bin in bin_index["weight_map"].values()} + + merged_result = {} + for file in sorted(bin_files): + if "safetensors" in str(file): + state_dict = load_safetensors_file(str(file), device="cpu") + else: + state_dict = torch.load( + str(file), map_location="cpu", mmap=True, weights_only=True + ) + merged_result.update(state_dict) + return merged_result + + +def process_state_dict( + state_dict: Dict[str, torch.Tensor], config: ModelArgs +) -> Dict[str, torch.Tensor]: weight_map = { "model.embed_tokens.weight": "tok_embeddings.weight", "model.layers.{}.self_attn.q_proj.weight": "layers.{}.attention.wq.weight", "model.layers.{}.self_attn.k_proj.weight": "layers.{}.attention.wk.weight", "model.layers.{}.self_attn.v_proj.weight": "layers.{}.attention.wv.weight", "model.layers.{}.self_attn.o_proj.weight": "layers.{}.attention.wo.weight", - 'model.layers.{}.self_attn.rotary_emb.inv_freq': None, - 'model.layers.{}.mlp.gate_proj.weight': 'layers.{}.feed_forward.w1.weight', + "model.layers.{}.self_attn.rotary_emb.inv_freq": None, + "model.layers.{}.mlp.gate_proj.weight": "layers.{}.feed_forward.w1.weight", "model.layers.{}.mlp.up_proj.weight": "layers.{}.feed_forward.w3.weight", "model.layers.{}.mlp.down_proj.weight": "layers.{}.feed_forward.w2.weight", "model.layers.{}.input_layernorm.weight": "layers.{}.attention_norm.weight", @@ -70,37 +116,35 @@ def convert_hf_checkpoint( "model.norm.weight": "norm.weight", "lm_head.weight": "output.weight", } - bin_files = {checkpoint_dir / bin for bin in bin_index["weight_map"].values()} - - def permute(w, n_head): - dim = config.dim - return ( - w.view(n_head, 2, config.head_dim // 2, dim) - .transpose(1, 2) - .reshape(config.head_dim * n_head, dim) - ) - merged_result = {} - for file in sorted(bin_files): - if "safetensors" in str(file): - state_dict = load_safetensors_file(str(file), device="cpu") - merged_result.update(state_dict) - else: - state_dict = torch.load(str(file), map_location="cpu", mmap=True, weights_only=True) - merged_result.update(state_dict) final_result = {} - for key, value in merged_result.items(): + for key, value in state_dict.items(): if "layers" in key: - abstract_key = re.sub(r'(\d+)', '{}', key) - layer_num = re.search(r'\d+', key).group(0) - new_key = weight_map[abstract_key] + abstract_key = re.sub(r"(\d+)", "{}", key) + layer_num = re.search(r"\d+", key).group(0) + new_key = weight_map.get(abstract_key) if new_key is None: continue new_key = new_key.format(layer_num) else: - new_key = weight_map[key] + new_key = weight_map.get(key) + + if new_key: + final_result[new_key] = value - final_result[new_key] = value + # tie embeddings if the output weight does not exist + # necessary for 1B and 3B models + if "output.weight" not in final_result: + print("Tying embeddings - this is only necessary for 1B and 3B models") + final_result["output.weight"] = final_result["tok_embeddings.weight"] + + def permute(w, n_head): + dim = config.dim + return ( + w.view(n_head, 2, config.head_dim // 2, dim) + .transpose(1, 2) + .reshape(config.head_dim * n_head, dim) + ) for key in tuple(final_result.keys()): if "wq" in key: @@ -113,23 +157,20 @@ def permute(w, n_head): del final_result[key] del final_result[key.replace("wq", "wk")] del final_result[key.replace("wq", "wv")] - print(f"Saving checkpoint to {checkpoint_dir / 'model.pth'}") - torch.save(final_result, checkpoint_dir / "model.pth") - if 'llama-3-' in model_name.lower() or 'llama-3.1-' in model_name.lower(): - if 'llama-3.1-405b' in model_name.lower(): - original_dir = checkpoint_dir / "original" / "mp16" - else: - original_dir = checkpoint_dir / "original" - tokenizer_model = original_dir / "tokenizer.model" - tokenizer_model_tiktoken = checkpoint_dir / "tokenizer.model" - print(f"Copying {tokenizer_model} to {tokenizer_model_tiktoken}") - shutil.copy(tokenizer_model, tokenizer_model_tiktoken) -if __name__ == '__main__': + return final_result + + +if __name__ == "__main__": import argparse - parser = argparse.ArgumentParser(description='Convert HuggingFace checkpoint.') - parser.add_argument('--checkpoint_dir', type=Path, default=Path("checkpoints/meta-llama/llama-2-7b-chat-hf")) - parser.add_argument('--model_name', type=str, default=None) + + parser = argparse.ArgumentParser(description="Convert HuggingFace checkpoint.") + parser.add_argument( + "--checkpoint_dir", + type=Path, + default=Path("checkpoints/meta-llama/llama-2-7b-chat-hf"), + ) + parser.add_argument("--model_name", type=str, default=None) args = parser.parse_args() convert_hf_checkpoint(