diff --git a/.pylintrc b/.pylintrc index 6f86387a2b..a1314bf820 100644 --- a/.pylintrc +++ b/.pylintrc @@ -75,6 +75,7 @@ disable=fixme, unnecessary-lambda-assignment, broad-exception-raised, consider-using-generator, + too-many-ancestors # Enable the message, report, category or checker with the given id(s). You can diff --git a/dash/__init__.py b/dash/__init__.py index 7fcac0f0ed..39ed65c539 100644 --- a/dash/__init__.py +++ b/dash/__init__.py @@ -41,6 +41,8 @@ from ._patch import Patch # noqa: F401,E402 from ._jupyter import jupyter_dash # noqa: F401,E402 +from ._hooks import hooks # noqa: F401,E402 + ctx = callback_context diff --git a/dash/_callback.py b/dash/_callback.py index 0e901ea8cc..8524f97e5b 100644 --- a/dash/_callback.py +++ b/dash/_callback.py @@ -1,7 +1,9 @@ import collections import hashlib from functools import wraps -from typing import Callable, Optional, Any, List, Tuple + +from typing import Callable, Optional, Any, List, Tuple, Union + import flask @@ -9,6 +11,7 @@ handle_callback_args, handle_grouped_callback_args, Output, + ClientsideFunction, Input, ) from .development.base_component import ComponentRegistry @@ -210,7 +213,10 @@ def validate_long_inputs(deps): ) -def clientside_callback(clientside_function, *args, **kwargs): +ClientsideFuncType = Union[str, ClientsideFunction] + + +def clientside_callback(clientside_function: ClientsideFuncType, *args, **kwargs): return register_clientside_callback( GLOBAL_CALLBACK_LIST, GLOBAL_CALLBACK_MAP, @@ -597,7 +603,7 @@ def register_clientside_callback( callback_map, config_prevent_initial_callbacks, inline_scripts, - clientside_function, + clientside_function: ClientsideFuncType, *args, **kwargs, ): diff --git a/dash/_hooks.py b/dash/_hooks.py new file mode 100644 index 0000000000..be1c00eee7 --- /dev/null +++ b/dash/_hooks.py @@ -0,0 +1,231 @@ +import typing as _t + +from importlib import metadata as _importlib_metadata + +import typing_extensions as _tx +import flask as _f + +from .exceptions import HookError +from .resources import ResourceType +from ._callback import ClientsideFuncType + +if _t.TYPE_CHECKING: + from .dash import Dash + from .development.base_component import Component + + ComponentType = _t.TypeVar("ComponentType", bound=Component) + LayoutType = _t.Union[ComponentType, _t.List[ComponentType]] +else: + LayoutType = None + ComponentType = None + Dash = None + + +HookDataType = _tx.TypeVar("HookDataType") + + +# pylint: disable=too-few-public-methods +class _Hook(_tx.Generic[HookDataType]): + def __init__(self, func, priority=0, final=False, data: HookDataType = None): + self.func = func + self.final = final + self.data = data + self.priority = priority + + def __call__(self, *args, **kwargs): + return self.func(*args, **kwargs) + + +class _Hooks: + def __init__(self) -> None: + self._ns = { + "setup": [], + "layout": [], + "routes": [], + "error": [], + "callback": [], + "index": [], + } + self._js_dist = [] + self._css_dist = [] + self._clientside_callbacks: _t.List[ + _t.Tuple[ClientsideFuncType, _t.Any, _t.Any] + ] = [] + + # final hooks are a single hook added to the end of regular hooks. + self._finals = {} + + def add_hook( + self, + hook: str, + func: _t.Callable, + priority: _t.Optional[int] = None, + final=False, + data=None, + ): + if final: + existing = self._finals.get(hook) + if existing: + raise HookError("Final hook already present") + self._finals[hook] = _Hook(func, final, data=data) + return + hks = self._ns.get(hook, []) + + p = 0 + if not priority and len(hks): + priority_max = max(h.priority for h in hks) + p = priority_max - 1 + + hks.append(_Hook(func, priority=p, data=data)) + self._ns[hook] = sorted(hks, reverse=True, key=lambda h: h.priority) + + def get_hooks(self, hook: str) -> _t.List[_Hook]: + final = self._finals.get(hook, None) + if final: + final = [final] + else: + final = [] + return self._ns.get(hook, []) + final + + def layout(self, priority: _t.Optional[int] = None, final: bool = False): + """ + Run a function when serving the layout, the return value + will be used as the layout. + """ + + def _wrap(func: _t.Callable[[LayoutType], LayoutType]): + self.add_hook("layout", func, priority=priority, final=final) + return func + + return _wrap + + def setup(self, priority: _t.Optional[int] = None, final: bool = False): + """ + Can be used to get a reference to the app after it is instantiated. + """ + + def _setup(func: _t.Callable[[Dash], None]): + self.add_hook("setup", func, priority=priority, final=final) + return func + + return _setup + + def route( + self, + name: _t.Optional[str] = None, + methods: _t.Sequence[str] = ("GET",), + priority: _t.Optional[int] = None, + final=False, + ): + """ + Add a route to the Dash server. + """ + + def wrap(func: _t.Callable[[], _f.Response]): + _name = name or func.__name__ + self.add_hook( + "routes", + func, + priority=priority, + final=final, + data=dict(name=_name, methods=methods), + ) + return func + + return wrap + + def error(self, priority: _t.Optional[int] = None, final=False): + """Automatically add an error handler to the dash app.""" + + def _error(func: _t.Callable[[Exception], _t.Any]): + self.add_hook("error", func, priority=priority, final=final) + return func + + return _error + + def callback(self, *args, priority: _t.Optional[int] = None, final=False, **kwargs): + """ + Add a callback to all the apps with the hook installed. + """ + + def wrap(func): + self.add_hook( + "callback", + func, + priority=priority, + final=final, + data=(list(args), dict(kwargs)), + ) + return func + + return wrap + + def clientside_callback( + self, clientside_function: ClientsideFuncType, *args, **kwargs + ): + """ + Add a callback to all the apps with the hook installed. + """ + self._clientside_callbacks.append((clientside_function, args, kwargs)) + + def script(self, distribution: _t.List[ResourceType]): + """Add js scripts to the page.""" + self._js_dist.extend(distribution) + + def stylesheet(self, distribution: _t.List[ResourceType]): + """Add stylesheets to the page.""" + self._css_dist.extend(distribution) + + def index(self, priority: _t.Optional[int] = None, final=False): + """Modify the index of the apps.""" + + def wrap(func): + self.add_hook( + "index", + func, + priority=priority, + final=final, + ) + return func + + return wrap + + +hooks = _Hooks() + + +class HooksManager: + # Flag to only run `register_setuptools` once + _registered = False + hooks = hooks + + # pylint: disable=too-few-public-methods + class HookErrorHandler: + def __init__(self, original): + self.original = original + + def __call__(self, err: Exception): + result = None + if self.original: + result = self.original(err) + hook_result = None + for hook in HooksManager.get_hooks("error"): + hook_result = hook(err) + return result or hook_result + + @classmethod + def get_hooks(cls, hook: str): + return cls.hooks.get_hooks(hook) + + @classmethod + def register_setuptools(cls): + if cls._registered: + # Only have to register once. + return + + for dist in _importlib_metadata.distributions(): + for entry in dist.entry_points: + # Look for setup.py entry points named `dash-hooks` + if entry.group != "dash-hooks": + continue + entry.load() diff --git a/dash/dash.py b/dash/dash.py index 3ad375c823..8f46dd9f7c 100644 --- a/dash/dash.py +++ b/dash/dash.py @@ -567,6 +567,8 @@ def __init__( # pylint: disable=too-many-statements for plugin in plugins: plugin.plug(self) + self._setup_hooks() + # tracks internally if a function already handled at least one request. self._got_first_request = {"pages": False, "setup_server": False} @@ -582,6 +584,38 @@ def __init__( # pylint: disable=too-many-statements ) self.setup_startup_routes() + def _setup_hooks(self): + # pylint: disable=import-outside-toplevel,protected-access + from ._hooks import HooksManager + + self._hooks = HooksManager + self._hooks.register_setuptools() + + for setup in self._hooks.get_hooks("setup"): + setup(self) + + for hook in self._hooks.get_hooks("callback"): + callback_args, callback_kwargs = hook.data + self.callback(*callback_args, **callback_kwargs)(hook.func) + + for ( + clientside_function, + args, + kwargs, + ) in self._hooks.hooks._clientside_callbacks: + _callback.register_clientside_callback( + self._callback_list, + self.callback_map, + self.config.prevent_initial_callbacks, + self._inline_scripts, + clientside_function, + *args, + **kwargs, + ) + + if self._hooks.get_hooks("error"): + self._on_error = self._hooks.HookErrorHandler(self._on_error) + def init_app(self, app=None, **kwargs): """Initialize the parts of Dash that require a flask app.""" @@ -682,6 +716,9 @@ def _setup_routes(self): "_alive_" + jupyter_dash.alive_token, jupyter_dash.serve_alive ) + for hook in self._hooks.get_hooks("routes"): + self._add_url(hook.data["name"], hook.func, hook.data["methods"]) + # catch-all for front-end routes, used by dcc.Location self._add_url("", self.index) @@ -748,6 +785,9 @@ def index_string(self, value): def serve_layout(self): layout = self._layout_value() + for hook in self._hooks.get_hooks("layout"): + layout = hook(layout) + # TODO - Set browser cache limit - pass hash into frontend return flask.Response( to_json(layout), @@ -890,9 +930,13 @@ def _relative_url_path(relative_package_path="", namespace=""): return srcs + # pylint: disable=protected-access def _generate_css_dist_html(self): external_links = self.config.external_stylesheets - links = self._collect_and_register_resources(self.css.get_all_css()) + links = self._collect_and_register_resources( + self.css.get_all_css() + + self.css._resources._filter_resources(self._hooks.hooks._css_dist) + ) return "\n".join( [ @@ -941,6 +985,9 @@ def _generate_scripts_html(self): + self.scripts._resources._filter_resources( dash_table._js_dist, dev_bundles=dev ) + + self.scripts._resources._filter_resources( + self._hooks.hooks._js_dist, dev_bundles=dev + ) ) ) @@ -1064,6 +1111,9 @@ def index(self, *args, **kwargs): # pylint: disable=unused-argument renderer=renderer, ) + for hook in self._hooks.get_hooks("index"): + index = hook(index) + checks = ( _re_index_entry_id, _re_index_config_id, diff --git a/dash/exceptions.py b/dash/exceptions.py index ce17986e54..62008210e1 100644 --- a/dash/exceptions.py +++ b/dash/exceptions.py @@ -101,3 +101,7 @@ class PageError(DashException): class ImportedInsideCallbackError(DashException): pass + + +class HookError(DashException): + pass diff --git a/dash/resources.py b/dash/resources.py index 2ac080d57a..2813da951f 100644 --- a/dash/resources.py +++ b/dash/resources.py @@ -2,20 +2,53 @@ import warnings import os +import typing as _t +import typing_extensions as _tx + + from .development.base_component import ComponentRegistry from . import exceptions +# ResourceType has `async` key, use the init form to be able to provide it. +ResourceType = _tx.TypedDict( + "ResourceType", + { + "namespace": str, + "async": _t.Union[bool, _t.Literal["eager", "lazy"]], + "dynamic": bool, + "relative_package_path": str, + "external_url": str, + "dev_package_path": str, + "absolute_path": str, + "asset_path": str, + "external_only": bool, + "filepath": str, + }, + total=False, +) + + +# pylint: disable=too-few-public-methods +class ResourceConfig: + def __init__(self, serve_locally, eager_loading): + self.eager_loading = eager_loading + self.serve_locally = serve_locally + + class Resources: - def __init__(self, resource_name): - self._resources = [] + def __init__(self, resource_name: str, config: ResourceConfig): + self._resources: _t.List[ResourceType] = [] self.resource_name = resource_name + self.config = config - def append_resource(self, resource): + def append_resource(self, resource: ResourceType): self._resources.append(resource) # pylint: disable=too-many-branches - def _filter_resources(self, all_resources, dev_bundles=False): + def _filter_resources( + self, all_resources: _t.List[ResourceType], dev_bundles=False + ): filtered_resources = [] for s in all_resources: filtered_resource = {} @@ -45,7 +78,9 @@ def _filter_resources(self, all_resources, dev_bundles=False): ) if "namespace" in s: filtered_resource["namespace"] = s["namespace"] - if "external_url" in s and not self.config.serve_locally: + if "external_url" in s and ( + s.get("external_only") or not self.config.serve_locally + ): filtered_resource["external_url"] = s["external_url"] elif "dev_package_path" in s and dev_bundles: filtered_resource["relative_package_path"] = s["dev_package_path"] @@ -54,14 +89,14 @@ def _filter_resources(self, all_resources, dev_bundles=False): elif "absolute_path" in s: filtered_resource["absolute_path"] = s["absolute_path"] elif "asset_path" in s: - info = os.stat(s["filepath"]) + info = os.stat(s["filepath"]) # type: ignore filtered_resource["asset_path"] = s["asset_path"] filtered_resource["ts"] = info.st_mtime elif self.config.serve_locally: warnings.warn( ( "You have set your config to `serve_locally=True` but " - f"A local version of {s['external_url']} is not available.\n" + f"A local version of {s['external_url']} is not available.\n" # type: ignore "If you added this file with " "`app.scripts.append_script` " "or `app.css.append_css`, use `external_scripts` " @@ -95,32 +130,25 @@ def get_library_resources(self, libraries, dev_bundles=False): return self._filter_resources(all_resources, dev_bundles) -# pylint: disable=too-few-public-methods -class _Config: - def __init__(self, serve_locally, eager_loading): - self.eager_loading = eager_loading - self.serve_locally = serve_locally - - class Css: - def __init__(self, serve_locally): - self._resources = Resources("_css_dist") - self._resources.config = self.config = _Config(serve_locally, True) + def __init__(self, serve_locally: bool): + self.config = ResourceConfig(serve_locally, True) + self._resources = Resources("_css_dist", self.config) - def append_css(self, stylesheet): + def append_css(self, stylesheet: ResourceType): self._resources.append_resource(stylesheet) def get_all_css(self): return self._resources.get_all_resources() - def get_library_css(self, libraries): + def get_library_css(self, libraries: _t.List[str]): return self._resources.get_library_resources(libraries) class Scripts: def __init__(self, serve_locally, eager): - self._resources = Resources("_js_dist") - self._resources.config = self.config = _Config(serve_locally, eager) + self.config = ResourceConfig(serve_locally, eager) + self._resources = Resources("_js_dist", self.config) def append_script(self, script): self._resources.append_resource(script) diff --git a/tests/integration/callbacks/test_layout_paths_with_callbacks.py b/tests/integration/callbacks/test_layout_paths_with_callbacks.py index 6d0152328f..8a3405c831 100644 --- a/tests/integration/callbacks/test_layout_paths_with_callbacks.py +++ b/tests/integration/callbacks/test_layout_paths_with_callbacks.py @@ -4,7 +4,10 @@ from dash import Dash, Input, Output, dcc, html import dash.testing.wait as wait +from flaky import flaky + +@flaky(max_runs=3) def test_cblp001_radio_buttons_callbacks_generating_children(dash_duo): TIMEOUT = 2 with open(os.path.join(os.path.dirname(__file__), "state_path.json")) as fp: diff --git a/tests/integration/test_hooks.py b/tests/integration/test_hooks.py new file mode 100644 index 0000000000..999819a816 --- /dev/null +++ b/tests/integration/test_hooks.py @@ -0,0 +1,188 @@ +from flask import jsonify +import requests +import pytest + +from dash import Dash, Input, Output, html, hooks, set_props + + +@pytest.fixture +def hook_cleanup(): + yield + hooks._ns["layout"] = [] + hooks._ns["setup"] = [] + hooks._ns["route"] = [] + hooks._ns["error"] = [] + hooks._ns["callback"] = [] + hooks._ns["index"] = [] + hooks._css_dist = [] + hooks._js_dist = [] + hooks._finals = {} + hooks._clientside_callbacks = [] + + +def test_hook001_layout(hook_cleanup, dash_duo): + @hooks.layout() + def on_layout(layout): + return [html.Div("Header", id="header")] + layout + + app = Dash() + app.layout = [html.Div("Body", id="body")] + + dash_duo.start_server(app) + + dash_duo.wait_for_text_to_equal("#header", "Header") + dash_duo.wait_for_text_to_equal("#body", "Body") + + +def test_hook002_setup(hook_cleanup): + setup_title = None + + @hooks.setup() + def on_setup(app: Dash): + nonlocal setup_title + setup_title = app.title + + app = Dash(title="setup-test") + app.layout = html.Div("setup") + + assert setup_title == "setup-test" + + +def test_hook003_route(hook_cleanup, dash_duo): + @hooks.route(methods=("POST",)) + def hook_route(): + return jsonify({"success": True}) + + app = Dash() + app.layout = html.Div("hook route") + + dash_duo.start_server(app) + response = requests.post(f"{dash_duo.server_url}/hook_route") + data = response.json() + assert data["success"] + + +def test_hook004_error(hook_cleanup, dash_duo): + @hooks.error() + def on_error(error): + set_props("error", {"children": str(error)}) + + app = Dash() + app.layout = [html.Button("start", id="start"), html.Div(id="error")] + + @app.callback(Input("start", "n_clicks"), prevent_initial_call=True) + def on_click(_): + raise Exception("hook error") + + dash_duo.start_server(app) + dash_duo.wait_for_element("#start").click() + dash_duo.wait_for_text_to_equal("#error", "hook error") + + +def test_hook005_callback(hook_cleanup, dash_duo): + @hooks.callback( + Output("output", "children"), + Input("start", "n_clicks"), + prevent_initial_call=True, + ) + def on_hook_cb(n_clicks): + return f"clicked {n_clicks}" + + app = Dash() + app.layout = [ + html.Button("start", id="start"), + html.Div(id="output"), + ] + + dash_duo.start_server(app) + dash_duo.wait_for_element("#start").click() + dash_duo.wait_for_text_to_equal("#output", "clicked 1") + + +def test_hook006_priority_final(hook_cleanup, dash_duo): + @hooks.layout(final=True) + def hook_final(layout): + return html.Div([html.Div("final")] + [layout], id="final-wrapper") + + @hooks.layout() + def hook1(layout): + layout.children.append(html.Div("first")) + return layout + + @hooks.layout() + def hook2(layout): + layout.children.append(html.Div("second")) + return layout + + @hooks.layout() + def hook3(layout): + layout.children.append(html.Div("third")) + return layout + + @hooks.layout(priority=6) + def hook4(layout): + layout.children.insert(0, html.Div("Prime")) + return layout + + app = Dash() + + app.layout = html.Div([html.Div("layout")], id="body") + + dash_duo.start_server(app) + dash_duo.wait_for_text_to_equal("#final-wrapper > div:first-child", "final") + dash_duo.wait_for_text_to_equal("#body > div:first-child", "Prime") + dash_duo.wait_for_text_to_equal("#body > div:nth-child(2)", "layout") + dash_duo.wait_for_text_to_equal("#body > div:nth-child(3)", "first") + dash_duo.wait_for_text_to_equal("#body > div:nth-child(4)", "second") + dash_duo.wait_for_text_to_equal("#body > div:nth-child(5)", "third") + + +def test_hook007_hook_index(hook_cleanup, dash_duo): + @hooks.index() + def hook_index(index: str): + body = "" + ib = index.find(body) + len(body) + injected = '
Hooked
' + new_index = index[ib:] + injected + index[: ib + 1] + return new_index + + app = Dash() + app.layout = html.Div(["index"]) + + dash_duo.start_server(app) + dash_duo.wait_for_text_to_equal("#hooked", "Hooked") + + +def test_hook008_hook_distributions(hook_cleanup, dash_duo): + js_uri = "https://example.com/none.js" + css_uri = "https://example.com/none.css" + hooks.script([{"external_url": js_uri, "external_only": True}]) + hooks.stylesheet([{"external_url": css_uri, "external_only": True}]) + + app = Dash() + app.layout = html.Div("distribute") + + dash_duo.start_server(app) + + assert dash_duo.find_element(f'script[src="{js_uri}"]') + assert dash_duo.find_element(f'link[href="{css_uri}"]') + + +def test_hook009_hook_clientside_callback(hook_cleanup, dash_duo): + hooks.clientside_callback( + "(n) => `Called ${n}`", + Output("hook-output", "children"), + Input("hook-start", "n_clicks"), + prevent_initial_call=True, + ) + + app = Dash() + app.layout = [ + html.Button("start", id="hook-start"), + html.Div(id="hook-output"), + ] + + dash_duo.start_server(app) + + dash_duo.wait_for_element("#hook-start").click() + dash_duo.wait_for_text_to_equal("#hook-output", "Called 1") diff --git a/tests/unit/library/test_async_resources.py b/tests/unit/library/test_async_resources.py index 4781f208fe..a4d738221b 100644 --- a/tests/unit/library/test_async_resources.py +++ b/tests/unit/library/test_async_resources.py @@ -1,15 +1,8 @@ -from dash.resources import Resources - - -class obj(object): - def __init__(self, dict): - self.__dict__ = dict +from dash.resources import Resources, ResourceConfig def test_resources_eager(): - - resource = Resources("js_test") - resource.config = obj({"eager_loading": True, "serve_locally": False}) + resource = Resources("js_test", ResourceConfig(False, True)) filtered = resource._filter_resources( [ @@ -32,9 +25,7 @@ def test_resources_eager(): def test_resources_lazy(): - - resource = Resources("js_test") - resource.config = obj({"eager_loading": False, "serve_locally": False}) + resource = Resources("js_test", ResourceConfig(False, False)) filtered = resource._filter_resources( [