Skip to content

Commit 53dee84

Browse files
authored
Merge pull request #848 from TransformerLensOrg/dev
Release 2.13.0
2 parents db0f191 + 0c78adb commit 53dee84

File tree

14 files changed

+872
-385
lines changed

14 files changed

+872
-385
lines changed

.github/workflows/checks.yml

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ jobs:
7070
- name: Unit Test
7171
run: make unit-test
7272
env:
73-
HF_TOKEN: ${{ vars.HF_TOKEN }}
73+
HF_TOKEN: ${{ secrets.HF_TOKEN }}
7474
- name: Acceptance Test
7575
run: make acceptance-test
7676
- name: Build check
@@ -109,11 +109,11 @@ jobs:
109109
- name: Test Suite with Coverage Report
110110
run: make coverage-report-test
111111
env:
112-
HF_TOKEN: ${{ vars.HF_TOKEN }}
112+
HF_TOKEN: ${{ secrets.HF_TOKEN }}
113113
- name: Build check
114114
run: poetry build
115115
- name: Upload Coverage Report Artifact
116-
uses: actions/upload-artifact@v3
116+
uses: actions/upload-artifact@v4
117117
with:
118118
name: test-coverage
119119
path: htmlcov
@@ -192,16 +192,16 @@ jobs:
192192
- name: Install dependencies
193193
run: poetry install --with docs
194194
- name: Download Test Coverage Artifact
195-
uses: actions/download-artifact@v3
195+
uses: actions/download-artifact@v4
196196
with:
197197
name: test-coverage
198198
path: docs/source/_static/coverage
199199
- name: Build Docs
200200
run: poetry run build-docs
201201
env:
202-
HF_TOKEN: ${{ vars.HF_TOKEN }}
202+
HF_TOKEN: ${{ secrets.HF_TOKEN }}
203203
- name: Upload Docs Artifact
204-
uses: actions/upload-artifact@v3
204+
uses: actions/upload-artifact@v4
205205
with:
206206
name: documentation
207207
path: docs/build
@@ -215,7 +215,7 @@ jobs:
215215
steps:
216216
- uses: actions/checkout@v4
217217
- name: Download Docs Artifact
218-
uses: actions/download-artifact@v3
218+
uses: actions/download-artifact@v4
219219
with:
220220
name: documentation
221221
path: docs/build

demos/LLaVA.ipynb

Lines changed: 449 additions & 0 deletions
Large diffs are not rendered by default.

poetry.lock

Lines changed: 65 additions & 242 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -29,18 +29,13 @@
2929
python=">=3.8,<4.0"
3030
rich=">=12.6.0"
3131
sentencepiece="*"
32-
torch=[
33-
{platform="!=linux", version=">=1.10,!=2.0,!=2.1.0"}, # Pin >=2.1.1 on Apple devices due to known MPS errors on 2.1.0
34-
{platform="linux", version=">=1.10"}, # We can use any torch version on Linux (e.g colab)
35-
]
32+
torch=">=2.2,<2.5"
3633
tqdm=">=4.64.1"
37-
transformers=[
38-
{version=">=4.37", python=">=3.8,<3.9"},
39-
{version=">=4.41,<4.42", python=">=3.9,<4"},
40-
]
34+
transformers=">=4.43"
4135
typing-extensions="*"
4236
wandb=">=0.13.5"
4337
typeguard = "^4.2"
38+
transformers-stream-generator = "^0.0.5"
4439

4540
[tool.poetry.group]
4641
[tool.poetry.group.dev.dependencies]

tests/acceptance/test_hooked_encoder_decoder.py

Lines changed: 34 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,8 @@ def test_relative_attention_bias(our_model, huggingface_model, hello_world_token
122122

123123
embed_out = huggingface_embed(hello_world_tokens)
124124

125-
huggingface_attn_out = huggingface_attn(embed_out)[0]
125+
cache_position = torch.arange(input_len)
126+
huggingface_attn_out = huggingface_attn(embed_out, cache_position=cache_position)[0]
126127
our_attn_out = our_attn(embed_out, embed_out, embed_out, position_bias=our_bias)
127128

128129
assert_close(our_attn_out, huggingface_attn_out, rtol=7.4e-4, atol=1e-5)
@@ -139,7 +140,8 @@ def test_relative_attention_layer(our_model, huggingface_model, hello_world_toke
139140
resid_norm = our_block.ln1(resid)
140141
our_out = resid + our_block.attn(resid_norm, resid_norm, resid_norm, position_bias=our_bias)
141142

142-
hf_out = hf_block(resid)[0]
143+
cache_position = torch.arange(input_len)
144+
hf_out = hf_block(resid, cache_position=cache_position)[0]
143145
assert_close(our_out, hf_out, rtol=1.3e-6, atol=4e-5)
144146

145147

@@ -151,7 +153,10 @@ def test_attention(our_model, huggingface_model, hello_world_tokens):
151153
our_attn = our_model.encoder[1].attn
152154

153155
our_attn_out = our_attn(embed_out, embed_out, embed_out)
154-
huggingface_attn_out = huggingface_attn(embed_out)[0]
156+
157+
input_len = hello_world_tokens.shape[1]
158+
cache_position = torch.arange(input_len)
159+
huggingface_attn_out = huggingface_attn(embed_out, cache_position=cache_position)[0]
155160

156161
assert_close(our_attn_out, huggingface_attn_out, rtol=5e-4, atol=1e-5)
157162

@@ -164,7 +169,10 @@ def test_decoder_attention(our_model, huggingface_model, hello_world_tokens):
164169
our_attn = our_model.decoder[1].attn
165170

166171
our_attn_out = our_attn(embed_out, embed_out, embed_out)
167-
huggingface_attn_out = huggingface_attn(embed_out)[0]
172+
173+
input_len = hello_world_tokens.shape[1]
174+
cache_position = torch.arange(input_len)
175+
huggingface_attn_out = huggingface_attn(embed_out, cache_position=cache_position)[0]
168176
assert_close(our_attn_out, huggingface_attn_out, rtol=3e-4, atol=1e-5)
169177

170178

@@ -177,7 +185,9 @@ def test_attention_layer(our_model, huggingface_model, hello_world_tokens):
177185
norm_embed = our_model.encoder[1].ln1(embed_out)
178186
our_attn_out = our_attn(norm_embed, norm_embed, norm_embed) + embed_out
179187

180-
huggingface_attn_out = huggingface_attn(embed_out)[0]
188+
input_len = hello_world_tokens.shape[1]
189+
cache_position = torch.arange(input_len)
190+
huggingface_attn_out = huggingface_attn(embed_out, cache_position=cache_position)[0]
181191
assert_close(our_attn_out, huggingface_attn_out, rtol=2e-4, atol=1e-5)
182192

183193

@@ -190,7 +200,9 @@ def test_decoder_attention_layer(our_model, huggingface_model, hello_world_token
190200
norm_embed = our_model.decoder[1].ln1(embed_out)
191201
our_attn_out = our_attn(norm_embed, norm_embed, norm_embed) + embed_out
192202

193-
huggingface_attn_out = huggingface_attn(embed_out)[0]
203+
input_len = hello_world_tokens.shape[1]
204+
cache_position = torch.arange(input_len)
205+
huggingface_attn_out = huggingface_attn(embed_out, cache_position=cache_position)[0]
194206
assert_close(our_attn_out, huggingface_attn_out, rtol=3e-4, atol=4e-5)
195207

196208

@@ -203,7 +215,7 @@ def test_cross_attention(our_model, huggingface_model, hello_world_tokens, decod
203215

204216
our_cross_attn_out = our_cross_attn(decoder_hidden, encoder_hidden, encoder_hidden)
205217
huggingface_cross_attn_out = huggingface_cross_attn(
206-
decoder_hidden, key_value_states=encoder_hidden
218+
decoder_hidden, key_value_states=encoder_hidden, cache_position=encoder_hidden
207219
)[0]
208220
assert_close(our_cross_attn_out, huggingface_cross_attn_out, rtol=2e-4, atol=1e-5)
209221

@@ -221,7 +233,9 @@ def test_cross_attention_layer(our_model, huggingface_model, hello_world_tokens,
221233
our_layer.cross_attn(our_layer.ln2(decoder_hidden), encoder_hidden, encoder_hidden)
222234
+ decoder_hidden
223235
)
224-
huggingface_cross_attn_out = hf_layer(decoder_hidden, key_value_states=encoder_hidden)[0]
236+
huggingface_cross_attn_out = hf_layer(
237+
decoder_hidden, key_value_states=encoder_hidden, cache_position=encoder_hidden
238+
)[0]
225239
assert_close(our_cross_attn_out, huggingface_cross_attn_out, rtol=2e-4, atol=1e-5)
226240

227241

@@ -232,7 +246,9 @@ def test_encoder_block(our_model, huggingface_model, hello_world_tokens):
232246

233247
embed_out = huggingface_embed(hello_world_tokens)
234248

235-
hf_out = huggingface_block(embed_out)[0]
249+
input_len = hello_world_tokens.shape[1]
250+
cache_position = torch.arange(input_len)
251+
hf_out = huggingface_block(embed_out, cache_position=cache_position)[0]
236252
our_out = our_block(embed_out)
237253

238254
assert_close(our_out, hf_out, rtol=2e-4, atol=2e-5)
@@ -244,10 +260,17 @@ def test_decoder_block(our_model, huggingface_model, hello_world_tokens, decoder
244260
our_block = our_model.decoder[1]
245261

246262
encoder_hidden = huggingface_model.encoder(hello_world_tokens)[0]
247-
decoder_hidden = huggingface_model.decoder.block[0](huggingface_embed(decoder_input_ids))[0]
263+
264+
input_len = decoder_input_ids.shape[1]
265+
cache_position = torch.arange(input_len)
266+
decoder_hidden = huggingface_model.decoder.block[0](
267+
huggingface_embed(decoder_input_ids), cache_position=cache_position
268+
)[0]
248269

249270
our_out = our_block(decoder_hidden, encoder_hidden_states=encoder_hidden)
250-
hf_out = huggingface_block(decoder_hidden, encoder_hidden_states=encoder_hidden)[0]
271+
hf_out = huggingface_block(
272+
decoder_hidden, encoder_hidden_states=encoder_hidden, cache_position=encoder_hidden
273+
)[0]
251274

252275
assert_close(hf_out, our_out, rtol=2e-4, atol=2e-5)
253276

tests/acceptance/test_hooked_transformer.py

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -553,3 +553,77 @@ def test_all_pythia_models_exist():
553553
f"Could not download model '{model}' from Huggingface."
554554
" Maybe the name was changed or the model has been removed."
555555
)
556+
557+
558+
@pytest.mark.parametrize(
559+
"input_type,return_type",
560+
[
561+
("str", "input"),
562+
("str", "str"),
563+
("str", "tokens"),
564+
("str", "embeds"),
565+
("tokens", "input"),
566+
("tokens", "str"),
567+
("tokens", "tokens"),
568+
("tokens", "embeds"),
569+
("embeds", "input"),
570+
("embeds", "str"),
571+
("embeds", "tokens"),
572+
("embeds", "embeds"),
573+
],
574+
)
575+
def test_different_inputs_for_generation(
576+
input_type, return_type, print_output=False, max_new_tokens=3
577+
):
578+
from typing import List
579+
580+
device = "cuda" if torch.cuda.is_available() else "cpu"
581+
hooked_llm = HookedTransformer.from_pretrained("gpt2", device=device)
582+
583+
hooked_llm.eval()
584+
for text_input in [
585+
"What is the meaning of life?",
586+
["AI will destroy world", "AI will save us"],
587+
]:
588+
is_batched = False if isinstance(text_input, str) else True
589+
590+
tokens_input = hooked_llm.to_tokens(text_input)
591+
embeddings_input = hooked_llm.embed(tokens_input)
592+
593+
if input_type == "str":
594+
model_input = text_input
595+
elif input_type == "tokens":
596+
model_input = tokens_input
597+
elif input_type == "embeds":
598+
model_input = embeddings_input
599+
else:
600+
raise ValueError(f"Unknown input_type: {input_type}")
601+
602+
output = hooked_llm.generate(
603+
input=model_input, max_new_tokens=max_new_tokens, return_type=return_type, verbose=False
604+
)
605+
606+
if return_type == "str" or (return_type == "input" and input_type == "str"):
607+
if is_batched:
608+
assert isinstance(output, List), f"Expected list output but got {type(output)}"
609+
assert isinstance(
610+
output[0], str
611+
), f"Expected list of strings but got list of {type(output[0])}"
612+
else:
613+
assert isinstance(output, str), f"Expected string output but got {type(output)}"
614+
elif return_type == "tokens" or (return_type == "input" and input_type == "tokens"):
615+
assert isinstance(
616+
output, torch.Tensor
617+
), f"Expected tensor output but got {type(output)}"
618+
assert output.ndim == 2, f"Expected 2D tensor but got {output.ndim}D"
619+
elif return_type == "embeds" or (return_type == "input" and input_type == "embeds"):
620+
assert isinstance(
621+
output, torch.Tensor
622+
), f"Expected tensor output but got {type(output)}"
623+
assert output.ndim == 3, f"Expected 3D tensor but got {output.ndim}D"
624+
625+
if print_output:
626+
print(f"Input type: {input_type}, return type: {return_type}, output:\n{output}")
627+
628+
if print_output:
629+
print()

tests/integration/test_match_huggingface.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,8 @@ def test_compare_huggingface_attention_match_local_implementation(self, model_na
4040
past_kv_cache_entry=None,
4141
attention_mask=None,
4242
)
43-
hf_out, _ = hf_model.transformer.h[layer_n].attn(hidden_states=input)
43+
hf_out, _, _ = hf_model.transformer.h[layer_n].attn(
44+
hidden_states=input, output_attentions=True
45+
)
4446

4547
assert torch.sum(tl_out == hf_out) == math.prod(tl_out.shape)

transformer_lens/HookedEncoder.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,10 +53,10 @@ def __init__(self, cfg, tokenizer=None, move_to_device=True, **kwargs):
5353
if tokenizer is not None:
5454
self.tokenizer = tokenizer
5555
elif self.cfg.tokenizer_name is not None:
56-
huggingface_token = os.environ.get("HF_TOKEN", None)
56+
huggingface_token = os.environ.get("HF_TOKEN", "")
5757
self.tokenizer = AutoTokenizer.from_pretrained(
5858
self.cfg.tokenizer_name,
59-
token=huggingface_token,
59+
token=huggingface_token if len(huggingface_token) > 0 else None,
6060
)
6161
else:
6262
self.tokenizer = None

transformer_lens/HookedEncoderDecoder.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,10 +57,10 @@ def __init__(self, cfg, tokenizer=None, move_to_device=True, **kwargs):
5757
if tokenizer is not None:
5858
self.tokenizer = tokenizer
5959
elif self.cfg.tokenizer_name is not None:
60-
huggingface_token = os.environ.get("HF_TOKEN", None)
60+
huggingface_token = os.environ.get("HF_TOKEN", "")
6161
self.tokenizer = AutoTokenizer.from_pretrained(
6262
self.cfg.tokenizer_name,
63-
token=huggingface_token,
63+
token=huggingface_token if len(huggingface_token) > 0 else None,
6464
)
6565
else:
6666
self.tokenizer = None

0 commit comments

Comments
 (0)