Skip to content

Commit 73a9055

Browse files
committed
fixing some issues with our support for 70/405B models
Summary: download and convert scripts needed to be updated alongside model.py config files Test Plan: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3.1-70B/model.pth Reviewers: Subscribers: Tasks: Tags:
1 parent 8aa6533 commit 73a9055

File tree

3 files changed

+84
-87
lines changed

3 files changed

+84
-87
lines changed

scripts/convert_hf_checkpoint.py

Lines changed: 76 additions & 85 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,10 @@
88
import json
99
import re
1010
import shutil
11+
import sys
1112
from pathlib import Path
1213
from typing import Optional
13-
14+
from safetensors.torch import load_file as load_safetensors_file
1415
import torch
1516

1617
from torchao._models.llama.model import ModelArgs
@@ -24,63 +25,49 @@ def convert_hf_checkpoint(
2425
) -> None:
2526
if model_name is None:
2627
model_name = checkpoint_dir.name
27-
28-
# Llama 3 8B doesn't need conversion; instead, the original/consolidated.NN.pth files
29-
# need to be copied into model.pth.
30-
# Llama 3 70B can't be easily merged into one model.pth file, though, since names of the
31-
# weights is state dict are the same in each consolidated.NN.pth file. Thus, it is not
32-
# currently supported.
33-
# Along this, we need to copy the original/tokenizer.model file to tokenizer.model.tiktoken
34-
is_llama3 = "Llama-3" in model_name
35-
if is_llama3:
36-
# Check if we have multiple original/consolidated.NN.pth files and report error
37-
# if we do for Llama 3.
38-
original_dir = checkpoint_dir / "original"
39-
pattern = re.compile(r"^consolidated\.\d{2}\.pth$")
40-
bin_files = [bin for bin in original_dir.iterdir() if pattern.match(bin.name)]
41-
if len(bin_files) > 1:
42-
raise ValueError(
43-
f"Multiple consolidated.NN.pth files found in {original_dir}. "
44-
"Merging them into one model.pth file is not supported for Llama 3.")
45-
46-
4728
config = ModelArgs.from_name(model_name)
4829
print(f"Model config {config.__dict__}")
4930

5031
# Load the json file containing weight mapping
51-
if not is_llama3:
52-
model_map_json = checkpoint_dir / "pytorch_model.bin.index.json"
53-
54-
assert model_map_json.is_file()
55-
56-
with open(model_map_json) as json_map:
57-
bin_index = json.load(json_map)
58-
59-
weight_map = {
60-
"model.embed_tokens.weight": "tok_embeddings.weight",
61-
"model.layers.{}.self_attn.q_proj.weight": "layers.{}.attention.wq.weight",
62-
"model.layers.{}.self_attn.k_proj.weight": "layers.{}.attention.wk.weight",
63-
"model.layers.{}.self_attn.v_proj.weight": "layers.{}.attention.wv.weight",
64-
"model.layers.{}.self_attn.o_proj.weight": "layers.{}.attention.wo.weight",
65-
'model.layers.{}.self_attn.rotary_emb.inv_freq': None,
66-
'model.layers.{}.mlp.gate_proj.weight': 'layers.{}.feed_forward.w1.weight',
67-
"model.layers.{}.mlp.up_proj.weight": "layers.{}.feed_forward.w3.weight",
68-
"model.layers.{}.mlp.down_proj.weight": "layers.{}.feed_forward.w2.weight",
69-
"model.layers.{}.input_layernorm.weight": "layers.{}.attention_norm.weight",
70-
"model.layers.{}.post_attention_layernorm.weight": "layers.{}.ffn_norm.weight",
71-
"model.norm.weight": "norm.weight",
72-
"lm_head.weight": "output.weight",
73-
}
74-
bin_files = {checkpoint_dir / bin for bin in bin_index["weight_map"].values()}
75-
else:
76-
# There is no separate pytorch_model.bin.index.json file for llama3.
77-
# Instead, we will just use all original/consolidated.NN.pth files.
78-
# so, we use model.safetensors.index.json
79-
weight_map = None
80-
original_dir = checkpoint_dir / "original"
81-
pattern = re.compile(r"^consolidated\.\d{2}\.pth$")
82-
bin_files = {bin for bin in original_dir.iterdir() if pattern.match(bin.name)}
83-
32+
model_map_json_safetensors = checkpoint_dir / 'model.safetensors.index.json'
33+
model_map_json_pytorch = checkpoint_dir / "pytorch_model.bin.index.json"
34+
model_map_json = None
35+
36+
try:
37+
assert model_map_json_safetensors.is_file()
38+
model_map_json = model_map_json_safetensors
39+
print(f"Found safetensors index at {model_map_json_safetensors}")
40+
except AssertionError:
41+
print(f"{model_map_json_safetensors} not found")
42+
if model_map_json is None:
43+
try:
44+
assert model_map_json_pytorch.is_file()
45+
model_map_json = model_map_json_pytorch
46+
print(f"Found pytorch index at {model_map_json_pytorch}")
47+
except AssertionError:
48+
print(f"{model_map_json_pytorch} not found")
49+
50+
if model_map_json is None: raise Exception("No model map found!")
51+
52+
with open(model_map_json) as json_map:
53+
bin_index = json.load(json_map)
54+
55+
weight_map = {
56+
"model.embed_tokens.weight": "tok_embeddings.weight",
57+
"model.layers.{}.self_attn.q_proj.weight": "layers.{}.attention.wq.weight",
58+
"model.layers.{}.self_attn.k_proj.weight": "layers.{}.attention.wk.weight",
59+
"model.layers.{}.self_attn.v_proj.weight": "layers.{}.attention.wv.weight",
60+
"model.layers.{}.self_attn.o_proj.weight": "layers.{}.attention.wo.weight",
61+
'model.layers.{}.self_attn.rotary_emb.inv_freq': None,
62+
'model.layers.{}.mlp.gate_proj.weight': 'layers.{}.feed_forward.w1.weight',
63+
"model.layers.{}.mlp.up_proj.weight": "layers.{}.feed_forward.w3.weight",
64+
"model.layers.{}.mlp.down_proj.weight": "layers.{}.feed_forward.w2.weight",
65+
"model.layers.{}.input_layernorm.weight": "layers.{}.attention_norm.weight",
66+
"model.layers.{}.post_attention_layernorm.weight": "layers.{}.ffn_norm.weight",
67+
"model.norm.weight": "norm.weight",
68+
"lm_head.weight": "output.weight",
69+
}
70+
bin_files = {checkpoint_dir / bin for bin in bin_index["weight_map"].values()}
8471

8572
def permute(w, n_head):
8673
dim = config.dim
@@ -92,40 +79,44 @@ def permute(w, n_head):
9279

9380
merged_result = {}
9481
for file in sorted(bin_files):
95-
state_dict = torch.load(str(file), map_location="cpu", mmap=True, weights_only=True)
96-
merged_result.update(state_dict)
82+
if "safetensors" in str(file):
83+
state_dict = load_safetensors_file(str(file), device="cpu")
84+
merged_result.update(state_dict)
85+
else:
86+
state_dict = torch.load(str(file), map_location="cpu", mmap=True, weights_only=True)
87+
merged_result.update(state_dict)
9788
final_result = {}
98-
if weight_map is not None:
99-
for key, value in merged_result.items():
100-
if "layers" in key:
101-
abstract_key = re.sub(r'(\d+)', '{}', key)
102-
layer_num = re.search(r'\d+', key).group(0)
103-
new_key = weight_map[abstract_key]
104-
if new_key is None:
105-
continue
106-
new_key = new_key.format(layer_num)
107-
else:
108-
new_key = weight_map[key]
109-
110-
final_result[new_key] = value
111-
112-
for key in tuple(final_result.keys()):
113-
if "wq" in key:
114-
q = final_result[key]
115-
k = final_result[key.replace("wq", "wk")]
116-
v = final_result[key.replace("wq", "wv")]
117-
q = permute(q, config.n_head)
118-
k = permute(k, config.n_local_heads)
119-
final_result[key.replace("wq", "wqkv")] = torch.cat([q, k, v])
120-
del final_result[key]
121-
del final_result[key.replace("wq", "wk")]
122-
del final_result[key.replace("wq", "wv")]
123-
else:
124-
final_result = merged_result
89+
for key, value in merged_result.items():
90+
if "layers" in key:
91+
abstract_key = re.sub(r'(\d+)', '{}', key)
92+
layer_num = re.search(r'\d+', key).group(0)
93+
new_key = weight_map[abstract_key]
94+
if new_key is None:
95+
continue
96+
new_key = new_key.format(layer_num)
97+
else:
98+
new_key = weight_map[key]
99+
100+
final_result[new_key] = value
101+
102+
for key in tuple(final_result.keys()):
103+
if "wq" in key:
104+
q = final_result[key]
105+
k = final_result[key.replace("wq", "wk")]
106+
v = final_result[key.replace("wq", "wv")]
107+
q = permute(q, config.n_head)
108+
k = permute(k, config.n_local_heads)
109+
final_result[key.replace("wq", "wqkv")] = torch.cat([q, k, v])
110+
del final_result[key]
111+
del final_result[key.replace("wq", "wk")]
112+
del final_result[key.replace("wq", "wv")]
125113
print(f"Saving checkpoint to {checkpoint_dir / 'model.pth'}")
126114
torch.save(final_result, checkpoint_dir / "model.pth")
127-
if is_llama3:
128-
original_dir = checkpoint_dir / "original"
115+
if 'llama-3-' in model_name.lower() or 'llama-3.1-' in model_name.lower():
116+
if 'llama-3.1-405b' in model_name.lower():
117+
original_dir = checkpoint_dir / "original" / "mp16"
118+
else:
119+
original_dir = checkpoint_dir / "original"
129120
tokenizer_model = original_dir / "tokenizer.model"
130121
tokenizer_model_tiktoken = checkpoint_dir / "tokenizer.model"
131122
print(f"Copying {tokenizer_model} to {tokenizer_model_tiktoken}")

scripts/download.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ def hf_download(repo_id: Optional[str] = None, hf_token: Optional[str] = None) -
1515
from huggingface_hub import snapshot_download
1616
os.makedirs(f"checkpoints/{repo_id}", exist_ok=True)
1717
try:
18-
snapshot_download(repo_id, local_dir=f"checkpoints/{repo_id}", local_dir_use_symlinks=False, token=hf_token, ignore_patterns="*.safetensors")
18+
snapshot_download(repo_id, local_dir=f"checkpoints/{repo_id}", local_dir_use_symlinks=False, token=hf_token)
1919
except HTTPError as e:
2020
if e.response.status_code == 401:
2121
print("You need to pass a valid `--hf_token=...` to download private checkpoints.")

torchao/_models/llama/model.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,13 @@ def from_name(cls, name: str):
7272
"stories15M": dict(n_layer=6, n_head=6, dim=288),
7373
"stories110M": dict(n_layer=12, n_head=12, dim=768),
7474
"Llama-3-8B": dict(block_size=8192, n_layer=32, n_head=32, n_local_heads=8, dim=4096, intermediate_size=14336, vocab_size=128256, rope_base=500000),
75-
"Llama-3.1-8B": dict(block_size=131072, n_layer=32, n_head=32, n_local_heads=8, dim=4096, intermediate_size=14336, vocab_size=128256, rope_base=500000, use_scaled_rope=True)
75+
"Llama-3.1-8B": dict(block_size=131072, n_layer=32, n_head=32, n_local_heads=8, dim=4096, intermediate_size=14336, vocab_size=128256, rope_base=500000, use_scaled_rope=True),
76+
"Llama-3.1-70B": dict(block_size=131072, n_layer=80, n_head=64, n_local_heads=8, dim=8192, intermediate_size=28672, vocab_size=128256, rope_base=500000,
77+
use_scaled_rope=True
78+
),
79+
"Llama-3.1-405B": dict(block_size=131072, n_layer=126, n_head=128, n_local_heads=8, dim=16384, intermediate_size=53248, vocab_size=128256, rope_base=500000,
80+
use_scaled_rope=True
81+
),
7682
}
7783

7884
# this is a model specific variable that controls whether index_put is used for the kv_cache update,

0 commit comments

Comments
 (0)