8
8
import json
9
9
import re
10
10
import shutil
11
+ import sys
11
12
from pathlib import Path
12
13
from typing import Optional
13
-
14
+ from safetensors . torch import load_file as load_safetensors_file
14
15
import torch
15
16
16
17
from torchao ._models .llama .model import ModelArgs
@@ -24,63 +25,49 @@ def convert_hf_checkpoint(
24
25
) -> None :
25
26
if model_name is None :
26
27
model_name = checkpoint_dir .name
27
-
28
- # Llama 3 8B doesn't need conversion; instead, the original/consolidated.NN.pth files
29
- # need to be copied into model.pth.
30
- # Llama 3 70B can't be easily merged into one model.pth file, though, since names of the
31
- # weights is state dict are the same in each consolidated.NN.pth file. Thus, it is not
32
- # currently supported.
33
- # Along this, we need to copy the original/tokenizer.model file to tokenizer.model.tiktoken
34
- is_llama3 = "Llama-3" in model_name
35
- if is_llama3 :
36
- # Check if we have multiple original/consolidated.NN.pth files and report error
37
- # if we do for Llama 3.
38
- original_dir = checkpoint_dir / "original"
39
- pattern = re .compile (r"^consolidated\.\d{2}\.pth$" )
40
- bin_files = [bin for bin in original_dir .iterdir () if pattern .match (bin .name )]
41
- if len (bin_files ) > 1 :
42
- raise ValueError (
43
- f"Multiple consolidated.NN.pth files found in { original_dir } . "
44
- "Merging them into one model.pth file is not supported for Llama 3." )
45
-
46
-
47
28
config = ModelArgs .from_name (model_name )
48
29
print (f"Model config { config .__dict__ } " )
49
30
50
31
# Load the json file containing weight mapping
51
- if not is_llama3 :
52
- model_map_json = checkpoint_dir / "pytorch_model.bin.index.json"
53
-
54
- assert model_map_json .is_file ()
55
-
56
- with open (model_map_json ) as json_map :
57
- bin_index = json .load (json_map )
58
-
59
- weight_map = {
60
- "model.embed_tokens.weight" : "tok_embeddings.weight" ,
61
- "model.layers.{}.self_attn.q_proj.weight" : "layers.{}.attention.wq.weight" ,
62
- "model.layers.{}.self_attn.k_proj.weight" : "layers.{}.attention.wk.weight" ,
63
- "model.layers.{}.self_attn.v_proj.weight" : "layers.{}.attention.wv.weight" ,
64
- "model.layers.{}.self_attn.o_proj.weight" : "layers.{}.attention.wo.weight" ,
65
- 'model.layers.{}.self_attn.rotary_emb.inv_freq' : None ,
66
- 'model.layers.{}.mlp.gate_proj.weight' : 'layers.{}.feed_forward.w1.weight' ,
67
- "model.layers.{}.mlp.up_proj.weight" : "layers.{}.feed_forward.w3.weight" ,
68
- "model.layers.{}.mlp.down_proj.weight" : "layers.{}.feed_forward.w2.weight" ,
69
- "model.layers.{}.input_layernorm.weight" : "layers.{}.attention_norm.weight" ,
70
- "model.layers.{}.post_attention_layernorm.weight" : "layers.{}.ffn_norm.weight" ,
71
- "model.norm.weight" : "norm.weight" ,
72
- "lm_head.weight" : "output.weight" ,
73
- }
74
- bin_files = {checkpoint_dir / bin for bin in bin_index ["weight_map" ].values ()}
75
- else :
76
- # There is no separate pytorch_model.bin.index.json file for llama3.
77
- # Instead, we will just use all original/consolidated.NN.pth files.
78
- # so, we use model.safetensors.index.json
79
- weight_map = None
80
- original_dir = checkpoint_dir / "original"
81
- pattern = re .compile (r"^consolidated\.\d{2}\.pth$" )
82
- bin_files = {bin for bin in original_dir .iterdir () if pattern .match (bin .name )}
83
-
32
+ model_map_json_safetensors = checkpoint_dir / 'model.safetensors.index.json'
33
+ model_map_json_pytorch = checkpoint_dir / "pytorch_model.bin.index.json"
34
+ model_map_json = None
35
+
36
+ try :
37
+ assert model_map_json_safetensors .is_file ()
38
+ model_map_json = model_map_json_safetensors
39
+ print (f"Found safetensors index at { model_map_json_safetensors } " )
40
+ except AssertionError :
41
+ print (f"{ model_map_json_safetensors } not found" )
42
+ if model_map_json is None :
43
+ try :
44
+ assert model_map_json_pytorch .is_file ()
45
+ model_map_json = model_map_json_pytorch
46
+ print (f"Found pytorch index at { model_map_json_pytorch } " )
47
+ except AssertionError :
48
+ print (f"{ model_map_json_pytorch } not found" )
49
+
50
+ if model_map_json is None : raise Exception ("No model map found!" )
51
+
52
+ with open (model_map_json ) as json_map :
53
+ bin_index = json .load (json_map )
54
+
55
+ weight_map = {
56
+ "model.embed_tokens.weight" : "tok_embeddings.weight" ,
57
+ "model.layers.{}.self_attn.q_proj.weight" : "layers.{}.attention.wq.weight" ,
58
+ "model.layers.{}.self_attn.k_proj.weight" : "layers.{}.attention.wk.weight" ,
59
+ "model.layers.{}.self_attn.v_proj.weight" : "layers.{}.attention.wv.weight" ,
60
+ "model.layers.{}.self_attn.o_proj.weight" : "layers.{}.attention.wo.weight" ,
61
+ 'model.layers.{}.self_attn.rotary_emb.inv_freq' : None ,
62
+ 'model.layers.{}.mlp.gate_proj.weight' : 'layers.{}.feed_forward.w1.weight' ,
63
+ "model.layers.{}.mlp.up_proj.weight" : "layers.{}.feed_forward.w3.weight" ,
64
+ "model.layers.{}.mlp.down_proj.weight" : "layers.{}.feed_forward.w2.weight" ,
65
+ "model.layers.{}.input_layernorm.weight" : "layers.{}.attention_norm.weight" ,
66
+ "model.layers.{}.post_attention_layernorm.weight" : "layers.{}.ffn_norm.weight" ,
67
+ "model.norm.weight" : "norm.weight" ,
68
+ "lm_head.weight" : "output.weight" ,
69
+ }
70
+ bin_files = {checkpoint_dir / bin for bin in bin_index ["weight_map" ].values ()}
84
71
85
72
def permute (w , n_head ):
86
73
dim = config .dim
@@ -92,40 +79,44 @@ def permute(w, n_head):
92
79
93
80
merged_result = {}
94
81
for file in sorted (bin_files ):
95
- state_dict = torch .load (str (file ), map_location = "cpu" , mmap = True , weights_only = True )
96
- merged_result .update (state_dict )
82
+ if "safetensors" in str (file ):
83
+ state_dict = load_safetensors_file (str (file ), device = "cpu" )
84
+ merged_result .update (state_dict )
85
+ else :
86
+ state_dict = torch .load (str (file ), map_location = "cpu" , mmap = True , weights_only = True )
87
+ merged_result .update (state_dict )
97
88
final_result = {}
98
- if weight_map is not None :
99
- for key , value in merged_result .items ():
100
- if "layers" in key :
101
- abstract_key = re .sub (r'(\d+)' , '{}' , key )
102
- layer_num = re .search (r'\d+' , key ).group (0 )
103
- new_key = weight_map [abstract_key ]
104
- if new_key is None :
105
- continue
106
- new_key = new_key .format (layer_num )
107
- else :
108
- new_key = weight_map [key ]
109
-
110
- final_result [new_key ] = value
111
-
112
- for key in tuple (final_result .keys ()):
113
- if "wq" in key :
114
- q = final_result [key ]
115
- k = final_result [key .replace ("wq" , "wk" )]
116
- v = final_result [key .replace ("wq" , "wv" )]
117
- q = permute (q , config .n_head )
118
- k = permute (k , config .n_local_heads )
119
- final_result [key .replace ("wq" , "wqkv" )] = torch .cat ([q , k , v ])
120
- del final_result [key ]
121
- del final_result [key .replace ("wq" , "wk" )]
122
- del final_result [key .replace ("wq" , "wv" )]
123
- else :
124
- final_result = merged_result
89
+ for key , value in merged_result .items ():
90
+ if "layers" in key :
91
+ abstract_key = re .sub (r'(\d+)' , '{}' , key )
92
+ layer_num = re .search (r'\d+' , key ).group (0 )
93
+ new_key = weight_map [abstract_key ]
94
+ if new_key is None :
95
+ continue
96
+ new_key = new_key .format (layer_num )
97
+ else :
98
+ new_key = weight_map [key ]
99
+
100
+ final_result [new_key ] = value
101
+
102
+ for key in tuple (final_result .keys ()):
103
+ if "wq" in key :
104
+ q = final_result [key ]
105
+ k = final_result [key .replace ("wq" , "wk" )]
106
+ v = final_result [key .replace ("wq" , "wv" )]
107
+ q = permute (q , config .n_head )
108
+ k = permute (k , config .n_local_heads )
109
+ final_result [key .replace ("wq" , "wqkv" )] = torch .cat ([q , k , v ])
110
+ del final_result [key ]
111
+ del final_result [key .replace ("wq" , "wk" )]
112
+ del final_result [key .replace ("wq" , "wv" )]
125
113
print (f"Saving checkpoint to { checkpoint_dir / 'model.pth' } " )
126
114
torch .save (final_result , checkpoint_dir / "model.pth" )
127
- if is_llama3 :
128
- original_dir = checkpoint_dir / "original"
115
+ if 'llama-3-' in model_name .lower () or 'llama-3.1-' in model_name .lower ():
116
+ if 'llama-3.1-405b' in model_name .lower ():
117
+ original_dir = checkpoint_dir / "original" / "mp16"
118
+ else :
119
+ original_dir = checkpoint_dir / "original"
129
120
tokenizer_model = original_dir / "tokenizer.model"
130
121
tokenizer_model_tiktoken = checkpoint_dir / "tokenizer.model"
131
122
print (f"Copying { tokenizer_model } to { tokenizer_model_tiktoken } " )
0 commit comments