Skip to content

Commit dd8c1e0

Browse files
authored
Merge pull request #712 from TransformerLensOrg/dev
v2.4.1
2 parents cb5017a + db1a7f5 commit dd8c1e0

File tree

2 files changed

+77
-2
lines changed

2 files changed

+77
-2
lines changed

tests/unit/test_use_attn_result.py

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
import torch
2+
3+
from transformer_lens import HookedTransformer
4+
from transformer_lens.HookedTransformerConfig import HookedTransformerConfig
5+
6+
7+
def test_atten_result_normal_attn_correct():
8+
"""Verifies that the attn_result flag does not change the output for models with normal attention."""
9+
d_model = 128
10+
d_head = 8
11+
n_heads = 16
12+
n_ctx = 128
13+
n_layers = 1
14+
d_vocab = 10
15+
16+
cfg = HookedTransformerConfig(
17+
d_model=d_model,
18+
d_head=d_head,
19+
n_heads=n_heads,
20+
n_ctx=n_ctx,
21+
n_layers=n_layers,
22+
attn_only=True,
23+
d_vocab=d_vocab,
24+
)
25+
26+
model = HookedTransformer(cfg)
27+
assert model.cfg.use_split_qkv_input is False
28+
29+
x = torch.arange(1, 9).unsqueeze(0)
30+
normal_output = model(x)
31+
32+
model.set_use_attn_result(True)
33+
assert model.cfg.use_attn_result is True
34+
35+
split_output = model(x)
36+
37+
assert torch.allclose(normal_output, split_output, atol=1e-6)
38+
39+
40+
def test_atten_result_grouped_query_attn_correct():
41+
"""Verifies that the atten_result flag does not change the output for models with grouped query attention."""
42+
43+
d_model = 128
44+
d_head = 8
45+
n_heads = 16
46+
n_ctx = 128
47+
n_key_value_heads = 2
48+
n_layers = 1
49+
d_vocab = 10
50+
51+
cfg = HookedTransformerConfig(
52+
d_model=d_model,
53+
d_head=d_head,
54+
n_heads=n_heads,
55+
n_ctx=n_ctx,
56+
n_key_value_heads=n_key_value_heads,
57+
n_layers=n_layers,
58+
attn_only=True,
59+
d_vocab=d_vocab,
60+
)
61+
62+
model = HookedTransformer(cfg)
63+
assert model.cfg.use_split_qkv_input is False
64+
65+
x = torch.arange(1, 9).unsqueeze(0)
66+
normal_output = model(x)
67+
68+
model.set_use_attn_result(True)
69+
assert model.cfg.use_attn_result is True
70+
71+
split_output = model(x)
72+
73+
assert torch.allclose(normal_output, split_output, atol=1e-6)

transformer_lens/HookedTransformer.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
alteration of activations in individual components like attention heads and MLP layers, facilitating
99
a deeper understanding of the internal workings of transformers like GPT-2.
1010
"""
11-
1211
import logging
1312
import os
1413
from typing import Dict, List, NamedTuple, Optional, Tuple, Union, cast, overload
@@ -1570,7 +1569,10 @@ def load_and_process_state_dict(
15701569
# so that quantization settings are not lost
15711570
self.load_state_dict(state_dict, assign=True, strict=False)
15721571
else:
1573-
self.load_state_dict(state_dict, strict=False)
1572+
state_dict_keys = list(state_dict.keys())
1573+
for key in state_dict_keys:
1574+
self.load_state_dict({key: state_dict[key]}, strict=False)
1575+
del state_dict[key]
15741576

15751577
def fill_missing_keys(self, state_dict):
15761578
return loading.fill_missing_keys(self, state_dict)

0 commit comments

Comments
 (0)