Skip to content

Commit 4172ce8

Browse files
authored
Model llama 3.2 (#734)
* fixed typo * added llama 3.2-1b * configured 3b * configured instruct models
1 parent 0eda78f commit 4172ce8

File tree

2 files changed

+81
-1
lines changed

2 files changed

+81
-1
lines changed

transformer_lens/HookedEncoder.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -255,7 +255,7 @@ def from_pretrained(
255255
if move_to_device:
256256
model.to(cfg.device)
257257

258-
print(f"Loaded pretrained model {model_name} into HookedTransformer")
258+
print(f"Loaded pretrained model {model_name} into HookedEncoder")
259259

260260
return model
261261

transformer_lens/loading_from_pretrained.py

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,10 @@
151151
"meta-llama/Meta-Llama-3-8B-Instruct",
152152
"meta-llama/Meta-Llama-3-70B",
153153
"meta-llama/Meta-Llama-3-70B-Instruct",
154+
"meta-llama/Llama-3.2-1B",
155+
"meta-llama/Llama-3.2-3B",
156+
"meta-llama/Llama-3.2-1B-Instruct",
157+
"meta-llama/Llama-3.2-3B-Instruct",
154158
"Baidicoot/Othello-GPT-Transformer-Lens",
155159
"bert-base-cased",
156160
"roneneldan/TinyStories-1M",
@@ -885,6 +889,82 @@ def convert_hf_model_config(model_name: str, **kwargs):
885889
"final_rms": True,
886890
"gated_mlp": True,
887891
}
892+
elif "Llama-3.2-1B" in official_model_name:
893+
cfg_dict = {
894+
"d_model": 2048,
895+
"d_head": 64,
896+
"n_heads": 32,
897+
"d_mlp": 8192,
898+
"n_layers": 16,
899+
"n_ctx": 2048, # capped due to memory issues
900+
"eps": 1e-5,
901+
"d_vocab": 128256,
902+
"act_fn": "silu",
903+
"n_key_value_heads": 8,
904+
"normalization_type": "RMS",
905+
"positional_embedding_type": "rotary",
906+
"rotary_adjacent_pairs": False,
907+
"rotary_dim": 64,
908+
"final_rms": True,
909+
"gated_mlp": True,
910+
}
911+
elif "Llama-3.2-3B" in official_model_name:
912+
cfg_dict = {
913+
"d_model": 3072,
914+
"d_head": 128,
915+
"n_heads": 24,
916+
"d_mlp": 8192,
917+
"n_layers": 28,
918+
"n_ctx": 2048, # capped due to memory issues
919+
"eps": 1e-5,
920+
"d_vocab": 128256,
921+
"act_fn": "silu",
922+
"n_key_value_heads": 8,
923+
"normalization_type": "RMS",
924+
"positional_embedding_type": "rotary",
925+
"rotary_adjacent_pairs": False,
926+
"rotary_dim": 128,
927+
"final_rms": True,
928+
"gated_mlp": True,
929+
}
930+
elif "Llama-3.2-1B-Instruct" in official_model_name:
931+
cfg_dict = {
932+
"d_model": 2048,
933+
"d_head": 64,
934+
"n_heads": 32,
935+
"d_mlp": 8192,
936+
"n_layers": 16,
937+
"n_ctx": 2048, # capped due to memory issues
938+
"eps": 1e-5,
939+
"d_vocab": 128256,
940+
"act_fn": "silu",
941+
"n_key_value_heads": 8,
942+
"normalization_type": "RMS",
943+
"positional_embedding_type": "rotary",
944+
"rotary_adjacent_pairs": False,
945+
"rotary_dim": 64,
946+
"final_rms": True,
947+
"gated_mlp": True,
948+
}
949+
elif "Llama-3.2-3B-Instruct" in official_model_name:
950+
cfg_dict = {
951+
"d_model": 3072,
952+
"d_head": 128,
953+
"n_heads": 24,
954+
"d_mlp": 8192,
955+
"n_layers": 28,
956+
"n_ctx": 2048, # capped due to memory issues
957+
"eps": 1e-5,
958+
"d_vocab": 128256,
959+
"act_fn": "silu",
960+
"n_key_value_heads": 8,
961+
"normalization_type": "RMS",
962+
"positional_embedding_type": "rotary",
963+
"rotary_adjacent_pairs": False,
964+
"rotary_dim": 128,
965+
"final_rms": True,
966+
"gated_mlp": True,
967+
}
888968
elif architecture == "GPTNeoForCausalLM":
889969
cfg_dict = {
890970
"d_model": hf_config.hidden_size,

0 commit comments

Comments
 (0)