Skip to content

Add hooks #3029

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 15 commits into from
Nov 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .pylintrc
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ disable=fixme,
unnecessary-lambda-assignment,
broad-exception-raised,
consider-using-generator,
too-many-ancestors


# Enable the message, report, category or checker with the given id(s). You can
Expand Down
2 changes: 2 additions & 0 deletions dash/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@
from ._patch import Patch # noqa: F401,E402
from ._jupyter import jupyter_dash # noqa: F401,E402

from ._hooks import hooks # noqa: F401,E402

ctx = callback_context


Expand Down
12 changes: 9 additions & 3 deletions dash/_callback.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,17 @@
import collections
import hashlib
from functools import wraps
from typing import Callable, Optional, Any, List, Tuple

from typing import Callable, Optional, Any, List, Tuple, Union


import flask

from .dependencies import (
handle_callback_args,
handle_grouped_callback_args,
Output,
ClientsideFunction,
Input,
)
from .development.base_component import ComponentRegistry
Expand Down Expand Up @@ -210,7 +213,10 @@ def validate_long_inputs(deps):
)


def clientside_callback(clientside_function, *args, **kwargs):
ClientsideFuncType = Union[str, ClientsideFunction]


def clientside_callback(clientside_function: ClientsideFuncType, *args, **kwargs):
return register_clientside_callback(
GLOBAL_CALLBACK_LIST,
GLOBAL_CALLBACK_MAP,
Expand Down Expand Up @@ -597,7 +603,7 @@ def register_clientside_callback(
callback_map,
config_prevent_initial_callbacks,
inline_scripts,
clientside_function,
clientside_function: ClientsideFuncType,
*args,
**kwargs,
):
Expand Down
231 changes: 231 additions & 0 deletions dash/_hooks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,231 @@
import typing as _t
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just curious: why this rather than importing the specific things you need from typing (e.g., from typing import TypeVar)?


from importlib import metadata as _importlib_metadata

import typing_extensions as _tx
import flask as _f

from .exceptions import HookError
from .resources import ResourceType
from ._callback import ClientsideFuncType

if _t.TYPE_CHECKING:
from .dash import Dash
from .development.base_component import Component

ComponentType = _t.TypeVar("ComponentType", bound=Component)
LayoutType = _t.Union[ComponentType, _t.List[ComponentType]]
else:
LayoutType = None
ComponentType = None
Dash = None


HookDataType = _tx.TypeVar("HookDataType")


# pylint: disable=too-few-public-methods
class _Hook(_tx.Generic[HookDataType]):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

docstrings? or do we not document internal classes like this?

def __init__(self, func, priority=0, final=False, data: HookDataType = None):
self.func = func
self.final = final
self.data = data
self.priority = priority

def __call__(self, *args, **kwargs):
return self.func(*args, **kwargs)


class _Hooks:
def __init__(self) -> None:
self._ns = {
"setup": [],
"layout": [],
"routes": [],
"error": [],
"callback": [],
"index": [],
}
self._js_dist = []
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what is _dist short for?

self._css_dist = []
self._clientside_callbacks: _t.List[
_t.Tuple[ClientsideFuncType, _t.Any, _t.Any]
] = []

# final hooks are a single hook added to the end of regular hooks.
self._finals = {}

def add_hook(
self,
hook: str,
func: _t.Callable,
priority: _t.Optional[int] = None,
final=False,
data=None,
):
if final:
existing = self._finals.get(hook)
if existing:
raise HookError("Final hook already present")
self._finals[hook] = _Hook(func, final, data=data)
return
hks = self._ns.get(hook, [])

p = 0
if not priority and len(hks):
priority_max = max(h.priority for h in hks)
p = priority_max - 1

hks.append(_Hook(func, priority=p, data=data))
self._ns[hook] = sorted(hks, reverse=True, key=lambda h: h.priority)

def get_hooks(self, hook: str) -> _t.List[_Hook]:
final = self._finals.get(hook, None)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this just to ensure that final is the last hook in this list? if so, maybe a comment to that effect?

if final:
final = [final]
else:
final = []
return self._ns.get(hook, []) + final

def layout(self, priority: _t.Optional[int] = None, final: bool = False):
"""
Run a function when serving the layout, the return value
will be used as the layout.
"""

def _wrap(func: _t.Callable[[LayoutType], LayoutType]):
self.add_hook("layout", func, priority=priority, final=final)
return func

return _wrap

def setup(self, priority: _t.Optional[int] = None, final: bool = False):
"""
Can be used to get a reference to the app after it is instantiated.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if the hook's purpose is to get the app, why is the hook called setup?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's to setup the app, called during app setup. Can be used to automatically setup flask extensions, get a reference to the app, add flask blueprint/static routes, etc.

"""

def _setup(func: _t.Callable[[Dash], None]):
self.add_hook("setup", func, priority=priority, final=final)
return func

return _setup

def route(
self,
name: _t.Optional[str] = None,
methods: _t.Sequence[str] = ("GET",),
priority: _t.Optional[int] = None,
final=False,
):
"""
Add a route to the Dash server.
"""

def wrap(func: _t.Callable[[], _f.Response]):
_name = name or func.__name__
self.add_hook(
"routes",
func,
priority=priority,
final=final,
data=dict(name=_name, methods=methods),
)
return func

return wrap

def error(self, priority: _t.Optional[int] = None, final=False):
"""Automatically add an error handler to the dash app."""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

line break after/before """ (see line 139-141 below)


def _error(func: _t.Callable[[Exception], _t.Any]):
self.add_hook("error", func, priority=priority, final=final)
return func

return _error

def callback(self, *args, priority: _t.Optional[int] = None, final=False, **kwargs):
"""
Add a callback to all the apps with the hook installed.
"""

def wrap(func):
self.add_hook(
"callback",
func,
priority=priority,
final=final,
data=(list(args), dict(kwargs)),
)
return func

return wrap

def clientside_callback(
self, clientside_function: ClientsideFuncType, *args, **kwargs
):
"""
Add a callback to all the apps with the hook installed.
"""
self._clientside_callbacks.append((clientside_function, args, kwargs))

def script(self, distribution: _t.List[ResourceType]):
"""Add js scripts to the page."""
self._js_dist.extend(distribution)

def stylesheet(self, distribution: _t.List[ResourceType]):
"""Add stylesheets to the page."""
self._css_dist.extend(distribution)

def index(self, priority: _t.Optional[int] = None, final=False):
"""Modify the index of the apps."""

def wrap(func):
self.add_hook(
"index",
func,
priority=priority,
final=final,
)
return func

return wrap


hooks = _Hooks()


class HooksManager:
# Flag to only run `register_setuptools` once
_registered = False
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please add a comment explaining what _registered is for - in particular, why is it a class-level variable instead of an instance variable?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

see above

hooks = hooks

# pylint: disable=too-few-public-methods
class HookErrorHandler:
def __init__(self, original):
self.original = original

def __call__(self, err: Exception):
result = None
if self.original:
result = self.original(err)
hook_result = None
for hook in HooksManager.get_hooks("error"):
hook_result = hook(err)
return result or hook_result

@classmethod
def get_hooks(cls, hook: str):
return cls.hooks.get_hooks(hook)

@classmethod
def register_setuptools(cls):
if cls._registered:
# Only have to register once.
return

for dist in _importlib_metadata.distributions():
for entry in dist.entry_points:
# Look for setup.py entry points named `dash-hooks`
if entry.group != "dash-hooks":
continue
entry.load()
52 changes: 51 additions & 1 deletion dash/dash.py
Original file line number Diff line number Diff line change
Expand Up @@ -567,6 +567,8 @@ def __init__( # pylint: disable=too-many-statements
for plugin in plugins:
plugin.plug(self)

self._setup_hooks()

# tracks internally if a function already handled at least one request.
self._got_first_request = {"pages": False, "setup_server": False}

Expand All @@ -582,6 +584,38 @@ def __init__( # pylint: disable=too-many-statements
)
self.setup_startup_routes()

def _setup_hooks(self):
# pylint: disable=import-outside-toplevel,protected-access
from ._hooks import HooksManager

self._hooks = HooksManager
self._hooks.register_setuptools()

for setup in self._hooks.get_hooks("setup"):
setup(self)

for hook in self._hooks.get_hooks("callback"):
callback_args, callback_kwargs = hook.data
self.callback(*callback_args, **callback_kwargs)(hook.func)

for (
clientside_function,
args,
kwargs,
) in self._hooks.hooks._clientside_callbacks:
_callback.register_clientside_callback(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice and clean

self._callback_list,
self.callback_map,
self.config.prevent_initial_callbacks,
self._inline_scripts,
clientside_function,
*args,
**kwargs,
)

if self._hooks.get_hooks("error"):
self._on_error = self._hooks.HookErrorHandler(self._on_error)

def init_app(self, app=None, **kwargs):
"""Initialize the parts of Dash that require a flask app."""

Expand Down Expand Up @@ -682,6 +716,9 @@ def _setup_routes(self):
"_alive_" + jupyter_dash.alive_token, jupyter_dash.serve_alive
)

for hook in self._hooks.get_hooks("routes"):
self._add_url(hook.data["name"], hook.func, hook.data["methods"])

# catch-all for front-end routes, used by dcc.Location
self._add_url("<path:path>", self.index)

Expand Down Expand Up @@ -748,6 +785,9 @@ def index_string(self, value):
def serve_layout(self):
layout = self._layout_value()

for hook in self._hooks.get_hooks("layout"):
layout = hook(layout)

# TODO - Set browser cache limit - pass hash into frontend
return flask.Response(
to_json(layout),
Expand Down Expand Up @@ -890,9 +930,13 @@ def _relative_url_path(relative_package_path="", namespace=""):

return srcs

# pylint: disable=protected-access
def _generate_css_dist_html(self):
external_links = self.config.external_stylesheets
links = self._collect_and_register_resources(self.css.get_all_css())
links = self._collect_and_register_resources(
self.css.get_all_css()
+ self.css._resources._filter_resources(self._hooks.hooks._css_dist)
)

return "\n".join(
[
Expand Down Expand Up @@ -941,6 +985,9 @@ def _generate_scripts_html(self):
+ self.scripts._resources._filter_resources(
dash_table._js_dist, dev_bundles=dev
)
+ self.scripts._resources._filter_resources(
self._hooks.hooks._js_dist, dev_bundles=dev
)
)
)

Expand Down Expand Up @@ -1064,6 +1111,9 @@ def index(self, *args, **kwargs): # pylint: disable=unused-argument
renderer=renderer,
)

for hook in self._hooks.get_hooks("index"):
index = hook(index)

checks = (
_re_index_entry_id,
_re_index_config_id,
Expand Down
4 changes: 4 additions & 0 deletions dash/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,3 +101,7 @@ class PageError(DashException):

class ImportedInsideCallbackError(DashException):
pass


class HookError(DashException):
pass
Loading