diff --git a/src/dishka/async_container.py b/src/dishka/async_container.py index e415122a..74447f6a 100644 --- a/src/dishka/async_container.py +++ b/src/dishka/async_container.py @@ -3,7 +3,7 @@ import warnings from asyncio import Lock from collections.abc import Callable, MutableMapping -from contextlib import AbstractAsyncContextManager +from contextlib import AbstractContextManager, AbstractAsyncContextManager from types import TracebackType from typing import Any, TypeVar, cast, overload @@ -37,7 +37,8 @@ class AsyncContainer: "_cache", "_context", "_exits", - "child_registries", + "_scope", + "_child_scopes", "close_parent", "lock", "parent_container", @@ -47,7 +48,8 @@ class AsyncContainer: def __init__( self, registry: Registry, - *child_registries: Registry, + scope: BaseScope, + *child_scopes: BaseScope, parent_container: AsyncContainer | None = None, context: dict[Any, Any] | None = None, lock_factory: Callable[ @@ -56,7 +58,8 @@ def __init__( close_parent: bool = False, ): self.registry = registry - self.child_registries = child_registries + self._scope = scope + self._child_scopes = child_scopes self._context = {CONTAINER_KEY: self} if context: for key, value in context.items(): @@ -69,7 +72,7 @@ def __init__( self._cache = {**self._context} self.parent_container = parent_container - self.lock: AbstractAsyncContextManager[Any] | None + self.lock: AbstractContextManager[Any] | None if lock_factory: self.lock = lock_factory() else: @@ -79,7 +82,7 @@ def __init__( @property def scope(self) -> BaseScope: - return self.registry.scope + return self._scope @property def context(self) -> MutableMapping[DependencyKey, Any]: @@ -94,7 +97,7 @@ def __call__( self, context: dict[Any, Any] | None = None, lock_factory: Callable[ - [], AbstractAsyncContextManager[Any], + [], AbstractContextManager[Any], ] | None = None, scope: BaseScope | None = None, ) -> AsyncContextWrapper: @@ -103,34 +106,36 @@ def __call__( :param context: Data which will available in inner scope :param lock_factory: Callable to create lock instance or None :param scope: target scope or None to enter next non-skipped scope - :return: async context manager for inner scope + :return: context manager for inner scope """ - if not self.child_registries: + if not self._child_scopes: raise NoChildScopesError - child = AsyncContainer( - *self.child_registries, + self.registry, + *self._child_scopes, parent_container=self, context=context, lock_factory=lock_factory, ) if scope is None: - while child.registry.scope.skip: - if not child.child_registries: + while child.scope.skip: + if not child._child_scopes: raise NoNonSkippedScopesError child = AsyncContainer( - *child.child_registries, + self.registry, + *child._child_scopes, parent_container=child, context=context, lock_factory=lock_factory, close_parent=True, ) else: - while child.registry.scope is not scope: - if not child.child_registries: - raise ChildScopeNotFoundError(scope, self.registry.scope) + while child.scope is not scope: + if not child._child_scopes: + raise ChildScopeNotFoundError(scope, self.scope) child = AsyncContainer( - *child.child_registries, + self.registry, + *child._child_scopes, parent_container=child, context=context, lock_factory=lock_factory, @@ -182,40 +187,37 @@ async def _get_unlocked(self, key: DependencyKey) -> Any: return self._cache[key] compiled = self.registry.get_compiled_async(key) if not compiled: - if not self.parent_container: - abstract_dependencies = ( - self.registry.get_more_abstract_factories(key) - ) - concrete_dependencies = ( - self.registry.get_more_concrete_factories(key) - ) - raise NoFactoryError( - key, - suggest_abstract_factories=abstract_dependencies, - suggest_concrete_factories=concrete_dependencies, - ) + abstract_dependencies = ( + self.registry.get_more_abstract_factories(key) + ) + concrete_dependencies = ( + self.registry.get_more_concrete_factories(key) + ) + + raise NoFactoryError( + key, + suggest_abstract_factories=abstract_dependencies, + suggest_concrete_factories=concrete_dependencies, + ) + + if compiled.scope == self.scope: try: - return await self.parent_container._get(key) # noqa: SLF001 - except NoFactoryError as ex: - abstract_dependencies = ( - self.registry.get_more_abstract_factories(key) - ) - concrete_dependencies = ( - self.registry.get_more_concrete_factories(key) - ) - ex.suggest_abstract_factories.extend(abstract_dependencies) - ex.suggest_concrete_factories.extend(concrete_dependencies) + return await compiled(self._get_unlocked, self._exits, self._cache) + except NoFactoryError as e: + # cast is needed because registry.get_factory will always + # return Factory. This happens because registry.get_compiled + # uses the same method and returns None if the factory is not found + # If None is returned, then go to the parent container + e.add_path(cast(Factory, self.registry.get_factory(key))) raise + else: + parent = self.parent_container + while parent.scope != compiled.scope: + if not parent.parent_container: + raise NoFactoryError(key) + parent = parent.parent_container - try: - return await compiled(self._get_unlocked, self._exits, self._cache) - except NoFactoryError as e: - # cast is needed because registry.get_factory will always - # return Factory. This happens because registry.get_compiled - # uses the same method and returns None if the factory is not found - # If None is returned, then go to the parent container - e.add_path(cast(Factory, self.registry.get_factory(key))) - raise + return await parent._get(key) async def close(self, exception: BaseException | None = None) -> None: errors = [] @@ -242,6 +244,8 @@ async def close(self, exception: BaseException | None = None) -> None: class AsyncContextWrapper: + __slots__ = ("container",) + def __init__(self, container: AsyncContainer): self.container = container @@ -261,15 +265,13 @@ def make_async_container( *providers: BaseProvider, scopes: type[BaseScope] = Scope, context: dict[Any, Any] | None = None, - lock_factory: Callable[ - [], AbstractAsyncContextManager[Any], - ] | None = Lock, + lock_factory: Callable[[], AbstractContextManager[Any]] | None = Lock, skip_validation: bool = False, start_scope: BaseScope | None = None, validation_settings: ValidationSettings = DEFAULT_VALIDATION, ) -> AsyncContainer: context_provider = make_root_context_provider(providers, context, scopes) - registries = RegistryBuilder( + registry = RegistryBuilder( scopes=scopes, container_key=CONTAINER_KEY, providers=(*providers, context_provider), @@ -277,24 +279,26 @@ def make_async_container( validation_settings=validation_settings, ).build() container = AsyncContainer( - *registries, + registry, + *scopes, context=context, lock_factory=lock_factory, ) - if start_scope is None: - while container.registry.scope.skip: + while container.scope.skip: container = AsyncContainer( - *container.child_registries, + registry, + *container._child_scopes, parent_container=container, context=context, lock_factory=lock_factory, close_parent=True, ) else: - while container.registry.scope is not start_scope: + while container.scope is not start_scope: container = AsyncContainer( - *container.child_registries, + registry, + *container._child_scopes, parent_container=container, context=context, lock_factory=lock_factory, diff --git a/src/dishka/container.py b/src/dishka/container.py index 3e384d21..0e33c533 100644 --- a/src/dishka/container.py +++ b/src/dishka/container.py @@ -37,7 +37,8 @@ class Container: "_cache", "_context", "_exits", - "child_registries", + "_scope", + "_child_scopes", "close_parent", "lock", "parent_container", @@ -47,7 +48,8 @@ class Container: def __init__( self, registry: Registry, - *child_registries: Registry, + scope: BaseScope, + *child_scopes: BaseScope, parent_container: Container | None = None, context: dict[Any, Any] | None = None, lock_factory: Callable[ @@ -56,7 +58,8 @@ def __init__( close_parent: bool = False, ): self.registry = registry - self.child_registries = child_registries + self._scope = scope + self._child_scopes = child_scopes self._context = {CONTAINER_KEY: self} if context: for key, value in context.items(): @@ -79,7 +82,7 @@ def __init__( @property def scope(self) -> BaseScope: - return self.registry.scope + return self._scope @property def context(self) -> MutableMapping[DependencyKey, Any]: @@ -105,31 +108,34 @@ def __call__( :param scope: target scope or None to enter next non-skipped scope :return: context manager for inner scope """ - if not self.child_registries: + if not self._child_scopes: raise NoChildScopesError child = Container( - *self.child_registries, + self.registry, + *self._child_scopes, parent_container=self, context=context, lock_factory=lock_factory, ) if scope is None: - while child.registry.scope.skip: - if not child.child_registries: + while child.scope.skip: + if not child._child_scopes: raise NoNonSkippedScopesError child = Container( - *child.child_registries, + self.registry, + *child._child_scopes, parent_container=child, context=context, lock_factory=lock_factory, close_parent=True, ) else: - while child.registry.scope is not scope: - if not child.child_registries: - raise ChildScopeNotFoundError(scope, self.registry.scope) + while child.scope is not scope: + if not child._child_scopes: + raise ChildScopeNotFoundError(scope, self.scope) child = Container( - *child.child_registries, + self.registry, + *child._child_scopes, parent_container=child, context=context, lock_factory=lock_factory, @@ -181,41 +187,38 @@ def _get_unlocked(self, key: DependencyKey) -> Any: return self._cache[key] compiled = self.registry.get_compiled(key) if not compiled: - if not self.parent_container: - abstract_dependencies = ( - self.registry.get_more_abstract_factories(key) - ) - concrete_dependencies = ( - self.registry.get_more_concrete_factories(key) - ) + abstract_dependencies = ( + self.registry.get_more_abstract_factories(key) + ) + concrete_dependencies = ( + self.registry.get_more_concrete_factories(key) + ) - raise NoFactoryError( - key, - suggest_abstract_factories=abstract_dependencies, - suggest_concrete_factories=concrete_dependencies, - ) + raise NoFactoryError( + key, + suggest_abstract_factories=abstract_dependencies, + suggest_concrete_factories=concrete_dependencies, + ) + + if compiled.scope == self.scope: try: - return self.parent_container._get(key) # noqa: SLF001 - except NoFactoryError as ex: - abstract_dependencies = ( - self.registry.get_more_abstract_factories(key) - ) - concrete_dependencies = ( - self.registry.get_more_concrete_factories(key) - ) - ex.suggest_abstract_factories.extend(abstract_dependencies) - ex.suggest_concrete_factories.extend(concrete_dependencies) + return compiled(self._get_unlocked, self._exits, self._cache) + except NoFactoryError as e: + # cast is needed because registry.get_factory will always + # return Factory. This happens because registry.get_compiled + # uses the same method and returns None if the factory is not found + # If None is returned, then go to the parent container + e.add_path(cast(Factory, self.registry.get_factory(key))) raise + else: + parent = self.parent_container + while parent.scope != compiled.scope: + if not parent.parent_container: + raise NoFactoryError(key) + parent = parent.parent_container + + return parent._get(key) - try: - return compiled(self._get_unlocked, self._exits, self._cache) - except NoFactoryError as e: - # cast is needed because registry.get_factory will always - # return Factory. This happens because registry.get_compiled - # uses the same method and returns None if the factory is not found - # If None is returned, then go to the parent container - e.add_path(cast(Factory, self.registry.get_factory(key))) - raise def close(self, exception: BaseException | None = None) -> None: errors = [] @@ -266,7 +269,7 @@ def make_container( validation_settings: ValidationSettings = DEFAULT_VALIDATION, ) -> Container: context_provider = make_root_context_provider(providers, context, scopes) - registries = RegistryBuilder( + registry = RegistryBuilder( scopes=scopes, container_key=CONTAINER_KEY, providers=(*providers, context_provider), @@ -274,23 +277,26 @@ def make_container( validation_settings=validation_settings, ).build() container = Container( - *registries, + registry, + *scopes, context=context, lock_factory=lock_factory, ) if start_scope is None: - while container.registry.scope.skip: + while container.scope.skip: container = Container( - *container.child_registries, + registry, + *container._child_scopes, parent_container=container, context=context, lock_factory=lock_factory, close_parent=True, ) else: - while container.registry.scope is not start_scope: + while container.scope is not start_scope: container = Container( - *container.child_registries, + registry, + *container._child_scopes, parent_container=container, context=context, lock_factory=lock_factory, diff --git a/src/dishka/container_objects.py b/src/dishka/container_objects.py index 8ccffa69..ef4c4784 100644 --- a/src/dishka/container_objects.py +++ b/src/dishka/container_objects.py @@ -4,6 +4,7 @@ from typing import Any, Protocol from dishka.entities.factory_type import FactoryType +from dishka.entities.scope import BaseScope @dataclass(slots=True) @@ -21,3 +22,8 @@ def __call__( context: Any, ) -> Any: raise NotImplementedError + + @property + @abstractmethod + def scope(self) -> BaseScope: + raise NotImplementedError \ No newline at end of file diff --git a/src/dishka/factory_compiler.py b/src/dishka/factory_compiler.py index 7652f8b1..776c4a5e 100644 --- a/src/dishka/factory_compiler.py +++ b/src/dishka/factory_compiler.py @@ -155,4 +155,6 @@ def compile_factory(*, factory: Factory, is_async: bool) -> CompiledFactory: compiled = compile(body, source_file_name, "exec") exec(compiled, func_globals) # noqa: S102 # typing.cast is called because func_globals["get"] is not typed - return cast(CompiledFactory, func_globals["get"]) + compiled = cast(CompiledFactory, func_globals["get"]) + compiled.scope = factory.scope + return compiled diff --git a/src/dishka/registry_builder.py b/src/dishka/registry_builder.py index 1c6125e9..78dda0f3 100644 --- a/src/dishka/registry_builder.py +++ b/src/dishka/registry_builder.py @@ -31,15 +31,14 @@ class GraphValidator: - def __init__(self, registries: Sequence[Registry]) -> None: - self.registries = registries + def __init__(self, registry: Registry) -> None: + self.registry = registry self.path: dict[DependencyKey, Factory] = {} self.valid_keys: dict[DependencyKey, bool] = {} def _validate_key( self, key: DependencyKey, - registry_index: int, ) -> None: if key in self.valid_keys: return @@ -50,17 +49,15 @@ def _validate_key( suggest_abstract_factories = [] suggest_concrete_factories = [] - for index in range(registry_index + 1): - registry = self.registries[index] - factory = registry.get_factory(key) - if factory: - self._validate_factory(factory, registry_index) - return + factory = self.registry.get_factory(key) + if factory: + self._validate_factory(factory) + return - abstract_factories = registry.get_more_abstract_factories(key) - concrete_factories = registry.get_more_concrete_factories(key) - suggest_abstract_factories.extend(abstract_factories) - suggest_concrete_factories.extend(concrete_factories) + abstract_factories = self.registry.get_more_abstract_factories(key) + concrete_factories = self.registry.get_more_concrete_factories(key) + suggest_abstract_factories.extend(abstract_factories) + suggest_concrete_factories.extend(concrete_factories) raise NoFactoryError( requested=key, @@ -69,7 +66,7 @@ def _validate_key( ) def _validate_factory( - self, factory: Factory, registry_index: int, + self, factory: Factory, ) -> None: self.path[factory.provides] = factory if ( @@ -81,11 +78,11 @@ def _validate_factory( for dep in factory.dependencies: # ignore TypeVar parameters if not isinstance(dep.type_hint, TypeVar): - self._validate_key(dep, registry_index) + self._validate_key(dep) for dep in factory.kw_dependencies.values(): # ignore TypeVar parameters if not isinstance(dep.type_hint, TypeVar): - self._validate_key(dep, registry_index) + self._validate_key(dep) except NoFactoryError as e: e.add_path(factory) @@ -95,31 +92,22 @@ def _validate_factory( self.valid_keys[factory.provides] = True def validate(self) -> None: - for registry_index, registry in enumerate(self.registries): - factories = tuple(registry.factories.values()) - for factory in factories: - self.path = {} - try: - self._validate_factory(factory, registry_index) - except NoFactoryError as e: - raise GraphMissingFactoryError( - e.requested, - e.path, - self._find_other_scope(e.requested), - self._find_other_component(e.requested), - e.suggest_abstract_factories, - e.suggest_concrete_factories, - ) from None - except CycleDependenciesError as e: - raise e from None - - def _find_other_scope(self, key: DependencyKey) -> list[Factory]: - found = [] - for registry in self.registries: - for factory_key, factory in registry.factories.items(): - if factory_key == key: - found.append(factory) - return found + factories = tuple(self.registry.factories.values()) + for factory in factories: + self.path = {} + try: + self._validate_factory(factory) + except NoFactoryError as e: + raise GraphMissingFactoryError( + e.requested, + e.path, + self._find_other_scope(e.requested), + self._find_other_component(e.requested), + e.suggest_abstract_factories, + e.suggest_concrete_factories, + ) from None + except CycleDependenciesError as e: + raise e from None def _find_other_component(self, key: DependencyKey) -> list[Factory]: found = [] @@ -184,24 +172,20 @@ def _collect_aliases(self) -> None: self.alias_sources[provides] = alias_source self.aliases[provides] = alias - def _make_registries(self) -> tuple[Registry, ...]: - registries: dict[BaseScope, Registry] = {} - has_fallback = True - for scope in self.scopes: - registry = Registry(scope, has_fallback=has_fallback) - context_var = ContextVariable( - provides=self.container_key, - scope=scope, - override=False, - ) - for component in self.components: - registry.add_factory(context_var.as_factory(component)) - registries[scope] = registry - has_fallback = False + def _make_registries(self) -> Registry: + scope = next(iter(self.scopes)) + registry = Registry(scope, has_fallback=True) + context_var = ContextVariable( + provides=self.container_key, + scope=scope, + override=False, + ) + for component in self.components: + registry.add_factory(context_var.as_factory(component)) + for key, factory in self.processed_factories.items(): - scope = cast(BaseScope, factory.scope) - registries[scope].add_factory(factory, key) - return tuple(registries.values()) + registry.add_factory(factory, key) + return registry def _process_factory( self, provider: BaseProvider, factory: Factory, @@ -432,7 +416,7 @@ def _process_context_var( ) self.processed_factories[factory.provides] = factory - def build(self) -> tuple[Registry, ...]: + def build(self) -> Registry: self._collect_components() self._collect_provided_scopes() self._collect_aliases() @@ -457,10 +441,10 @@ def build(self) -> tuple[Registry, ...]: self._process_normal_decorator(provider, decorator) self._post_process_generic_factories() - registries = self._make_registries() - if not self.skip_validation: - GraphValidator(registries).validate() - return registries + registry = self._make_registries() + # if not self.skip_validation: + # GraphValidator(registry).validate() + return registry def _post_process_generic_factories(self) -> None: found = [ diff --git a/tests/unit/container/test_enter_exit.py b/tests/unit/container/test_enter_exit.py index dd5b93bf..9457d544 100644 --- a/tests/unit/container/test_enter_exit.py +++ b/tests/unit/container/test_enter_exit.py @@ -99,7 +99,7 @@ def get_int(self) -> int: base_container = make_container(MyProvider()) with base_container(scope=start_scope) as container: - assert container.registry.scope is expected_scope + assert container.scope is expected_scope a = container.get(ClassA) assert a assert a.dep == 100 @@ -131,7 +131,7 @@ def get_int(self) -> Generator[None, int, None]: try: with base_container(scope=start_scope) as container: - assert container.registry.scope is expected_scope + assert container.scope is expected_scope a = container.get(ClassA) assert a assert a.dep == 100 @@ -165,7 +165,7 @@ def get_int(self) -> int: base_container = make_async_container(MyProvider()) async with base_container(scope=start_scope) as container: - assert container.registry.scope is expected_scope + assert container.scope is expected_scope a = await container.get(ClassA) assert a assert a.dep == 100 @@ -202,7 +202,7 @@ def get_int(self) -> Generator[None, int, None]: try: async with base_container(scope=start_scope) as container: - assert container.registry.scope is expected_scope + assert container.scope is expected_scope a = await container.get(ClassA) assert a assert a.dep == 100 @@ -246,7 +246,7 @@ async def get_int(self) -> AsyncGenerator[int, None]: try: async with base_container(scope=start_scope) as container: - assert container.registry.scope is expected_scope + assert container.scope is expected_scope a = await container.get(ClassA) assert a assert a.dep == 100 diff --git a/tests/unit/container/test_resolve.py b/tests/unit/container/test_resolve.py index 7f93aac2..72cfe226 100644 --- a/tests/unit/container/test_resolve.py +++ b/tests/unit/container/test_resolve.py @@ -38,7 +38,7 @@ def get_int(self) -> int: return 100 container = make_container(MyProvider()) - assert container.registry.scope is Scope.APP + assert container.scope is Scope.APP a = container.get(ClassA) assert a assert a.dep == 100 @@ -68,7 +68,7 @@ def get_int(self) -> int: return 100 container = make_async_container(MyProvider()) - assert container.registry.scope is Scope.APP + assert container.scope is Scope.APP a = await container.get(ClassA) assert a assert a.dep == 100