|
8 | 8 | alteration of activations in individual components like attention heads and MLP layers, facilitating
|
9 | 9 | a deeper understanding of the internal workings of transformers like GPT-2.
|
10 | 10 | """
|
| 11 | + |
11 | 12 | import logging
|
12 | 13 | import os
|
13 | 14 | from typing import (
|
@@ -297,23 +298,25 @@ def input_to_embed(
|
297 | 298 | if tokens.device.type != self.cfg.device:
|
298 | 299 | tokens = tokens.to(devices.get_device_for_block_index(0, self.cfg))
|
299 | 300 |
|
300 |
| - if attention_mask is not None: |
| 301 | + if ( |
| 302 | + (self.tokenizer and self.tokenizer.padding_side == "left") |
| 303 | + or attention_mask is not None |
| 304 | + or past_kv_cache is not None |
| 305 | + ): |
| 306 | + # This means we need to have an explicit attention mask. |
| 307 | + if attention_mask is None: |
| 308 | + # If the padding side is left or we are using caching, we need to compute the attention |
| 309 | + # mask for the adjustment of absolute positional embeddings and attention masking so |
| 310 | + # that pad tokens are not attended. |
| 311 | + if prepend_bos is USE_DEFAULT_VALUE: |
| 312 | + prepend_bos = self.cfg.default_prepend_bos |
| 313 | + attention_mask = utils.get_attention_mask(self.tokenizer, tokens, prepend_bos) |
| 314 | + |
301 | 315 | assert attention_mask.shape == tokens.shape, (
|
302 | 316 | f"Attention mask shape {attention_mask.shape} does not match tokens shape "
|
303 | 317 | f"{tokens.shape}"
|
304 | 318 | )
|
305 | 319 | attention_mask = attention_mask.to(devices.get_device_for_block_index(0, self.cfg))
|
306 |
| - elif ( |
307 |
| - self.tokenizer and self.tokenizer.padding_side == "left" |
308 |
| - ) or past_kv_cache is not None: |
309 |
| - # If the padding side is left or we are using caching, we need to compute the attention |
310 |
| - # mask for the adjustment of absolute positional embeddings and attention masking so |
311 |
| - # that pad tokens are not attended. |
312 |
| - |
313 |
| - if prepend_bos is USE_DEFAULT_VALUE: |
314 |
| - prepend_bos = self.cfg.default_prepend_bos |
315 |
| - attention_mask = utils.get_attention_mask(self.tokenizer, tokens, prepend_bos) |
316 |
| - |
317 | 320 | if past_kv_cache is not None:
|
318 | 321 | # past_kv_cache is not None, so we're doing caching.
|
319 | 322 | # We need to extend the previous attention_mask.
|
@@ -1080,7 +1083,7 @@ def from_pretrained(
|
1080 | 1083 | tokenizer: Optional[PreTrainedTokenizerBase] = None,
|
1081 | 1084 | move_to_device: bool = True,
|
1082 | 1085 | fold_value_biases: bool = True,
|
1083 |
| - default_prepend_bos: bool = True, |
| 1086 | + default_prepend_bos: Optional[bool] = None, |
1084 | 1087 | default_padding_side: Literal["left", "right"] = "right",
|
1085 | 1088 | dtype="float32",
|
1086 | 1089 | first_n_layers: Optional[int] = None,
|
@@ -1202,11 +1205,15 @@ def from_pretrained(
|
1202 | 1205 | remains exactly the same, and so is just broadcast across the destination positions.
|
1203 | 1206 | default_prepend_bos: Default behavior of whether to prepend the BOS
|
1204 | 1207 | token when the methods of HookedTransformer process input text to tokenize (only
|
1205 |
| - when input is a string). Defaults to True - even for models not explicitly trained |
1206 |
| - with this, heads often use the first position as a resting position and accordingly |
1207 |
| - lose information from the first token, so this empirically seems to give better |
1208 |
| - results. To change the default behavior to False, pass in default_prepend_bos=False. |
1209 |
| - Note that you can also locally override the default behavior by passing in |
| 1208 | + when input is a string). |
| 1209 | + Resolution order for default_prepend_bos: |
| 1210 | + 1. If user passes value explicitly, use that value |
| 1211 | + 2. Model-specific default from cfg_dict if it exists (e.g. for bloom models it's False) |
| 1212 | + 3. Global default (True) |
| 1213 | +
|
| 1214 | + Even for models not explicitly trained with the BOS token, heads often use the first position as a resting position |
| 1215 | + and accordingly lose information from the first token, so this empirically seems to give better |
| 1216 | + results. Note that you can also locally override the default behavior by passing in |
1210 | 1217 | prepend_bos=True/False when you call a method that processes the input string.
|
1211 | 1218 | from_pretrained_kwargs: Any other optional argument passed to
|
1212 | 1219 | HuggingFace's from_pretrained (e.g. "cache_dir" or "torch_dtype"). Also passed to
|
@@ -1350,7 +1357,7 @@ def from_pretrained_no_processing(
|
1350 | 1357 | refactor_factored_attn_matrices=False,
|
1351 | 1358 | fold_value_biases=False,
|
1352 | 1359 | dtype=torch.float32,
|
1353 |
| - default_prepend_bos=True, |
| 1360 | + default_prepend_bos=None, |
1354 | 1361 | default_padding_side="right",
|
1355 | 1362 | **from_pretrained_kwargs,
|
1356 | 1363 | ):
|
|
0 commit comments