Skip to content

Commit f4287b6

Browse files
authored
moved enable hook functionality to separate functions and tested new functions (#613)
1 parent 570a911 commit f4287b6

File tree

2 files changed

+134
-15
lines changed

2 files changed

+134
-15
lines changed

tests/unit/test_hooked_root_module.py

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
from unittest.mock import Mock
2+
3+
from transformer_lens.hook_points import HookedRootModule
4+
5+
MODEL_NAME = "solu-2l"
6+
7+
8+
def test_enable_hook_with_name():
9+
model = HookedRootModule()
10+
model.mod_dict = {"linear": Mock()}
11+
model.context_level = 5
12+
13+
hook = lambda x: False
14+
dir = "fwd"
15+
16+
model._enable_hook_with_name("linear", hook=hook, dir=dir)
17+
18+
model.mod_dict["linear"].add_hook.assert_called_with(hook, dir="fwd", level=5)
19+
20+
21+
def test_enable_hooks_for_points():
22+
model = HookedRootModule()
23+
model.mod_dict = {}
24+
model.context_level = 5
25+
26+
hook_points = {
27+
"linear": Mock(),
28+
"attn": Mock(),
29+
}
30+
31+
enabled = lambda x: x == "attn"
32+
33+
hook = lambda x: False
34+
dir = "bwd"
35+
36+
print(hook_points.items())
37+
model._enable_hooks_for_points(
38+
hook_points=hook_points.items(), enabled=enabled, hook=hook, dir=dir
39+
)
40+
41+
hook_points["attn"].add_hook.assert_called_with(hook, dir="bwd", level=5)
42+
hook_points["linear"].add_hook.assert_not_called()
43+
44+
45+
def test_enable_hook_with_string_param():
46+
model = HookedRootModule()
47+
model.mod_dict = {"linear": Mock()}
48+
model.context_level = 5
49+
50+
hook = lambda x: False
51+
dir = "fwd"
52+
53+
model._enable_hook("linear", hook=hook, dir=dir)
54+
55+
model.mod_dict["linear"].add_hook.assert_called_with(hook, dir="fwd", level=5)
56+
57+
58+
def test_enable_hook_with_callable_param():
59+
model = HookedRootModule()
60+
model.mod_dict = {"linear": Mock()}
61+
model.hook_dict = {
62+
"linear": Mock(),
63+
"attn": Mock(),
64+
}
65+
model.context_level = 5
66+
67+
enabled = lambda x: x == "attn"
68+
69+
hook = lambda x: False
70+
dir = "fwd"
71+
72+
model._enable_hook(enabled, hook=hook, dir=dir)
73+
74+
model.mod_dict["linear"].add_hook.assert_not_called()
75+
model.hook_dict["attn"].add_hook.assert_called_with(hook, dir="fwd", level=5)
76+
model.hook_dict["linear"].add_hook.assert_not_called()

transformer_lens/hook_points.py

Lines changed: 58 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,18 @@
77
from contextlib import contextmanager
88
from dataclasses import dataclass
99
from functools import partial
10-
from typing import Callable, Dict, List, Literal, Optional, Sequence, Tuple, Union, cast
10+
from typing import (
11+
Callable,
12+
Dict,
13+
Iterable,
14+
List,
15+
Literal,
16+
Optional,
17+
Sequence,
18+
Tuple,
19+
Union,
20+
cast,
21+
)
1122

1223
import torch.nn as nn
1324
import torch.utils.hooks as hooks
@@ -267,6 +278,50 @@ def add_hook(
267278
def add_perma_hook(self, name, hook, dir="fwd") -> None:
268279
self.add_hook(name, hook, dir=dir, is_permanent=True)
269280

281+
def _enable_hook_with_name(self, name: str, hook: Callable, dir: Literal["fwd", "bwd"]):
282+
"""This function takes a key for the mod_dict and enables the related hook for that module
283+
284+
Args:
285+
name (str): The module name
286+
hook (Callable): The hook to add
287+
dir (Literal["fwd", "bwd"]): The direction for the hook
288+
"""
289+
self.mod_dict[name].add_hook(hook, dir=dir, level=self.context_level)
290+
291+
def _enable_hooks_for_points(
292+
self,
293+
hook_points: Iterable[Tuple[str, HookPoint]],
294+
enabled: Callable,
295+
hook: Callable,
296+
dir: Literal["fwd", "bwd"],
297+
):
298+
"""Enables hooks for a list of points
299+
300+
Args:
301+
hook_points (Dict[str, HookPoint]): The hook points
302+
enabled (Callable): _description_
303+
hook (Callable): _description_
304+
dir (Literal["fwd", "bwd"]): _description_
305+
"""
306+
for hook_name, hook_point in hook_points:
307+
if enabled(hook_name):
308+
hook_point.add_hook(hook, dir=dir, level=self.context_level)
309+
310+
def _enable_hook(self, name: Union[str, Callable], hook: Callable, dir: Literal["fwd", "bwd"]):
311+
"""Enables an individual hook on a hook point
312+
313+
Args:
314+
name (str): The name of the hook
315+
hook (Callable): The actual hook
316+
dir (Literal["fwd", "bwd"], optional): The direction of the hook. Defaults to "fwd".
317+
"""
318+
if isinstance(name, str):
319+
self._enable_hook_with_name(name=name, hook=hook, dir=dir)
320+
else:
321+
self._enable_hooks_for_points(
322+
hook_points=self.hook_dict.items(), enabled=name, hook=hook, dir=dir
323+
)
324+
270325
@contextmanager
271326
def hooks(
272327
self,
@@ -296,21 +351,9 @@ def hooks(
296351
self.context_level += 1
297352

298353
for name, hook in fwd_hooks:
299-
if isinstance(name, str):
300-
self.mod_dict[name].add_hook(hook, dir="fwd", level=self.context_level)
301-
else:
302-
# Otherwise, name is a Boolean function on names
303-
for hook_name, hp in self.hook_dict.items():
304-
if name(hook_name):
305-
hp.add_hook(hook, dir="fwd", level=self.context_level)
354+
self._enable_hook(name=name, hook=hook, dir="fwd")
306355
for name, hook in bwd_hooks:
307-
if isinstance(name, str):
308-
self.mod_dict[name].add_hook(hook, dir="bwd", level=self.context_level)
309-
else:
310-
# Otherwise, name is a Boolean function on names
311-
for hook_name, hp in self.hook_dict: # type: ignore
312-
if name(hook_name):
313-
hp.add_hook(hook, dir="bwd", level=self.context_level)
356+
self._enable_hook(name=name, hook=hook, dir="bwd")
314357
yield self
315358
finally:
316359
if reset_hooks_end:

0 commit comments

Comments
 (0)