Skip to content

Commit d9792a9

Browse files
degenfabianbryce13950Fabian Degen
authored
Fix that if use_past_kv_cache is set to True models from the Bloom family produce weird outputs. (#777)
* Fix kv_cache leads to wrong output when used with bloom models * add test for bloom models when use_past_kv_cache is set to true * fix max_length for huggingface model in kv_cache test * set max_length to 13 for huggingface model in kv_cache test * use max_new_tokens for huggingface model instead of max_length in kv_cache test * fix format --------- Co-authored-by: Bryce Meyer <[email protected]> Co-authored-by: Fabian Degen <[email protected]>
1 parent 32b87c6 commit d9792a9

File tree

2 files changed

+22
-1
lines changed

2 files changed

+22
-1
lines changed

tests/acceptance/test_hooked_transformer.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -175,6 +175,26 @@ def test_from_pretrained_revision():
175175
raise AssertionError("Should have raised an error")
176176

177177

178+
def test_bloom_similarity_with_hf_model_with_kv_cache_activated():
179+
tf_model = HookedTransformer.from_pretrained(
180+
"bigscience/bloom-560m", default_prepend_bos=False, device="cpu"
181+
)
182+
hf_model = AutoModelForCausalLM.from_pretrained("bigscience/bloom-560m")
183+
hf_tokenizer = AutoTokenizer.from_pretrained("bigscience/bloom-560m")
184+
185+
output_tf = tf_model.generate(
186+
text, do_sample=False, use_past_kv_cache=True, verbose=False, max_new_tokens=10
187+
)
188+
output_hf_tokens = hf_model.generate(
189+
hf_tokenizer(text, return_tensors="pt").input_ids,
190+
do_sample=False,
191+
max_new_tokens=10,
192+
)
193+
output_hf_str = hf_tokenizer.decode(output_hf_tokens[0], skip_special_tokens=True)
194+
195+
assert output_tf == output_hf_str
196+
197+
178198
def check_norm_folding(
179199
model_name,
180200
hf_model=None,

transformer_lens/components/abstract_attention.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -229,8 +229,9 @@ def forward(
229229
self.cfg.n_heads, key_ctx, self.cfg.device
230230
)
231231

232+
# Take the last query_ctx positions so it also works with past_kv_cache
232233
attn_scores += self.alibi[
233-
:, :query_ctx, :key_ctx
234+
:, -query_ctx:, :key_ctx
234235
] # [batch, head_index, query_pos, key_pos]
235236
elif self.cfg.positional_embedding_type == "relative_positional_bias":
236237
if position_bias is None:

0 commit comments

Comments
 (0)