Skip to content

Commit 567a627

Browse files
Unit tests loading from pretrained fill missing keys (#623)
* Add unit tests for fill_missing_keys * Reformat test_loading_from_pretrained.py with black * Rename unit test file to test_loading_from_pretrained_utilities to avoid naming conflict
1 parent 2ee51c5 commit 567a627

File tree

1 file changed

+72
-0
lines changed

1 file changed

+72
-0
lines changed
Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
from unittest import mock
2+
3+
import pytest
4+
5+
from transformer_lens import HookedTransformer
6+
from transformer_lens.HookedTransformerConfig import HookedTransformerConfig
7+
from transformer_lens.loading_from_pretrained import fill_missing_keys
8+
9+
10+
def get_default_config():
11+
return HookedTransformerConfig(
12+
d_model=128, d_head=8, n_heads=16, n_ctx=128, n_layers=1, d_vocab=50257, attn_only=True
13+
)
14+
15+
16+
# Successes
17+
18+
19+
@mock.patch("logging.warning")
20+
def test_fill_missing_keys(mock_warning):
21+
cfg = get_default_config()
22+
model = HookedTransformer(cfg)
23+
default_state_dict = model.state_dict()
24+
25+
incomplete_state_dict = {k: v for k, v in default_state_dict.items() if "W_" not in k}
26+
27+
filled_state_dict = fill_missing_keys(model, incomplete_state_dict)
28+
29+
assert set(filled_state_dict.keys()) == set(default_state_dict.keys())
30+
31+
# Check that warnings were issued for missing weight matrices
32+
for key in default_state_dict:
33+
if "W_" in key and key not in incomplete_state_dict:
34+
mock_warning.assert_any_call(
35+
f"Missing key for a weight matrix in pretrained, filled in with an empty tensor: {key}"
36+
)
37+
38+
39+
def test_fill_missing_keys_with_hf_model_keys():
40+
cfg = get_default_config()
41+
model = HookedTransformer(cfg)
42+
default_state_dict = model.state_dict()
43+
44+
incomplete_state_dict = {k: v for k, v in default_state_dict.items() if "hf_model" not in k}
45+
46+
filled_state_dict = fill_missing_keys(model, incomplete_state_dict)
47+
48+
expected_keys = set(default_state_dict.keys()) - {
49+
k for k in default_state_dict.keys() if "hf_model" in k
50+
}
51+
assert set(filled_state_dict.keys()) == expected_keys
52+
53+
54+
def test_fill_missing_keys_no_missing_keys():
55+
cfg = get_default_config()
56+
model = HookedTransformer(cfg)
57+
default_state_dict = model.state_dict()
58+
59+
filled_state_dict = fill_missing_keys(model, default_state_dict)
60+
61+
assert filled_state_dict == default_state_dict
62+
63+
64+
# Failures
65+
66+
67+
def test_fill_missing_keys_raises_error_on_invalid_model():
68+
invalid_model = None
69+
default_state_dict = {}
70+
71+
with pytest.raises(AttributeError):
72+
fill_missing_keys(invalid_model, default_state_dict)

0 commit comments

Comments
 (0)