Skip to content

Commit be334fb

Browse files
authored
Merge pull request #721 from TransformerLensOrg/dev
Release 2.5
2 parents dd8c1e0 + f1ee5fb commit be334fb

File tree

4 files changed

+10
-4
lines changed

4 files changed

+10
-4
lines changed

.github/ISSUE_TEMPLATE/bug.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,11 @@ Please try to provide a minimal example to reproduce the bug. Error messages and
1717
Describe the characteristic of your environment:
1818
* Describe how `transformer_lens` was installed (pip, docker, source, ...)
1919
* What OS are you using? (Linux, MacOS, Windows)
20-
* Python version (We suppourt 3.7 -3.10 currently)
20+
* Python version (We support 3.7--3.10 currently)
2121

2222
**Additional context**
2323
Add any other context about the problem here.
2424

2525
### Checklist
2626

27-
- [ ] I have checked that there is no similar [issue](https://github.com/TransformerLensOrg/TransformerLens/issues) in the repo (**required**)
27+
- [ ] I have checked that there is no similar [issue](https://github.com/TransformerLensOrg/TransformerLens/issues) in the repo (**required**)

transformer_lens/HookedTransformer.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1070,6 +1070,7 @@ def from_pretrained(
10701070
default_prepend_bos: bool = True,
10711071
default_padding_side: Literal["left", "right"] = "right",
10721072
dtype="float32",
1073+
first_n_layers: Optional[int] = None,
10731074
**from_pretrained_kwargs,
10741075
) -> "HookedTransformer":
10751076
"""Load in a Pretrained Model.
@@ -1204,6 +1205,7 @@ def from_pretrained(
12041205
the model.
12051206
default_padding_side: Which side to pad on when tokenizing. Defaults to
12061207
"right".
1208+
first_n_layers: If specified, only load the first n layers of the model.
12071209
"""
12081210

12091211
assert not (
@@ -1261,6 +1263,7 @@ def from_pretrained(
12611263
n_devices=n_devices,
12621264
default_prepend_bos=default_prepend_bos,
12631265
dtype=dtype,
1266+
first_n_layers=first_n_layers,
12641267
**from_pretrained_kwargs,
12651268
)
12661269

transformer_lens/HookedTransformerConfig.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ class HookedTransformerConfig:
7878
attention
7979
attn_types (List[str], *optional*): the types of attention to use for
8080
local attention
81-
weight_init_mode (str): the initialization mode to use for the
81+
init_mode (str): the initialization mode to use for the
8282
weights. Only relevant for custom models, ignored for pre-trained.
8383
We now support 'gpt2', 'xavier_uniform', 'xavier_normal', 'kaiming_uniform',
8484
'kaiming_normal'. MuP support to come. Defaults to 'gpt2'.
@@ -100,7 +100,7 @@ class HookedTransformerConfig:
100100
Used to set sources of randomness (Python, PyTorch and NumPy) and to initialize weights.
101101
Defaults to None. We recommend setting a seed, so your experiments are reproducible.
102102
initializer_range (float): The standard deviation of the normal used to
103-
initialise the weights, initialized to 0.8 / sqrt(d_model). If weight_init_mode is
103+
initialise the weights, initialized to 0.8 / sqrt(d_model). If init_mode is
104104
'xavier_uniform' or 'xavier_normal', this value is instead treated as the `gain` parameter for the weight
105105
initialisation (a constant factor to scale the weights by). Defaults to -1.0, which means not set.
106106
init_weights (bool): Whether to initialize the weights. Defaults to

transformer_lens/loading_from_pretrained.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1389,6 +1389,7 @@ def get_pretrained_model_config(
13891389
n_devices: int = 1,
13901390
default_prepend_bos: bool = True,
13911391
dtype: torch.dtype = torch.float32,
1392+
first_n_layers: Optional[int] = None,
13921393
**kwargs,
13931394
):
13941395
"""Returns the pretrained model config as an HookedTransformerConfig object.
@@ -1501,6 +1502,8 @@ def get_pretrained_model_config(
15011502
cfg_dict["default_prepend_bos"] = default_prepend_bos
15021503
if hf_cfg is not None:
15031504
cfg_dict["load_in_4bit"] = hf_cfg.get("quantization_config", {}).get("load_in_4bit", False)
1505+
if first_n_layers is not None:
1506+
cfg_dict["n_layers"] = first_n_layers
15041507

15051508
cfg = HookedTransformerConfig.from_dict(cfg_dict)
15061509
return cfg

0 commit comments

Comments
 (0)