Skip to content

Commit 47fe156

Browse files
authored
Merge pull request #917 from TransformerLensOrg/dev
v2.15.1
2 parents e65fafb + d2f3f15 commit 47fe156

File tree

4 files changed

+25
-1
lines changed

4 files changed

+25
-1
lines changed

.github/workflows/checks.yml

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,12 @@ jobs:
6767
run: |
6868
poetry check --lock
6969
poetry install --with dev
70+
- name: Authenticate HuggingFace CLI
71+
run: |
72+
pip install huggingface_hub
73+
huggingface-cli login --token $HF_TOKEN
74+
env:
75+
HF_TOKEN: ${{ secrets.HF_TOKEN }}
7076
- name: Unit Test
7177
run: make unit-test
7278
env:
@@ -106,6 +112,12 @@ jobs:
106112
run: make docstring-test
107113
- name: Type check
108114
run: poetry run mypy .
115+
- name: Authenticate HuggingFace CLI
116+
run: |
117+
pip install huggingface_hub
118+
huggingface-cli login --token $HF_TOKEN
119+
env:
120+
HF_TOKEN: ${{ secrets.HF_TOKEN }}
109121
- name: Test Suite with Coverage Report
110122
run: make coverage-report-test
111123
env:
@@ -196,6 +208,12 @@ jobs:
196208
with:
197209
name: test-coverage
198210
path: docs/source/_static/coverage
211+
- name: Authenticate HuggingFace CLI
212+
run: |
213+
pip install huggingface_hub
214+
huggingface-cli login --token $HF_TOKEN
215+
env:
216+
HF_TOKEN: ${{ secrets.HF_TOKEN }}
199217
- name: Build Docs
200218
run: poetry run build-docs
201219
env:

transformer_lens/HookedTransformerConfig.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -262,6 +262,7 @@ class HookedTransformerConfig:
262262
NTK_by_parts_low_freq_factor: float = 1.0
263263
NTK_by_parts_high_freq_factor: float = 4.0
264264
NTK_by_parts_factor: float = 8.0
265+
NTK_original_ctx_len: int = 8192
265266

266267
def __post_init__(self):
267268
if self.n_heads == -1:

transformer_lens/components/abstract_attention.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -504,7 +504,7 @@ def calculate_sin_cos_rotary(
504504
factor = self.cfg.NTK_by_parts_factor
505505
low_freq_factor = self.cfg.NTK_by_parts_low_freq_factor
506506
high_freq_factor = self.cfg.NTK_by_parts_high_freq_factor
507-
old_context_len = n_ctx
507+
old_context_len = self.cfg.NTK_original_ctx_len
508508

509509
low_freq_wavelen = old_context_len / low_freq_factor
510510
high_freq_wavelen = old_context_len / high_freq_factor

transformer_lens/loading_from_pretrained.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -947,6 +947,7 @@ def convert_hf_model_config(model_name: str, **kwargs):
947947
"NTK_by_parts_low_freq_factor": 1.0,
948948
"NTK_by_parts_high_freq_factor": 4.0,
949949
"NTK_by_parts_factor": 32.0,
950+
"NTK_original_ctx_len": 8192,
950951
}
951952
elif "Llama-3.2-3B" in official_model_name:
952953
cfg_dict = {
@@ -971,6 +972,7 @@ def convert_hf_model_config(model_name: str, **kwargs):
971972
"NTK_by_parts_low_freq_factor": 1.0,
972973
"NTK_by_parts_high_freq_factor": 4.0,
973974
"NTK_by_parts_factor": 32.0,
975+
"NTK_original_ctx_len": 8192,
974976
}
975977
elif "Llama-3.3-70B" in official_model_name:
976978
cfg_dict = {
@@ -995,6 +997,7 @@ def convert_hf_model_config(model_name: str, **kwargs):
995997
"NTK_by_parts_low_freq_factor": 1.0,
996998
"NTK_by_parts_high_freq_factor": 4.0,
997999
"NTK_by_parts_factor": 8.0,
1000+
"NTK_original_ctx_len": 8192,
9981001
}
9991002
elif "Llama-3.1-8B" in official_model_name:
10001003
cfg_dict = {
@@ -1019,6 +1022,7 @@ def convert_hf_model_config(model_name: str, **kwargs):
10191022
"NTK_by_parts_low_freq_factor": 1.0,
10201023
"NTK_by_parts_high_freq_factor": 4.0,
10211024
"NTK_by_parts_factor": 8.0,
1025+
"NTK_original_ctx_len": 8192,
10221026
}
10231027
elif "Llama-3.1-70B" in official_model_name:
10241028
cfg_dict = {
@@ -1043,6 +1047,7 @@ def convert_hf_model_config(model_name: str, **kwargs):
10431047
"NTK_by_parts_low_freq_factor": 1.0,
10441048
"NTK_by_parts_high_freq_factor": 4.0,
10451049
"NTK_by_parts_factor": 8.0,
1050+
"NTK_original_ctx_len": 8192,
10461051
}
10471052
elif architecture == "GPTNeoForCausalLM":
10481053
cfg_dict = {

0 commit comments

Comments
 (0)