Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
122 changes: 63 additions & 59 deletions src/dishka/async_container.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import warnings
from asyncio import Lock
from collections.abc import Callable, MutableMapping
from contextlib import AbstractAsyncContextManager
from contextlib import AbstractContextManager, AbstractAsyncContextManager
from types import TracebackType
from typing import Any, TypeVar, cast, overload

Expand Down Expand Up @@ -37,7 +37,8 @@ class AsyncContainer:
"_cache",
"_context",
"_exits",
"child_registries",
"_scope",
"_child_scopes",
"close_parent",
"lock",
"parent_container",
Expand All @@ -47,7 +48,8 @@ class AsyncContainer:
def __init__(
self,
registry: Registry,
*child_registries: Registry,
scope: BaseScope,
*child_scopes: BaseScope,
parent_container: AsyncContainer | None = None,
context: dict[Any, Any] | None = None,
lock_factory: Callable[
Expand All @@ -56,7 +58,8 @@ def __init__(
close_parent: bool = False,
):
self.registry = registry
self.child_registries = child_registries
self._scope = scope
self._child_scopes = child_scopes
self._context = {CONTAINER_KEY: self}
if context:
for key, value in context.items():
Expand All @@ -69,7 +72,7 @@ def __init__(
self._cache = {**self._context}
self.parent_container = parent_container

self.lock: AbstractAsyncContextManager[Any] | None
self.lock: AbstractContextManager[Any] | None
if lock_factory:
self.lock = lock_factory()
else:
Expand All @@ -79,7 +82,7 @@ def __init__(

@property
def scope(self) -> BaseScope:
return self.registry.scope
return self._scope

@property
def context(self) -> MutableMapping[DependencyKey, Any]:
Expand All @@ -94,7 +97,7 @@ def __call__(
self,
context: dict[Any, Any] | None = None,
lock_factory: Callable[
[], AbstractAsyncContextManager[Any],
[], AbstractContextManager[Any],
] | None = None,
scope: BaseScope | None = None,
) -> AsyncContextWrapper:
Expand All @@ -103,34 +106,36 @@ def __call__(
:param context: Data which will available in inner scope
:param lock_factory: Callable to create lock instance or None
:param scope: target scope or None to enter next non-skipped scope
:return: async context manager for inner scope
:return: context manager for inner scope
"""
if not self.child_registries:
if not self._child_scopes:
raise NoChildScopesError

child = AsyncContainer(
*self.child_registries,
self.registry,
*self._child_scopes,
parent_container=self,
context=context,
lock_factory=lock_factory,
)
if scope is None:
while child.registry.scope.skip:
if not child.child_registries:
while child.scope.skip:
if not child._child_scopes:
raise NoNonSkippedScopesError
child = AsyncContainer(
*child.child_registries,
self.registry,
*child._child_scopes,
parent_container=child,
context=context,
lock_factory=lock_factory,
close_parent=True,
)
else:
while child.registry.scope is not scope:
if not child.child_registries:
raise ChildScopeNotFoundError(scope, self.registry.scope)
while child.scope is not scope:
if not child._child_scopes:
raise ChildScopeNotFoundError(scope, self.scope)
child = AsyncContainer(
*child.child_registries,
self.registry,
*child._child_scopes,
parent_container=child,
context=context,
lock_factory=lock_factory,
Expand Down Expand Up @@ -182,40 +187,37 @@ async def _get_unlocked(self, key: DependencyKey) -> Any:
return self._cache[key]
compiled = self.registry.get_compiled_async(key)
if not compiled:
if not self.parent_container:
abstract_dependencies = (
self.registry.get_more_abstract_factories(key)
)
concrete_dependencies = (
self.registry.get_more_concrete_factories(key)
)
raise NoFactoryError(
key,
suggest_abstract_factories=abstract_dependencies,
suggest_concrete_factories=concrete_dependencies,
)
abstract_dependencies = (
self.registry.get_more_abstract_factories(key)
)
concrete_dependencies = (
self.registry.get_more_concrete_factories(key)
)

raise NoFactoryError(
key,
suggest_abstract_factories=abstract_dependencies,
suggest_concrete_factories=concrete_dependencies,
)

if compiled.scope == self.scope:
try:
return await self.parent_container._get(key) # noqa: SLF001
except NoFactoryError as ex:
abstract_dependencies = (
self.registry.get_more_abstract_factories(key)
)
concrete_dependencies = (
self.registry.get_more_concrete_factories(key)
)
ex.suggest_abstract_factories.extend(abstract_dependencies)
ex.suggest_concrete_factories.extend(concrete_dependencies)
return await compiled(self._get_unlocked, self._exits, self._cache)
except NoFactoryError as e:
# cast is needed because registry.get_factory will always
# return Factory. This happens because registry.get_compiled
# uses the same method and returns None if the factory is not found
# If None is returned, then go to the parent container
e.add_path(cast(Factory, self.registry.get_factory(key)))
raise
else:
parent = self.parent_container
while parent.scope != compiled.scope:
if not parent.parent_container:
raise NoFactoryError(key)
parent = parent.parent_container

try:
return await compiled(self._get_unlocked, self._exits, self._cache)
except NoFactoryError as e:
# cast is needed because registry.get_factory will always
# return Factory. This happens because registry.get_compiled
# uses the same method and returns None if the factory is not found
# If None is returned, then go to the parent container
e.add_path(cast(Factory, self.registry.get_factory(key)))
raise
return await parent._get(key)

async def close(self, exception: BaseException | None = None) -> None:
errors = []
Expand All @@ -242,6 +244,8 @@ async def close(self, exception: BaseException | None = None) -> None:


class AsyncContextWrapper:
__slots__ = ("container",)

def __init__(self, container: AsyncContainer):
self.container = container

Expand All @@ -261,40 +265,40 @@ def make_async_container(
*providers: BaseProvider,
scopes: type[BaseScope] = Scope,
context: dict[Any, Any] | None = None,
lock_factory: Callable[
[], AbstractAsyncContextManager[Any],
] | None = Lock,
lock_factory: Callable[[], AbstractContextManager[Any]] | None = Lock,
skip_validation: bool = False,
start_scope: BaseScope | None = None,
validation_settings: ValidationSettings = DEFAULT_VALIDATION,
) -> AsyncContainer:
context_provider = make_root_context_provider(providers, context, scopes)
registries = RegistryBuilder(
registry = RegistryBuilder(
scopes=scopes,
container_key=CONTAINER_KEY,
providers=(*providers, context_provider),
skip_validation=skip_validation,
validation_settings=validation_settings,
).build()
container = AsyncContainer(
*registries,
registry,
*scopes,
context=context,
lock_factory=lock_factory,
)

if start_scope is None:
while container.registry.scope.skip:
while container.scope.skip:
container = AsyncContainer(
*container.child_registries,
registry,
*container._child_scopes,
parent_container=container,
context=context,
lock_factory=lock_factory,
close_parent=True,
)
else:
while container.registry.scope is not start_scope:
while container.scope is not start_scope:
container = AsyncContainer(
*container.child_registries,
registry,
*container._child_scopes,
parent_container=container,
context=context,
lock_factory=lock_factory,
Expand Down
Loading
Loading