|
7 | 7 | from contextlib import contextmanager
|
8 | 8 | from dataclasses import dataclass
|
9 | 9 | from functools import partial
|
10 |
| -from typing import Callable, Dict, List, Literal, Optional, Sequence, Tuple, Union, cast |
| 10 | +from typing import ( |
| 11 | + Callable, |
| 12 | + Dict, |
| 13 | + Iterable, |
| 14 | + List, |
| 15 | + Literal, |
| 16 | + Optional, |
| 17 | + Sequence, |
| 18 | + Tuple, |
| 19 | + Union, |
| 20 | + cast, |
| 21 | +) |
11 | 22 |
|
12 | 23 | import torch.nn as nn
|
13 | 24 | import torch.utils.hooks as hooks
|
@@ -267,6 +278,50 @@ def add_hook(
|
267 | 278 | def add_perma_hook(self, name, hook, dir="fwd") -> None:
|
268 | 279 | self.add_hook(name, hook, dir=dir, is_permanent=True)
|
269 | 280 |
|
| 281 | + def _enable_hook_with_name(self, name: str, hook: Callable, dir: Literal["fwd", "bwd"]): |
| 282 | + """This function takes a key for the mod_dict and enables the related hook for that module |
| 283 | +
|
| 284 | + Args: |
| 285 | + name (str): The module name |
| 286 | + hook (Callable): The hook to add |
| 287 | + dir (Literal["fwd", "bwd"]): The direction for the hook |
| 288 | + """ |
| 289 | + self.mod_dict[name].add_hook(hook, dir=dir, level=self.context_level) |
| 290 | + |
| 291 | + def _enable_hooks_for_points( |
| 292 | + self, |
| 293 | + hook_points: Iterable[Tuple[str, HookPoint]], |
| 294 | + enabled: Callable, |
| 295 | + hook: Callable, |
| 296 | + dir: Literal["fwd", "bwd"], |
| 297 | + ): |
| 298 | + """Enables hooks for a list of points |
| 299 | +
|
| 300 | + Args: |
| 301 | + hook_points (Dict[str, HookPoint]): The hook points |
| 302 | + enabled (Callable): _description_ |
| 303 | + hook (Callable): _description_ |
| 304 | + dir (Literal["fwd", "bwd"]): _description_ |
| 305 | + """ |
| 306 | + for hook_name, hook_point in hook_points: |
| 307 | + if enabled(hook_name): |
| 308 | + hook_point.add_hook(hook, dir=dir, level=self.context_level) |
| 309 | + |
| 310 | + def _enable_hook(self, name: Union[str, Callable], hook: Callable, dir: Literal["fwd", "bwd"]): |
| 311 | + """Enables an individual hook on a hook point |
| 312 | +
|
| 313 | + Args: |
| 314 | + name (str): The name of the hook |
| 315 | + hook (Callable): The actual hook |
| 316 | + dir (Literal["fwd", "bwd"], optional): The direction of the hook. Defaults to "fwd". |
| 317 | + """ |
| 318 | + if isinstance(name, str): |
| 319 | + self._enable_hook_with_name(name=name, hook=hook, dir=dir) |
| 320 | + else: |
| 321 | + self._enable_hooks_for_points( |
| 322 | + hook_points=self.hook_dict.items(), enabled=name, hook=hook, dir=dir |
| 323 | + ) |
| 324 | + |
270 | 325 | @contextmanager
|
271 | 326 | def hooks(
|
272 | 327 | self,
|
@@ -296,21 +351,9 @@ def hooks(
|
296 | 351 | self.context_level += 1
|
297 | 352 |
|
298 | 353 | for name, hook in fwd_hooks:
|
299 |
| - if isinstance(name, str): |
300 |
| - self.mod_dict[name].add_hook(hook, dir="fwd", level=self.context_level) |
301 |
| - else: |
302 |
| - # Otherwise, name is a Boolean function on names |
303 |
| - for hook_name, hp in self.hook_dict.items(): |
304 |
| - if name(hook_name): |
305 |
| - hp.add_hook(hook, dir="fwd", level=self.context_level) |
| 354 | + self._enable_hook(name=name, hook=hook, dir="fwd") |
306 | 355 | for name, hook in bwd_hooks:
|
307 |
| - if isinstance(name, str): |
308 |
| - self.mod_dict[name].add_hook(hook, dir="bwd", level=self.context_level) |
309 |
| - else: |
310 |
| - # Otherwise, name is a Boolean function on names |
311 |
| - for hook_name, hp in self.hook_dict: # type: ignore |
312 |
| - if name(hook_name): |
313 |
| - hp.add_hook(hook, dir="bwd", level=self.context_level) |
| 356 | + self._enable_hook(name=name, hook=hook, dir="bwd") |
314 | 357 | yield self
|
315 | 358 | finally:
|
316 | 359 | if reset_hooks_end:
|
|
0 commit comments