Skip to content

Commit dc19c08

Browse files
authored
Merge pull request #780 from TransformerLensOrg/dev
Release 2.9
2 parents 8f482fc + d9792a9 commit dc19c08

File tree

5 files changed

+88
-27
lines changed

5 files changed

+88
-27
lines changed

tests/acceptance/test_hooked_transformer.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@
6666
"redwood_attn_2l": 10.530948638916016,
6767
"solu-1l": 5.256411552429199,
6868
"tiny-stories-33M": 12.203617095947266,
69-
"bloom-560m": 4.1953,
69+
"bloom-560m": 5.237126350402832,
7070
}
7171

7272
no_processing = [
@@ -175,6 +175,26 @@ def test_from_pretrained_revision():
175175
raise AssertionError("Should have raised an error")
176176

177177

178+
def test_bloom_similarity_with_hf_model_with_kv_cache_activated():
179+
tf_model = HookedTransformer.from_pretrained(
180+
"bigscience/bloom-560m", default_prepend_bos=False, device="cpu"
181+
)
182+
hf_model = AutoModelForCausalLM.from_pretrained("bigscience/bloom-560m")
183+
hf_tokenizer = AutoTokenizer.from_pretrained("bigscience/bloom-560m")
184+
185+
output_tf = tf_model.generate(
186+
text, do_sample=False, use_past_kv_cache=True, verbose=False, max_new_tokens=10
187+
)
188+
output_hf_tokens = hf_model.generate(
189+
hf_tokenizer(text, return_tensors="pt").input_ids,
190+
do_sample=False,
191+
max_new_tokens=10,
192+
)
193+
output_hf_str = hf_tokenizer.decode(output_hf_tokens[0], skip_special_tokens=True)
194+
195+
assert output_tf == output_hf_str
196+
197+
178198
def check_norm_folding(
179199
model_name,
180200
hf_model=None,

tests/integration/test_kv_cache.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -213,6 +213,28 @@ def test_freeze_cache(pretrained):
213213
assert not t.allclose(with_cache_logits_1, with_cache_2_logits_1, atol=atol)
214214

215215

216+
def test_kv_cache_with_custom_attention_mask(pretrained):
217+
model, atol = pretrained
218+
prompt_pre = "An apple"
219+
prompt_post = " a day keeps junk the"
220+
prompt_whole = "An apple a day keeps the"
221+
tokens_pre = model.to_tokens(prompt_pre)
222+
tokens_post = model.to_tokens(prompt_post, prepend_bos=False)
223+
tokens_whole = model.to_tokens(prompt_whole)
224+
correct_logits = model(tokens_whole)
225+
226+
past_kv_cache = HookedTransformerKeyValueCache.init_cache(
227+
model.cfg, model.cfg.device, tokens_pre.shape[0]
228+
)
229+
model(tokens_pre, past_kv_cache=past_kv_cache)
230+
exp_logits = model(
231+
tokens_post,
232+
attention_mask=t.tensor([[1, 1, 1, 0, 1]], device=model.cfg.device),
233+
past_kv_cache=past_kv_cache,
234+
)
235+
assert t.allclose(correct_logits[:, -1], exp_logits[:, -1], atol=atol)
236+
237+
216238
def test_kv_cache_and_start_at_layer(pretrained):
217239
model, atol = pretrained
218240
pre_prompt = "I went to Staten Island,"

transformer_lens/HookedTransformer.py

Lines changed: 26 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
alteration of activations in individual components like attention heads and MLP layers, facilitating
99
a deeper understanding of the internal workings of transformers like GPT-2.
1010
"""
11+
1112
import logging
1213
import os
1314
from typing import (
@@ -297,23 +298,25 @@ def input_to_embed(
297298
if tokens.device.type != self.cfg.device:
298299
tokens = tokens.to(devices.get_device_for_block_index(0, self.cfg))
299300

300-
if attention_mask is not None:
301+
if (
302+
(self.tokenizer and self.tokenizer.padding_side == "left")
303+
or attention_mask is not None
304+
or past_kv_cache is not None
305+
):
306+
# This means we need to have an explicit attention mask.
307+
if attention_mask is None:
308+
# If the padding side is left or we are using caching, we need to compute the attention
309+
# mask for the adjustment of absolute positional embeddings and attention masking so
310+
# that pad tokens are not attended.
311+
if prepend_bos is USE_DEFAULT_VALUE:
312+
prepend_bos = self.cfg.default_prepend_bos
313+
attention_mask = utils.get_attention_mask(self.tokenizer, tokens, prepend_bos)
314+
301315
assert attention_mask.shape == tokens.shape, (
302316
f"Attention mask shape {attention_mask.shape} does not match tokens shape "
303317
f"{tokens.shape}"
304318
)
305319
attention_mask = attention_mask.to(devices.get_device_for_block_index(0, self.cfg))
306-
elif (
307-
self.tokenizer and self.tokenizer.padding_side == "left"
308-
) or past_kv_cache is not None:
309-
# If the padding side is left or we are using caching, we need to compute the attention
310-
# mask for the adjustment of absolute positional embeddings and attention masking so
311-
# that pad tokens are not attended.
312-
313-
if prepend_bos is USE_DEFAULT_VALUE:
314-
prepend_bos = self.cfg.default_prepend_bos
315-
attention_mask = utils.get_attention_mask(self.tokenizer, tokens, prepend_bos)
316-
317320
if past_kv_cache is not None:
318321
# past_kv_cache is not None, so we're doing caching.
319322
# We need to extend the previous attention_mask.
@@ -1080,7 +1083,7 @@ def from_pretrained(
10801083
tokenizer: Optional[PreTrainedTokenizerBase] = None,
10811084
move_to_device: bool = True,
10821085
fold_value_biases: bool = True,
1083-
default_prepend_bos: bool = True,
1086+
default_prepend_bos: Optional[bool] = None,
10841087
default_padding_side: Literal["left", "right"] = "right",
10851088
dtype="float32",
10861089
first_n_layers: Optional[int] = None,
@@ -1202,11 +1205,15 @@ def from_pretrained(
12021205
remains exactly the same, and so is just broadcast across the destination positions.
12031206
default_prepend_bos: Default behavior of whether to prepend the BOS
12041207
token when the methods of HookedTransformer process input text to tokenize (only
1205-
when input is a string). Defaults to True - even for models not explicitly trained
1206-
with this, heads often use the first position as a resting position and accordingly
1207-
lose information from the first token, so this empirically seems to give better
1208-
results. To change the default behavior to False, pass in default_prepend_bos=False.
1209-
Note that you can also locally override the default behavior by passing in
1208+
when input is a string).
1209+
Resolution order for default_prepend_bos:
1210+
1. If user passes value explicitly, use that value
1211+
2. Model-specific default from cfg_dict if it exists (e.g. for bloom models it's False)
1212+
3. Global default (True)
1213+
1214+
Even for models not explicitly trained with the BOS token, heads often use the first position as a resting position
1215+
and accordingly lose information from the first token, so this empirically seems to give better
1216+
results. Note that you can also locally override the default behavior by passing in
12101217
prepend_bos=True/False when you call a method that processes the input string.
12111218
from_pretrained_kwargs: Any other optional argument passed to
12121219
HuggingFace's from_pretrained (e.g. "cache_dir" or "torch_dtype"). Also passed to
@@ -1350,7 +1357,7 @@ def from_pretrained_no_processing(
13501357
refactor_factored_attn_matrices=False,
13511358
fold_value_biases=False,
13521359
dtype=torch.float32,
1353-
default_prepend_bos=True,
1360+
default_prepend_bos=None,
13541361
default_padding_side="right",
13551362
**from_pretrained_kwargs,
13561363
):

transformer_lens/components/abstract_attention.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -229,8 +229,9 @@ def forward(
229229
self.cfg.n_heads, key_ctx, self.cfg.device
230230
)
231231

232+
# Take the last query_ctx positions so it also works with past_kv_cache
232233
attn_scores += self.alibi[
233-
:, :query_ctx, :key_ctx
234+
:, -query_ctx:, :key_ctx
234235
] # [batch, head_index, query_pos, key_pos]
235236
elif self.cfg.positional_embedding_type == "relative_positional_bias":
236237
if position_bias is None:

transformer_lens/loading_from_pretrained.py

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1498,7 +1498,7 @@ def get_pretrained_model_config(
14981498
fold_ln: bool = False,
14991499
device: Optional[Union[str, torch.device]] = None,
15001500
n_devices: int = 1,
1501-
default_prepend_bos: bool = True,
1501+
default_prepend_bos: Optional[bool] = None,
15021502
dtype: torch.dtype = torch.float32,
15031503
first_n_layers: Optional[int] = None,
15041504
**kwargs,
@@ -1529,11 +1529,15 @@ def get_pretrained_model_config(
15291529
n_devices (int, optional): The number of devices to split the model across. Defaults to 1.
15301530
default_prepend_bos (bool, optional): Default behavior of whether to prepend the BOS token when the
15311531
methods of HookedTransformer process input text to tokenize (only when input is a string).
1532-
Defaults to True - even for models not explicitly trained with this, heads often use the
1532+
Resolution order for default_prepend_bos:
1533+
1. If user passes value explicitly, use that value
1534+
2. Model-specific default from cfg_dict if it exists (e.g. for bloom models it's False)
1535+
3. Global default (True)
1536+
1537+
Even for models not explicitly trained with the BOS token, heads often use the
15331538
first position as a resting position and accordingly lose information from the first token,
1534-
so this empirically seems to give better results. To change the default behavior to False, pass in
1535-
default_prepend_bos=False. Note that you can also locally override the default behavior by passing
1536-
in prepend_bos=True/False when you call a method that processes the input string.
1539+
so this empirically seems to give better results. Note that you can also locally override the default behavior
1540+
by passing in prepend_bos=True/False when you call a method that processes the input string.
15371541
dtype (torch.dtype, optional): The dtype to load the TransformerLens model in.
15381542
kwargs: Other optional arguments passed to HuggingFace's from_pretrained.
15391543
Also given to other HuggingFace functions when compatible.
@@ -1610,7 +1614,14 @@ def get_pretrained_model_config(
16101614

16111615
cfg_dict["device"] = device
16121616
cfg_dict["n_devices"] = n_devices
1613-
cfg_dict["default_prepend_bos"] = default_prepend_bos
1617+
1618+
if default_prepend_bos is not None:
1619+
# User explicitly set prepend_bos behavior, override config/default value
1620+
cfg_dict["default_prepend_bos"] = default_prepend_bos
1621+
elif "default_prepend_bos" not in cfg_dict:
1622+
# No config value or user override, set default value (True)
1623+
cfg_dict["default_prepend_bos"] = True
1624+
16141625
if hf_cfg is not None:
16151626
cfg_dict["load_in_4bit"] = hf_cfg.get("quantization_config", {}).get("load_in_4bit", False)
16161627
if first_n_layers is not None:

0 commit comments

Comments
 (0)