From d3e07586348e7b3b5425a99895953f13633bbad2 Mon Sep 17 00:00:00 2001 From: Andrey Tikhonov <17@itishka.org> Date: Wed, 23 Oct 2024 22:55:52 +0200 Subject: [PATCH 1/5] scopeless graph compiler --- src/dishka/graph_compiler.py | 232 +++++++++++++++++++++++++++++++++++ src/dishka/registry.py | 31 ++++- 2 files changed, 260 insertions(+), 3 deletions(-) create mode 100644 src/dishka/graph_compiler.py diff --git a/src/dishka/graph_compiler.py b/src/dishka/graph_compiler.py new file mode 100644 index 00000000..ae1c1413 --- /dev/null +++ b/src/dishka/graph_compiler.py @@ -0,0 +1,232 @@ +import linecache +import re +from textwrap import indent +from typing import Any, Sequence, Mapping + +from .container_objects import Exit +from .entities.factory_type import FactoryType, FactoryData +from .entities.key import DependencyKey +from .entities.scope import BaseScope +from .exceptions import NoContextValueError, UnsupportedFactoryError +from .text_rendering import get_name + + +class Node(FactoryData): + __slots__ = ( + "dependencies", + "kw_dependencies", + "cache", + ) + + def __init__( + self, + *, + dependencies: Sequence["Node"], + kw_dependencies: Mapping[str, "Node"], + source: Any, + provides: DependencyKey, + scope: BaseScope, + type_: FactoryType, + cache: bool, + ) -> None: + super().__init__( + source=source, + provides=provides, + type_=type_, + scope=scope, + ) + self.dependencies = dependencies + self.kw_dependencies = kw_dependencies + self.cache = cache + + +def make_args(args: list[str], kwargs: dict[str, str]) -> str: + res = ", ".join(args) + if not kwargs: + return res + if res: + res += ", " + res += ", ".join( + f"{arg}={var}" + for arg, var in kwargs + ) + return res + + +GENERATOR = """ +generator = {source}({args}) +{var} = next(generator) +exits.append(Exit(factory_type, generator)) +""" +ASYNC_GENERATOR = """ +generator = {source}({args}) +{var} = await anext(generator) +exits.append(Exit(factory_type, generator)) +""" +FACTORY = """ +{var} = {source}({args}) +""" +ASYNC_FACTORY = """ +{var} = await {source}({args}) +""" +VALUE = """ +{var} = {source} +""" +ALIAS = """ +{var} = {args} +""" +CONTEXT = """ +raise NoContextValueError({key}) +""" +INVALID = """ +raise UnsupportedFactoryError( + f"Unsupported factory type {{factory_type}}.", +) +""" + +ASYNC_BODIES = { + FactoryType.ASYNC_FACTORY: ASYNC_FACTORY, + FactoryType.FACTORY: FACTORY, + FactoryType.ASYNC_GENERATOR: ASYNC_GENERATOR, + FactoryType.GENERATOR: GENERATOR, + FactoryType.VALUE: VALUE, + FactoryType.CONTEXT: CONTEXT, + FactoryType.ALIAS: ALIAS, +} +SYNC_BODIES = { + FactoryType.FACTORY: FACTORY, + FactoryType.GENERATOR: GENERATOR, + FactoryType.VALUE: VALUE, + FactoryType.CONTEXT: CONTEXT, + FactoryType.ALIAS: ALIAS, +} +FUNC_TEMPLATE = """ +{async_}def {func_name}(getter, exits, context): + cache_getter = context.get + {body} + return {var} +""" + +IF_TEMPLATE = """ +if {var} := cache_getter({key}): + pass # cache found +else: + {deps} + {body} + {cache} +""" +CACHE = "context[{key}] = {var}" + + +def make_name(obj: Any, ns: dict[Any, str]) -> str: + if isinstance(obj, DependencyKey): + key = get_name(obj.type_hint, include_module=False) + obj.component + else: + key = get_name(obj, include_module=False) + key = re.sub(r"\W", "_", key) + if key in ns: + key += f"_{len(ns)}" + return key + + +def make_globals(node: Node, ns: dict[Any, str]): + if node.provides not in ns: + ns[node.provides] = make_name(node.provides, ns) + if node.source not in ns: + ns[node.source] = make_name(node.source, ns) + for dep in node.dependencies: + make_globals(dep, ns) + for dep in node.kw_dependencies.values(): + make_globals(dep, ns) + + +def make_var(node: Node, ns: dict[Any, str]): + return "value_" + ns[node.provides].lower() + + +def make_if(node: Node, node_var: str, ns: dict[Any, str], + is_async: bool) -> str: + node_key = ns[node.provides] + node_source = ns[node.source] + + deps = "".join( + make_if(dep, make_var(dep, ns), ns, is_async) + for dep in node.dependencies + ) + deps += "".join( + make_if(dep, make_var(dep, ns), ns, is_async) + for dep in node.kw_dependencies.values() + ) + deps = indent(deps, " ") + if node.cache: + cache = CACHE.format(var=node_var, key=node_key) + else: + cache = "# no cache" + + args = [ns.get(dep.provides) for dep in node.dependencies] + kwargs = { + key: ns.get(dep.provides) + for key, dep in node.kw_dependencies.items() + } + + if is_async: + body_template = ASYNC_BODIES.get(node.type, INVALID) + else: + body_template = SYNC_BODIES.get(node.type, INVALID) + + args_str = make_args(args, kwargs) + body_str = body_template.format( + source=node_source, + key=node_key, + var=node_var, + args=args_str, + ) + body_str = indent(body_str, " ") + + return IF_TEMPLATE.format( + var=node_var, + key=node_key, + deps=deps, + body=body_str, + cache=cache, + ) + + +def make_func( + node: Node, ns: dict[Any, str], func_name: str, is_async: bool, +) -> str: + node_var = make_var(node, ns) + body = make_if(node, node_var, ns, is_async) + body = indent(body, " ") + return FUNC_TEMPLATE.format( + async_="async " if is_async else "", + var=node_var, + body=body, + func_name=func_name, + ) + + +def compile_graph(node: Node, is_async: bool): + ns: dict[Any, str] = { + node.type: "factory_type", + Exit: "Exit", + NoContextValueError: "NoContextValueError", + UnsupportedFactoryError: "UnsupportedFactoryError", + } + make_globals(node, ns) + func_name = f"get_{ns[node.provides].lower()}" + src = make_func(node, ns, func_name, is_async=is_async) + src = "\n".join(line for line in src.splitlines() if line.strip()) + + print(src) + print() + source_file_name = f"__dishka_factory_{id(node.provides)}" + if is_async: + source_file_name += "_async" + lines = src.splitlines(keepends=True) + linecache.cache[source_file_name] = ( + len(src), None, lines, source_file_name, + ) + global_ns = {value: key for key, value in ns.items()} + exec(src, global_ns) + return global_ns[func_name] diff --git a/src/dishka/registry.py b/src/dishka/registry.py index c7d67b5c..17151c5e 100644 --- a/src/dishka/registry.py +++ b/src/dishka/registry.py @@ -1,6 +1,8 @@ from collections.abc import Callable from typing import Any, TypeVar, get_args, get_origin +from pydantic.v1 import compiled + from ._adaptix.type_tools.fundamentals import get_type_vars from .container_objects import CompiledFactory from .dependency_source import ( @@ -11,6 +13,7 @@ from .entities.key import DependencyKey from .entities.scope import BaseScope from .factory_compiler import compile_factory +from .graph_compiler import Node, compile_graph class Registry: @@ -25,7 +28,7 @@ def __init__(self, scope: BaseScope): def add_factory( self, factory: Factory, - provides: DependencyKey| None = None, + provides: DependencyKey | None = None, ) -> None: if provides is None: provides = factory.provides @@ -40,7 +43,9 @@ def get_compiled( factory = self.get_factory(dependency) if not factory: return None - compiled = compile_factory(factory=factory, is_async=False) + node = make_node(self, factory) + compiled = compile_graph(node=node, is_async=False) + # compiled = compile_factory(factory=factory, is_async=False) self.compiled[dependency] = compiled return compiled @@ -53,7 +58,9 @@ def get_compiled_async( factory = self.get_factory(dependency) if not factory: return None - compiled = compile_factory(factory=factory, is_async=True) + node = make_node(self, factory) + compiled = compile_graph(node=node, is_async=True) + # compiled = compile_factory(factory=factory, is_async=True) self.compiled[dependency] = compiled return compiled @@ -144,3 +151,21 @@ def _specialize_generic( cache=factory.cache, override=factory.override, ) + + +def make_node(registry: Registry, factory: Factory) -> Node: + return Node( + provides=factory.provides, + scope=factory.scope, + source=factory.source, + type_=factory.type, + cache=factory.cache, + dependencies=[ + make_node(registry, registry.get_factory(dep)) + for dep in factory.dependencies + ], + kw_dependencies={ + key: make_node(registry, registry.get_factory(dep)) + for key, dep in factory.kw_dependencies.items() + }, + ) From 87cc8110d9e330b1f4c52db2bfc922b3f563bfb8 Mon Sep 17 00:00:00 2001 From: Andrey Tikhonov <17@itishka.org> Date: Wed, 23 Oct 2024 23:03:40 +0200 Subject: [PATCH 2/5] go parent --- src/dishka/graph_compiler.py | 16 ++++++++++++---- src/dishka/registry.py | 21 ++++++++++++++++----- 2 files changed, 28 insertions(+), 9 deletions(-) diff --git a/src/dishka/graph_compiler.py b/src/dishka/graph_compiler.py index ae1c1413..d224f6c0 100644 --- a/src/dishka/graph_compiler.py +++ b/src/dishka/graph_compiler.py @@ -26,7 +26,7 @@ def __init__( source: Any, provides: DependencyKey, scope: BaseScope, - type_: FactoryType, + type_: FactoryType | None, cache: bool, ) -> None: super().__init__( @@ -83,6 +83,12 @@ def make_args(args: list[str], kwargs: dict[str, str]) -> str: f"Unsupported factory type {{factory_type}}.", ) """ +GO_PARENT = """ +{var} = getter({key}) +""" +GO_PARENT_ASYNC = """ +{var} = await getter({key}) +""" ASYNC_BODIES = { FactoryType.ASYNC_FACTORY: ASYNC_FACTORY, @@ -92,6 +98,7 @@ def make_args(args: list[str], kwargs: dict[str, str]) -> str: FactoryType.VALUE: VALUE, FactoryType.CONTEXT: CONTEXT, FactoryType.ALIAS: ALIAS, + None: GO_PARENT_ASYNC, } SYNC_BODIES = { FactoryType.FACTORY: FACTORY, @@ -99,6 +106,7 @@ def make_args(args: list[str], kwargs: dict[str, str]) -> str: FactoryType.VALUE: VALUE, FactoryType.CONTEXT: CONTEXT, FactoryType.ALIAS: ALIAS, + None: GO_PARENT, } FUNC_TEMPLATE = """ {async_}def {func_name}(getter, exits, context): @@ -108,7 +116,7 @@ def make_args(args: list[str], kwargs: dict[str, str]) -> str: """ IF_TEMPLATE = """ -if {var} := cache_getter({key}): +if {var} := cache_getter({key}, None): pass # cache found else: {deps} @@ -163,9 +171,9 @@ def make_if(node: Node, node_var: str, ns: dict[Any, str], else: cache = "# no cache" - args = [ns.get(dep.provides) for dep in node.dependencies] + args = [make_var(dep, ns) for dep in node.dependencies] kwargs = { - key: ns.get(dep.provides) + key: make_var(dep, ns) for key, dep in node.kw_dependencies.items() } diff --git a/src/dishka/registry.py b/src/dishka/registry.py index 17151c5e..df1a26dd 100644 --- a/src/dishka/registry.py +++ b/src/dishka/registry.py @@ -43,7 +43,7 @@ def get_compiled( factory = self.get_factory(dependency) if not factory: return None - node = make_node(self, factory) + node = make_node(self, dependency) compiled = compile_graph(node=node, is_async=False) # compiled = compile_factory(factory=factory, is_async=False) self.compiled[dependency] = compiled @@ -58,7 +58,7 @@ def get_compiled_async( factory = self.get_factory(dependency) if not factory: return None - node = make_node(self, factory) + node = make_node(self, dependency) compiled = compile_graph(node=node, is_async=True) # compiled = compile_factory(factory=factory, is_async=True) self.compiled[dependency] = compiled @@ -153,7 +153,18 @@ def _specialize_generic( ) -def make_node(registry: Registry, factory: Factory) -> Node: +def make_node(registry: Registry, key: DependencyKey) -> Node: + factory = registry.get_factory(key) + if not factory: + return Node( + provides=key, + scope=registry.scope, + type_=None, + dependencies=[], + kw_dependencies={}, + cache=False, + source=None, + ) return Node( provides=factory.provides, scope=factory.scope, @@ -161,11 +172,11 @@ def make_node(registry: Registry, factory: Factory) -> Node: type_=factory.type, cache=factory.cache, dependencies=[ - make_node(registry, registry.get_factory(dep)) + make_node(registry, dep) for dep in factory.dependencies ], kw_dependencies={ - key: make_node(registry, registry.get_factory(dep)) + key: make_node(registry, dep) for key, dep in factory.kw_dependencies.items() }, ) From 48152f1236ecf3a9f9d4bdd91c52c07c49186f79 Mon Sep 17 00:00:00 2001 From: Andrey Tikhonov <17@itishka.org> Date: Wed, 23 Oct 2024 23:07:59 +0200 Subject: [PATCH 3/5] skip cache checking if no cache --- src/dishka/graph_compiler.py | 24 ++++++++++++++---------- 1 file changed, 14 insertions(+), 10 deletions(-) diff --git a/src/dishka/graph_compiler.py b/src/dishka/graph_compiler.py index d224f6c0..8a51e5c4 100644 --- a/src/dishka/graph_compiler.py +++ b/src/dishka/graph_compiler.py @@ -152,8 +152,9 @@ def make_var(node: Node, ns: dict[Any, str]): return "value_" + ns[node.provides].lower() -def make_if(node: Node, node_var: str, ns: dict[Any, str], - is_async: bool) -> str: +def make_if( + node: Node, node_var: str, ns: dict[Any, str], is_async: bool, +) -> str: node_key = ns[node.provides] node_source = ns[node.source] @@ -189,15 +190,18 @@ def make_if(node: Node, node_var: str, ns: dict[Any, str], var=node_var, args=args_str, ) - body_str = indent(body_str, " ") - return IF_TEMPLATE.format( - var=node_var, - key=node_key, - deps=deps, - body=body_str, - cache=cache, - ) + if node.cache: + body_str = indent(body_str, " ") + return IF_TEMPLATE.format( + var=node_var, + key=node_key, + deps=deps, + body=body_str, + cache=cache, + ) + else: + return "\n".join([deps, body_str, cache]) def make_func( From f8967cc6b720df7042a1fa4b681c7e5b71dc5c41 Mon Sep 17 00:00:00 2001 From: Andrey Tikhonov <17@itishka.org> Date: Fri, 25 Oct 2024 17:00:10 +0200 Subject: [PATCH 4/5] cache nodes, limit recursion depth --- src/dishka/container.py | 5 ++++- src/dishka/graph_compiler.py | 42 ++++++++++++++++++++++------------- src/dishka/registry.py | 43 ++++++++++++++++++++---------------- 3 files changed, 55 insertions(+), 35 deletions(-) diff --git a/src/dishka/container.py b/src/dishka/container.py index a7a4acca..dd7b8120 100644 --- a/src/dishka/container.py +++ b/src/dishka/container.py @@ -51,7 +51,7 @@ def __init__( ): self.registry = registry self.child_registries = child_registries - self._context = {DependencyKey(type(self), DEFAULT_COMPONENT): self} + self._context = {CONTAINER_KEY: self} if context: for key, value in context.items(): if not isinstance(key, DependencyKey): @@ -252,3 +252,6 @@ def make_container( close_parent=True, ) return container + + +CONTAINER_KEY = DependencyKey(Container, DEFAULT_COMPONENT) \ No newline at end of file diff --git a/src/dishka/graph_compiler.py b/src/dishka/graph_compiler.py index 8a51e5c4..55a73e20 100644 --- a/src/dishka/graph_compiler.py +++ b/src/dishka/graph_compiler.py @@ -11,6 +11,9 @@ from .text_rendering import get_name +MAX_DEPTH = 5 # max code depth, otherwise we get too big file + + class Node(FactoryData): __slots__ = ( "dependencies", @@ -106,7 +109,6 @@ def make_args(args: list[str], kwargs: dict[str, str]) -> str: FactoryType.VALUE: VALUE, FactoryType.CONTEXT: CONTEXT, FactoryType.ALIAS: ALIAS, - None: GO_PARENT, } FUNC_TEMPLATE = """ {async_}def {func_name}(getter, exits, context): @@ -116,19 +118,19 @@ def make_args(args: list[str], kwargs: dict[str, str]) -> str: """ IF_TEMPLATE = """ -if {var} := cache_getter({key}, None): - pass # cache found -else: +if ({var} := cache_getter({key}, ...)) is ...: {deps} {body} {cache} """ CACHE = "context[{key}] = {var}" - +builtins = {getattr(__builtins__, name): name for name in dir(__builtins__)} def make_name(obj: Any, ns: dict[Any, str]) -> str: + if obj in builtins: + return builtins[obj] if isinstance(obj, DependencyKey): - key = get_name(obj.type_hint, include_module=False) + obj.component + key = get_name(obj.type_hint, include_module=False) +"_"+ obj.component else: key = get_name(obj, include_module=False) key = re.sub(r"\W", "_", key) @@ -153,24 +155,33 @@ def make_var(node: Node, ns: dict[Any, str]): def make_if( - node: Node, node_var: str, ns: dict[Any, str], is_async: bool, + node: Node, node_var: str, ns: dict[Any, str], + is_async: bool, + depth: int, ) -> str: node_key = ns[node.provides] node_source = ns[node.source] + if depth > MAX_DEPTH or node.type is None: + if is_async: + return GO_PARENT.format( + var=node_var, + key=node_key, + ) + else: + return GO_PARENT.format( + var=node_var, + key=node_key, + ) deps = "".join( - make_if(dep, make_var(dep, ns), ns, is_async) + make_if(dep, make_var(dep, ns), ns, is_async, depth+1) for dep in node.dependencies ) deps += "".join( - make_if(dep, make_var(dep, ns), ns, is_async) + make_if(dep, make_var(dep, ns), ns, is_async, depth+1) for dep in node.kw_dependencies.values() ) deps = indent(deps, " ") - if node.cache: - cache = CACHE.format(var=node_var, key=node_key) - else: - cache = "# no cache" args = [make_var(dep, ns) for dep in node.dependencies] kwargs = { @@ -192,6 +203,7 @@ def make_if( ) if node.cache: + cache = CACHE.format(var=node_var, key=node_key) body_str = indent(body_str, " ") return IF_TEMPLATE.format( var=node_var, @@ -201,14 +213,14 @@ def make_if( cache=cache, ) else: - return "\n".join([deps, body_str, cache]) + return "\n".join([deps, body_str]) def make_func( node: Node, ns: dict[Any, str], func_name: str, is_async: bool, ) -> str: node_var = make_var(node, ns) - body = make_if(node, node_var, ns, is_async) + body = make_if(node, node_var, ns, is_async, 0) body = indent(body, " ") return FUNC_TEMPLATE.format( async_="async " if is_async else "", diff --git a/src/dishka/registry.py b/src/dishka/registry.py index df1a26dd..618ff0c1 100644 --- a/src/dishka/registry.py +++ b/src/dishka/registry.py @@ -1,7 +1,8 @@ +import time from collections.abc import Callable +from linecache import cache from typing import Any, TypeVar, get_args, get_origin -from pydantic.v1 import compiled from ._adaptix.type_tools.fundamentals import get_type_vars from .container_objects import CompiledFactory @@ -12,7 +13,6 @@ from .entities.factory_type import FactoryType from .entities.key import DependencyKey from .entities.scope import BaseScope -from .factory_compiler import compile_factory from .graph_compiler import Node, compile_graph @@ -153,10 +153,12 @@ def _specialize_generic( ) -def make_node(registry: Registry, key: DependencyKey) -> Node: +def make_node(registry: Registry, key: DependencyKey, cache: dict| None = None) -> Node: + if cache is None: + cache = {} factory = registry.get_factory(key) if not factory: - return Node( + node = Node( provides=key, scope=registry.scope, type_=None, @@ -165,18 +167,21 @@ def make_node(registry: Registry, key: DependencyKey) -> Node: cache=False, source=None, ) - return Node( - provides=factory.provides, - scope=factory.scope, - source=factory.source, - type_=factory.type, - cache=factory.cache, - dependencies=[ - make_node(registry, dep) - for dep in factory.dependencies - ], - kw_dependencies={ - key: make_node(registry, dep) - for key, dep in factory.kw_dependencies.items() - }, - ) + else: + node = Node( + provides=factory.provides, + scope=factory.scope, + source=factory.source, + type_=factory.type, + cache=factory.cache, + dependencies=[ + make_node(registry, dep, cache) + for dep in factory.dependencies + ], + kw_dependencies={ + key: make_node(registry, dep, cache) + for key, dep in factory.kw_dependencies.items() + }, + ) + cache[key] = node + return node From 20f2a008f0a12fca019c5458ee8bece85d9c60b0 Mon Sep 17 00:00:00 2001 From: Andrey Tikhonov <17@itishka.org> Date: Fri, 25 Oct 2024 17:21:33 +0200 Subject: [PATCH 5/5] max depth in graph transform --- src/dishka/registry.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/src/dishka/registry.py b/src/dishka/registry.py index 618ff0c1..5827791e 100644 --- a/src/dishka/registry.py +++ b/src/dishka/registry.py @@ -152,12 +152,13 @@ def _specialize_generic( override=factory.override, ) +MAX_DEPTH = 4 -def make_node(registry: Registry, key: DependencyKey, cache: dict| None = None) -> Node: +def make_node(registry: Registry, key: DependencyKey, cache: dict| None = None, depth: int=0) -> Node: if cache is None: cache = {} factory = registry.get_factory(key) - if not factory: + if not factory or depth>MAX_DEPTH: node = Node( provides=key, scope=registry.scope, @@ -175,11 +176,11 @@ def make_node(registry: Registry, key: DependencyKey, cache: dict| None = None) type_=factory.type, cache=factory.cache, dependencies=[ - make_node(registry, dep, cache) + make_node(registry, dep, cache, depth+1) for dep in factory.dependencies ], kw_dependencies={ - key: make_node(registry, dep, cache) + key: make_node(registry, dep, cache, depth+1) for key, dep in factory.kw_dependencies.items() }, )