diff --git a/src/dishka/__init__.py b/src/dishka/__init__.py index 99f97232..2bd8caa5 100644 --- a/src/dishka/__init__.py +++ b/src/dishka/__init__.py @@ -1,6 +1,8 @@ __all__ = [ "DEFAULT_COMPONENT", "STRICT_VALIDATION", + "ActivationContext", + "Activator", "AnyOf", "AsyncContainer", "BaseScope", @@ -9,6 +11,7 @@ "DependencyKey", "FromComponent", "FromDishka", + "Has", "Provider", "Scope", "ValidationSettings", @@ -25,6 +28,7 @@ from .async_container import AsyncContainer, make_async_container from .container import Container, make_container +from .entities.activator import ActivationContext, Activator, Has from .entities.component import DEFAULT_COMPONENT, Component from .entities.depends_marker import FromDishka from .entities.key import DependencyKey, FromComponent diff --git a/src/dishka/async_container.py b/src/dishka/async_container.py index e415122a..6ff41bd8 100644 --- a/src/dishka/async_container.py +++ b/src/dishka/async_container.py @@ -275,6 +275,7 @@ def make_async_container( providers=(*providers, context_provider), skip_validation=skip_validation, validation_settings=validation_settings, + root_context=context, ).build() container = AsyncContainer( *registries, diff --git a/src/dishka/container.py b/src/dishka/container.py index 3e384d21..9263dea5 100644 --- a/src/dishka/container.py +++ b/src/dishka/container.py @@ -272,6 +272,7 @@ def make_container( providers=(*providers, context_provider), skip_validation=skip_validation, validation_settings=validation_settings, + root_context=context, ).build() container = Container( *registries, diff --git a/src/dishka/dependency_source/alias.py b/src/dishka/dependency_source/alias.py index 713bafb2..52ccaa4e 100644 --- a/src/dishka/dependency_source/alias.py +++ b/src/dishka/dependency_source/alias.py @@ -2,6 +2,7 @@ from typing import Any +from dishka.entities.activator import Activator from dishka.entities.component import Component from dishka.entities.factory_type import FactoryType from dishka.entities.key import DependencyKey @@ -14,7 +15,10 @@ def _identity(x: Any) -> Any: class Alias: - __slots__ = ("cache", "component", "override", "provides", "source") + __slots__ = ( + "cache", "component", "override", + "provides", "source", "when", + ) def __init__( self, *, @@ -22,11 +26,13 @@ def __init__( provides: DependencyKey, cache: bool, override: bool, + when: Activator | None, ) -> None: self.source = source self.provides = provides self.cache = cache self.override = override + self.when = when def as_factory( self, scope: BaseScope | None, component: Component | None, @@ -41,6 +47,7 @@ def as_factory( type_=FactoryType.ALIAS, cache=self.cache, override=self.override, + when=self.when, ) def __get__(self, instance: Any, owner: Any) -> Alias: diff --git a/src/dishka/dependency_source/context_var.py b/src/dishka/dependency_source/context_var.py index 325d7ede..bf707497 100644 --- a/src/dishka/dependency_source/context_var.py +++ b/src/dishka/dependency_source/context_var.py @@ -2,6 +2,7 @@ from typing import Any, NoReturn +from dishka.entities.activator import Activator from dishka.entities.component import DEFAULT_COMPONENT, Component from dishka.entities.factory_type import FactoryType from dishka.entities.key import DependencyKey @@ -15,21 +16,21 @@ def context_stub() -> NoReturn: class ContextVariable: - __slots__ = ("override", "provides", "scope") + __slots__ = ("override", "provides", "scope", "when") def __init__( self, *, provides: DependencyKey, scope: BaseScope | None, override: bool, + when: Activator | None, ) -> None: self.provides = provides self.scope = scope self.override = override + self.when = when - def as_factory( - self, component: Component, - ) -> Factory: + def as_factory(self, component: Component | None) -> Factory: if component == DEFAULT_COMPONENT: return Factory( scope=self.scope, @@ -41,6 +42,7 @@ def as_factory( type_=FactoryType.CONTEXT, cache=False, override=self.override, + when=self.when, ) else: aliased = Alias( @@ -51,6 +53,7 @@ def as_factory( component=component, type_hint=self.provides.type_hint, ), + when=self.when, ) return aliased.as_factory(scope=self.scope, component=component) @@ -60,4 +63,5 @@ def __get__(self, instance: Any, owner: Any) -> ContextVariable: scope=scope, provides=self.provides, override=self.override, + when=self.when, ) diff --git a/src/dishka/dependency_source/decorator.py b/src/dishka/dependency_source/decorator.py index 1de073d7..27c213e4 100644 --- a/src/dishka/dependency_source/decorator.py +++ b/src/dishka/dependency_source/decorator.py @@ -2,6 +2,7 @@ from typing import Any, TypeVar, get_args, get_origin +from dishka.entities.activator import Activator from dishka.entities.component import Component from dishka.entities.key import DependencyKey from dishka.entities.scope import BaseScope @@ -10,13 +11,14 @@ class Decorator: - __slots__ = ("factory", "provides", "scope") + __slots__ = ("factory", "provides", "scope", "when") def __init__( self, factory: Factory, provides: DependencyKey | None = None, scope: BaseScope | None = None, + when: Activator | None = None, ) -> None: self.factory = factory if provides: @@ -24,6 +26,7 @@ def __init__( else: self.provides = factory.provides self.scope = scope + self.when = when def is_generic(self) -> bool: return ( @@ -39,7 +42,7 @@ def as_factory( scope: BaseScope, new_dependency: DependencyKey, cache: bool, - component: Component, + component: Component | None, ) -> Factory: typevar_replacement = get_typevar_replacement( self.provides.type_hint, @@ -67,6 +70,7 @@ def as_factory( type_=self.factory.type, cache=cache, override=False, + when=self.when, ) def _replace_dep( @@ -95,4 +99,5 @@ def __get__(self, instance: Any, owner: Any) -> Decorator: return Decorator( self.factory.__get__(instance, owner), scope=self.scope, + when=self.when, ) diff --git a/src/dishka/dependency_source/factory.py b/src/dishka/dependency_source/factory.py index db30242a..68a90e56 100644 --- a/src/dishka/dependency_source/factory.py +++ b/src/dishka/dependency_source/factory.py @@ -6,6 +6,7 @@ ) from typing import Any +from dishka.entities.activator import Activator from dishka.entities.component import Component from dishka.entities.factory_type import FactoryData, FactoryType from dishka.entities.key import DependencyKey @@ -19,6 +20,7 @@ class Factory(FactoryData): "is_to_bind", "kw_dependencies", "override", + "when", ) def __init__( @@ -33,6 +35,7 @@ def __init__( is_to_bind: bool, cache: bool, override: bool, + when: Activator | None, ) -> None: super().__init__( source=source, @@ -45,6 +48,7 @@ def __init__( self.is_to_bind = is_to_bind self.cache = cache self.override = override + self.when = when def __get__(self, instance: Any, owner: Any) -> Factory: scope = self.scope or instance.scope @@ -66,9 +70,10 @@ def __get__(self, instance: Any, owner: Any) -> Factory: is_to_bind=False, cache=self.cache, override=self.override, + when=self.when, ) - def with_component(self, component: Component) -> Factory: + def with_component(self, component: Component | None) -> Factory: return Factory( dependencies=[ d.with_component(component) for d in self.dependencies @@ -84,4 +89,5 @@ def with_component(self, component: Component) -> Factory: cache=self.cache, type_=self.type, override=self.override, + when=self.when, ) diff --git a/src/dishka/entities/activator.py b/src/dishka/entities/activator.py new file mode 100644 index 00000000..957c8b3c --- /dev/null +++ b/src/dishka/entities/activator.py @@ -0,0 +1,43 @@ +from abc import abstractmethod +from collections.abc import Callable +from typing import Any, NamedTuple, Protocol, TypeAlias + +from dishka.entities.component import Component +from dishka.entities.key import DependencyKey + + +class ActivationBuilder(Protocol): + @abstractmethod + def has_active( + self, + key: DependencyKey, + request_stack: list[DependencyKey], + ) -> bool: + raise NotImplementedError + + +class ActivationContext(NamedTuple): + container_context: dict[Any, Any] | None + container_key: DependencyKey + key: DependencyKey + builder: ActivationBuilder + request_stack: list[DependencyKey] + + +Activator: TypeAlias = Callable[[ActivationContext], bool] + + +class Has: + def __init__( + self, + cls: Any, + *, + component: Component | None = None, + ) -> None: + self.key = DependencyKey(cls, component=component) + + def __call__(self, ctx: ActivationContext) -> bool: + key = self.key.with_component(ctx.key.component) + if key in ctx.request_stack: # cycle + return True + return ctx.builder.has_active(key, ctx.request_stack) diff --git a/src/dishka/provider/base_provider.py b/src/dishka/provider/base_provider.py index ec2ea4f9..9167c66d 100644 --- a/src/dishka/provider/base_provider.py +++ b/src/dishka/provider/base_provider.py @@ -4,13 +4,23 @@ Decorator, Factory, ) +from dishka.entities.activator import Activator from dishka.entities.component import Component class BaseProvider: - def __init__(self, component: Component | None) -> None: + when: Activator | None = None + component: Component | None = None + + def __init__( + self, + component: Component | None, + when: Activator | None = None, + ) -> None: if component is not None: self.component = component + if when is not None: + self.when = when self.factories: list[Factory] = [] self.aliases: list[Alias] = [] self.decorators: list[Decorator] = [] @@ -18,8 +28,13 @@ def __init__(self, component: Component | None) -> None: class ProviderWrapper(BaseProvider): - def __init__(self, component: Component, provider: BaseProvider) -> None: - super().__init__(component) + def __init__( + self, + component: Component, + provider: BaseProvider, + when: Activator | None = None, + ) -> None: + super().__init__(component, when=when) self.factories.extend(provider.factories) self.aliases.extend(provider.aliases) self.decorators.extend(provider.decorators) diff --git a/src/dishka/provider/make_alias.py b/src/dishka/provider/make_alias.py index 0f730859..c238055f 100644 --- a/src/dishka/provider/make_alias.py +++ b/src/dishka/provider/make_alias.py @@ -5,6 +5,7 @@ CompositeDependencySource, ensure_composite, ) +from dishka.entities.activator import Activator from dishka.entities.component import Component from dishka.entities.key import hint_to_dependency_key from .unpack_provides import unpack_alias @@ -17,6 +18,7 @@ def alias( cache: bool = True, component: Component | None = None, override: bool = False, + when: Activator | None = None, ) -> CompositeDependencySource: if component is provides is None: raise ValueError( # noqa: TRY003 @@ -31,6 +33,7 @@ def alias( provides=hint_to_dependency_key(provides), cache=cache, override=override, + when=when, ) composite.dependency_sources.extend(unpack_alias(alias_instance)) return composite diff --git a/src/dishka/provider/make_context_var.py b/src/dishka/provider/make_context_var.py index 67cc7c6e..7a6a3093 100644 --- a/src/dishka/provider/make_context_var.py +++ b/src/dishka/provider/make_context_var.py @@ -6,6 +6,7 @@ ContextVariable, context_stub, ) +from dishka.entities.activator import Activator from dishka.entities.component import DEFAULT_COMPONENT from dishka.entities.key import DependencyKey from dishka.entities.scope import BaseScope @@ -20,6 +21,7 @@ def from_context( *, scope: BaseScope | None = None, override: bool = False, + when: Activator | None = None, ) -> CompositeDependencySource: composite = CompositeDependencySource(origin=context_stub) composite.dependency_sources.append( @@ -30,6 +32,7 @@ def from_context( type_hint=provides, component=DEFAULT_COMPONENT, ), + when=when, ), ) @@ -41,6 +44,7 @@ def from_context( provides=DependencyKey(base_type, DEFAULT_COMPONENT), cache=True, override=override, + when=when, ), ) return composite diff --git a/src/dishka/provider/make_decorator.py b/src/dishka/provider/make_decorator.py index 74522777..c5a33259 100644 --- a/src/dishka/provider/make_decorator.py +++ b/src/dishka/provider/make_decorator.py @@ -6,6 +6,7 @@ Decorator, ensure_composite, ) +from dishka.entities.activator import Activator from dishka.entities.scope import BaseScope from .exceptions import IndependentDecoratorError from .make_factory import make_factory @@ -18,6 +19,7 @@ def _decorate( scope: BaseScope | None, *, is_in_class: bool = True, + when: Activator | None = None, ) -> CompositeDependencySource: composite = ensure_composite(source) decorator = Decorator( @@ -28,8 +30,10 @@ def _decorate( cache=False, is_in_class=is_in_class, override=False, + when=when, ), scope=scope, + when=when, ) if ( decorator.provides not in decorator.factory.kw_dependencies.values() @@ -46,6 +50,7 @@ def decorate( *, provides: Any = None, scope: BaseScope | None = None, + when: Activator | None = None, ) -> Callable[ [Callable[..., Any]], CompositeDependencySource, ]: @@ -58,6 +63,7 @@ def decorate( *, provides: Any = None, scope: BaseScope | None = None, + when: Activator | None = None, ) -> CompositeDependencySource: ... @@ -66,14 +72,19 @@ def decorate( source: Callable[..., Any] | type | None = None, provides: Any = None, scope: BaseScope | None = None, + when: Activator | None = None, ) -> CompositeDependencySource | Callable[ [Callable[..., Any]], CompositeDependencySource, ]: if source is not None: - return _decorate(source, provides, scope=scope, is_in_class=True) + return _decorate( + source, provides, scope=scope, is_in_class=True, when=when, + ) def scoped(func: Callable[..., Any]) -> CompositeDependencySource: - return _decorate(func, provides, scope=scope, is_in_class=True) + return _decorate( + func, provides, scope=scope, is_in_class=True, when=when, + ) return scoped @@ -82,5 +93,8 @@ def decorate_on_instance( source: Callable[..., Any] | type, provides: Any, scope: BaseScope | None, + when: Activator | None = None, ) -> CompositeDependencySource: - return _decorate(source, provides, scope=scope, is_in_class=False) + return _decorate( + source, provides, scope=scope, is_in_class=False, when=when, + ) diff --git a/src/dishka/provider/make_factory.py b/src/dishka/provider/make_factory.py index 171a6cfc..b6c18855 100644 --- a/src/dishka/provider/make_factory.py +++ b/src/dishka/provider/make_factory.py @@ -53,6 +53,7 @@ Factory, ensure_composite, ) +from dishka.entities.activator import Activator from dishka.entities.factory_type import FactoryType from dishka.entities.key import ( dependency_key_to_hint, @@ -248,6 +249,7 @@ def _make_factory_by_class( source: type, cache: bool, override: bool, + when: Activator | None, ) -> Factory: if not provides: provides = source @@ -276,6 +278,7 @@ def _make_factory_by_class( is_to_bind=False, cache=cache, override=override, + when=when, ) @@ -305,6 +308,7 @@ def _make_factory_by_function( cache: bool, is_in_class: bool, override: bool, + when: Activator | None, check_self_name: bool, ) -> Factory: # typing.cast is applied as unwrap takes a Callable object @@ -348,6 +352,7 @@ def _make_factory_by_function( is_to_bind=is_in_class, cache=cache, override=override, + when=when, ) @@ -358,6 +363,7 @@ def _make_factory_by_static_method( source: staticmethod, # type: ignore[type-arg] cache: bool, override: bool, + when: Activator | None, ) -> Factory: if missing_hints := _params_without_hints(source, skip_self=False): raise MissingHintsError(source, missing_hints) @@ -388,6 +394,7 @@ def _make_factory_by_static_method( is_to_bind=False, cache=cache, override=override, + when=when, ) @@ -398,6 +405,7 @@ def _make_factory_by_other_callable( source: Callable[..., Any], cache: bool, override: bool, + when: Activator | None, ) -> Factory: if _is_bound_method(source): to_check = source.__func__ # type: ignore[attr-defined] @@ -418,6 +426,7 @@ def _make_factory_by_other_callable( is_in_class=is_in_class, override=override, check_self_name=False, + when=when, ) if factory.is_to_bind: dependencies = factory.dependencies[1:] # remove `self` @@ -433,6 +442,7 @@ def _make_factory_by_other_callable( is_to_bind=False, cache=cache, override=override, + when=when, ) @@ -461,6 +471,7 @@ def make_factory( cache: bool, is_in_class: bool, override: bool, + when: Activator | None, ) -> Factory: provides, source = _extract_source(provides, source) @@ -479,6 +490,7 @@ def make_factory( source=cast(type, source), cache=cache, override=override, + when=when, ) elif isfunction(source) or isinstance(source, classmethod): return _make_factory_by_function( @@ -488,6 +500,7 @@ def make_factory( cache=cache, is_in_class=is_in_class, override=override, + when=when, check_self_name=True, ) elif isbuiltin(source): @@ -498,6 +511,7 @@ def make_factory( cache=cache, is_in_class=False, override=override, + when=when, check_self_name=False, ) elif isinstance(source, staticmethod): @@ -507,6 +521,7 @@ def make_factory( source=source, cache=cache, override=override, + when=when, ) elif callable(source) and not source_origin: return _make_factory_by_other_callable( @@ -515,6 +530,7 @@ def make_factory( source=source, cache=cache, override=override, + when=when, ) else: raise NotAFactoryError(source) @@ -529,13 +545,14 @@ def _provide( is_in_class: bool = True, recursive: bool = False, override: bool = False, + when: Activator | None = None, ) -> CompositeDependencySource: composite = ensure_composite(source) factory = make_factory( provides=provides, scope=scope, source=composite.origin, cache=cache, is_in_class=is_in_class, - override=override, + override=override, when=when, ) composite.dependency_sources.extend(unpack_factory(factory)) if not recursive: @@ -566,11 +583,12 @@ def provide_on_instance( cache: bool = True, recursive: bool = False, override: bool = False, + when: Activator | None = None, ) -> CompositeDependencySource: return _provide( provides=provides, scope=scope, source=source, cache=cache, is_in_class=False, - recursive=recursive, override=override, + recursive=recursive, override=override, when=when, ) @@ -582,6 +600,7 @@ def provide( cache: bool = True, recursive: bool = False, override: bool = False, + when: Activator | None = None, ) -> Callable[[Callable[..., Any]], CompositeDependencySource]: ... @@ -595,6 +614,7 @@ def provide( cache: bool = True, recursive: bool = False, override: bool = False, + when: Activator | None = None, ) -> CompositeDependencySource: ... @@ -607,6 +627,7 @@ def provide( cache: bool = True, recursive: bool = False, override: bool = False, + when: Activator | None = None, ) -> CompositeDependencySource | Callable[ [Callable[..., Any]], CompositeDependencySource, ]: @@ -630,18 +651,21 @@ def provide( :param cache: save created object to scope cache or not :param recursive: register dependencies as factories as well :param override: dependency override + :param when: activation func aka `def activate(ctx: ActivationCtx) -> bool` :return: instance of Factory or a decorator returning it """ if source is not None: return _provide( provides=provides, scope=scope, source=source, cache=cache, - is_in_class=True, recursive=recursive, override=override, + is_in_class=True, + recursive=recursive, override=override, when=when, ) def scoped(func: Callable[..., Any]) -> CompositeDependencySource: return _provide( provides=provides, scope=scope, source=func, cache=cache, - is_in_class=True, recursive=recursive, override=override, + is_in_class=True, + recursive=recursive, override=override, when=when, ) return scoped @@ -655,6 +679,7 @@ def _provide_all( is_in_class: bool, recursive: bool, override: bool = False, + when: Activator | None = None, ) -> CompositeDependencySource: composite = CompositeDependencySource(None) for single_provides in provides: @@ -666,6 +691,7 @@ def _provide_all( is_in_class=is_in_class, recursive=recursive, override=override, + when=when, ) composite.dependency_sources.extend(source.dependency_sources) return composite @@ -677,11 +703,12 @@ def provide_all( cache: bool = True, recursive: bool = False, override: bool = False, + when: Activator | None = None, ) -> CompositeDependencySource: return _provide_all( - provides=provides, scope=scope, - cache=cache, is_in_class=True, - recursive=recursive, override=override, + provides=provides, scope=scope, cache=cache, + is_in_class=True, + recursive=recursive, override=override, when=when, ) @@ -691,9 +718,10 @@ def provide_all_on_instance( cache: bool = True, recursive: bool = False, override: bool = False, + when: Activator | None = None, ) -> CompositeDependencySource: return _provide_all( - provides=provides, scope=scope, - cache=cache, is_in_class=False, - recursive=recursive, override=override, + provides=provides, scope=scope, cache=cache, + is_in_class=False, + recursive=recursive, override=override, when=when, ) diff --git a/src/dishka/provider/provider.py b/src/dishka/provider/provider.py index d343f520..fb0b4716 100644 --- a/src/dishka/provider/provider.py +++ b/src/dishka/provider/provider.py @@ -10,6 +10,7 @@ DependencySource, Factory, ) +from dishka.entities.activator import Activator from dishka.entities.component import DEFAULT_COMPONENT, Component from dishka.entities.scope import BaseScope from .base_provider import BaseProvider, ProviderWrapper @@ -47,13 +48,15 @@ class Provider(BaseProvider): """ scope: BaseScope | None = None component: Component = DEFAULT_COMPONENT + when: Activator | None = None def __init__( self, scope: BaseScope | None = None, component: Component | None = None, + when: Activator | None = None, ): - super().__init__(component) + super().__init__(component, when=when) self.scope = self.scope or scope self._init_dependency_sources() @@ -128,6 +131,7 @@ def provide( cache: bool = True, recursive: bool = False, override: bool = False, + when: Activator | None = None, ) -> CompositeDependencySource: if scope is None: scope = self.scope @@ -138,6 +142,7 @@ def provide( cache=cache, recursive=recursive, override=override, + when=when, ) self._add_dependency_sources(str(source), composite.dependency_sources) return composite @@ -149,6 +154,7 @@ def provide_all( cache: bool = True, recursive: bool = False, override: bool = False, + when: Activator | None = None, ) -> CompositeDependencySource: if scope is None: scope = self.scope @@ -158,6 +164,7 @@ def provide_all( cache=cache, recursive=recursive, override=override, + when=when, ) self._add_dependency_sources("?", composite.dependency_sources) return composite @@ -170,6 +177,7 @@ def alias( cache: bool = True, component: Component | None = None, override: bool = False, + when: Activator | None = None, ) -> CompositeDependencySource: composite = alias( source=source, @@ -177,6 +185,7 @@ def alias( cache=cache, component=component, override=override, + when=when, ) self._add_dependency_sources(str(source), composite.dependency_sources) return composite @@ -187,11 +196,13 @@ def decorate( *, provides: Any = None, scope: BaseScope | None = None, + when: Activator | None = None, ) -> CompositeDependencySource: composite = decorate_on_instance( source=source, provides=provides, scope=scope, + when=when, ) self._add_dependency_sources(str(source), composite.dependency_sources) return composite @@ -205,11 +216,13 @@ def from_context( *, scope: BaseScope | None = None, override: bool = False, + when: Activator | None = None, ) -> CompositeDependencySource: composite = from_context( provides=provides, scope=scope or self.scope, override=override, + when=when, ) self._add_dependency_sources( name=str(provides), diff --git a/src/dishka/provider/unpack_provides.py b/src/dishka/provider/unpack_provides.py index ce29fbd4..c2573ad5 100644 --- a/src/dishka/provider/unpack_provides.py +++ b/src/dishka/provider/unpack_provides.py @@ -27,6 +27,7 @@ def unpack_factory(factory: Factory) -> Sequence[DependencySource]: ).with_component(factory.provides.component), cache=factory.cache, override=factory.override, + when=factory.when, ) for provides_other in provides_others ] @@ -43,6 +44,7 @@ def unpack_factory(factory: Factory) -> Sequence[DependencySource]: provides=hint_to_dependency_key( provides_first, ).with_component(factory.provides.component), + when=factory.when, ), ) return res @@ -58,6 +60,7 @@ def unpack_decorator(decorator: Decorator) -> Sequence[DependencySource]: provides=hint_to_dependency_key( provides, ).with_component(decorator.provides.component), + when=decorator.when, ) for provides in get_args(decorator.provides.type_hint) ] @@ -75,6 +78,7 @@ def unpack_alias(alias: Alias) -> Sequence[DependencySource]: source=alias.source, cache=alias.cache, override=alias.override, + when=alias.when, ) for provides in get_args(alias.provides.type_hint) ] diff --git a/src/dishka/registry.py b/src/dishka/registry.py index ad118897..67596c67 100644 --- a/src/dishka/registry.py +++ b/src/dishka/registry.py @@ -169,6 +169,7 @@ def _get_type_var_factory(self, dependency: DependencyKey) -> Factory: cache=False, override=False, source=lambda: typevar, + when=None, ) def _specialize_generic( @@ -214,4 +215,5 @@ def _specialize_generic( scope=factory.scope, cache=factory.cache, override=factory.override, + when=factory.when, ) diff --git a/src/dishka/registry_builder.py b/src/dishka/registry_builder.py index 1c6125e9..c80efc7a 100644 --- a/src/dishka/registry_builder.py +++ b/src/dishka/registry_builder.py @@ -7,8 +7,10 @@ Alias, ContextVariable, Decorator, + DependencySource, Factory, ) +from .entities.activator import ActivationContext from .entities.component import DEFAULT_COMPONENT, Component from .entities.factory_type import FactoryType from .entities.key import DependencyKey @@ -133,6 +135,9 @@ def _find_other_component(self, key: DependencyKey) -> list[Factory]: return found +ProviderSource = tuple[BaseProvider, DependencySource] + + class RegistryBuilder: def __init__( self, @@ -142,9 +147,13 @@ def __init__( container_key: DependencyKey, skip_validation: bool, validation_settings: ValidationSettings, + root_context: dict[Any, Any] | None, ) -> None: self.scopes = scopes self.providers = providers + self.all_sources: list[ProviderSource] = [] + self.active_sources: set[ProviderSource] = set() + self.inactive_sources: set[ProviderSource] = set() self.dependency_scopes: dict[DependencyKey, BaseScope] = {} self.components: set[Component] = {DEFAULT_COMPONENT} self.alias_sources: dict[DependencyKey, Any] = {} @@ -154,35 +163,106 @@ def __init__( self.skip_validation = skip_validation self.validation_settings = validation_settings self.processed_factories: dict[DependencyKey, Factory] = {} + self.root_context = root_context + + def _is_active( + self, + provider: BaseProvider, + source: DependencySource, + request_stack: list[DependencyKey], + ) -> bool: + if (provider, source) in self.active_sources: + return True + if (provider, source) in self.inactive_sources: + return False + + key = source.provides.with_component(provider.component) + context = ActivationContext( + container_context=self.root_context, + container_key=self.container_key, + key=key, + builder=self, + request_stack=[*request_stack, key], + ) + if ((not source.when or source.when(context)) and + (not provider.when or provider.when(context))): + self.active_sources.add((provider, source)) + return True + self.inactive_sources.add((provider, source)) + return False + + def _collect_sources(self) -> None: + for provider in self.providers: + for factory in provider.factories: + self.all_sources.append((provider, factory)) + for alias in provider.aliases: + self.all_sources.append((provider, alias)) + for context_var in provider.context_vars: + self.all_sources.append((provider, context_var)) + for decorator in provider.decorators: + self.all_sources.append((provider, decorator)) + + def _filter_active_sources(self) -> None: + self.all_sources = [ + (provider, source) + for provider, source in self.all_sources + if self._is_active(provider, source, []) + ] + + def has_active( + self, + key: DependencyKey, + request_stack: list[DependencyKey], + ) -> bool: + for provider, source in self.all_sources: + src_key = source.provides.with_component(provider.component) + if ( + src_key==key + and self._is_active(provider, source, request_stack) + ): + return True + return False def _collect_components(self) -> None: for provider in self.providers: - self.components.add(provider.component) + if provider.component is not None: + self.components.add(provider.component) def _collect_provided_scopes(self) -> None: - for provider in self.providers: - for factory in provider.factories: - if not isinstance(factory.scope, self.scopes): - raise UnknownScopeError(factory.scope, self.scopes) - provides = factory.provides.with_component(provider.component) - self.dependency_scopes[provides] = factory.scope - for context_var in provider.context_vars: - for component in self.components: - provides = context_var.provides.with_component(component) - # typing.cast is applied because the scope - # was checked above - self.dependency_scopes[provides] = cast( - BaseScope, context_var.scope, - ) + for provider, source in self.all_sources: + match source: + case Factory(): + if not isinstance(source.scope, self.scopes): + raise UnknownScopeError(source.scope, self.scopes) + key = source.provides.with_component(provider.component) + self.dependency_scopes[key] = source.scope + case ContextVariable(): + if not isinstance(source.scope, self.scopes): + raise UnknownScopeError(source.scope, self.scopes) + for component in self.components: + key = source.provides.with_component(component) + # typing.cast is applied because the scope + # was checked above + self.dependency_scopes[key] = source.scope + case Decorator(): + if not source.scope: + continue + if not isinstance(source.scope, self.scopes): + raise UnknownScopeError(source.scope, self.scopes) + key = source.provides.with_component(provider.component) + self.dependency_scopes[key] = source.scope + + def _collect_aliases(self) -> None: - for provider in self.providers: - component = provider.component - for alias in provider.aliases: - provides = alias.provides.with_component(component) - alias_source = alias.source.with_component(component) - self.alias_sources[provides] = alias_source - self.aliases[provides] = alias + for provider, source in self.all_sources: + match source: + case Alias(): + component = provider.component + provides = source.provides.with_component(component) + alias_source = source.source.with_component(component) + self.alias_sources[provides] = alias_source + self.aliases[provides] = source def _make_registries(self) -> tuple[Registry, ...]: registries: dict[BaseScope, Registry] = {} @@ -193,6 +273,8 @@ def _make_registries(self) -> tuple[Registry, ...]: provides=self.container_key, scope=scope, override=False, + # Container have no activation function. + when=None, ) for component in self.components: registry.add_factory(context_var.as_factory(component)) @@ -432,29 +514,45 @@ def _process_context_var( ) self.processed_factories[factory.provides] = factory + def _process_source( + self, + provider: BaseProvider, + source: DependencySource, + ) -> None: + match source: + case Factory(): + self._process_factory(provider, source) + case Alias(): + self._process_alias(provider, source) + case ContextVariable(): + self._process_context_var(provider, source) + case Decorator(): + if source.is_generic(): + self._process_generic_decorator(provider, source) + else: + self._process_normal_decorator(provider, source) + case _: + raise TypeError + def build(self) -> tuple[Registry, ...]: self._collect_components() + self._collect_sources() + self._filter_active_sources() + self._collect_provided_scopes() self._collect_aliases() - for provider in self.providers: - for factory in provider.factories: - self.dependency_scopes[ - factory.provides.with_component(provider.component) - ] = cast(BaseScope, factory.scope) - - for provider in self.providers: - for factory in provider.factories: - self._process_factory(provider, factory) - for alias in provider.aliases: - self._process_alias(provider, alias) - for context_var in provider.context_vars: - self._process_context_var(provider, context_var) - for decorator in provider.decorators: - if decorator.is_generic(): - self._process_generic_decorator(provider, decorator) - else: - self._process_normal_decorator(provider, decorator) + for provider, source in self.all_sources: + match source: + case Factory(): + key = source.provides.with_component(provider.component) + self.dependency_scopes[key] = cast(BaseScope, source.scope) + case ContextVariable(): + key = source.provides.with_component(provider.component) + self.dependency_scopes[key] = cast(BaseScope, source.scope) + + for provider, source in self.all_sources: + self._process_source(provider, source) self._post_process_generic_factories() registries = self._make_registries() diff --git a/tests/unit/container/test_when.py b/tests/unit/container/test_when.py new file mode 100644 index 00000000..cb835902 --- /dev/null +++ b/tests/unit/container/test_when.py @@ -0,0 +1,178 @@ +import pytest + +from dishka import ( + ActivationContext, + Has, + Provider, + Scope, + alias, + decorate, + from_context, + make_container, + provide, + provide_all, +) +from dishka.exceptions import NoFactoryError + + +class A: + pass + +class B: + pass + +def always(x: ActivationContext) -> bool: + return True + + +def never(x: ActivationContext) -> bool: + return False + + +@pytest.mark.parametrize( + ("provides", "value"), [ + (int, "int"), + (float, "float"), + (complex, "default"), + ], +) +def test_has_cls(provides, value): + p = Provider(scope=Scope.APP) + p.provide(lambda: 42, provides=provides) + p.provide(lambda: "default", provides=str) + p.provide(lambda: "int", provides=str, when=Has(int)) + p.provide(lambda: "float", provides=str, when=Has(float)) + c = make_container(p) + + assert c.get(str) == value + + +def test_has_cycle(): + p = Provider(scope=Scope.APP) + p.provide(lambda: 42, provides=int, when=Has(str)) + p.provide(lambda: "s", provides=str, when=Has(int)) + c = make_container(p) + assert c.get(str) == "s" + assert c.get(int) == 42 + + +def test_chain(): + p = Provider(scope=Scope.APP) + p.provide(lambda: 42, provides=int, when=never) + p.provide(lambda: "s", provides=str, when=Has(int)) + c = make_container(p) + with pytest.raises(NoFactoryError): + c.get(int) + with pytest.raises(NoFactoryError): + c.get(str) + + +def test_custom_predicate_on(): + p = Provider(scope=Scope.APP) + p.provide(lambda: 42, provides=int, when=always) + p.from_context(provides=str, when=always) + p.alias(int, provides=float, when=always) + p.provide_all(A, B, when=always) + + c = make_container(p, context={str: "x"}) + assert c.get(int) == 42 + assert c.get(str) == "x" + assert c.get(float) == 42 + assert isinstance(c.get(A), A) + assert isinstance(c.get(B), B) + + +def test_custom_predicate_off(): + p = Provider(scope=Scope.APP) + p.provide(lambda: 42, provides=int, when=never) + p.from_context(provides=str, when=never) + p.alias(int, provides=float, when=never) + p.provide_all(A, B, when=never) + c = make_container(p) + with pytest.raises(NoFactoryError): + c.get(int) + with pytest.raises(NoFactoryError): + c.get(str) + with pytest.raises(NoFactoryError): + c.get(float) + with pytest.raises(NoFactoryError): + c.get(A) + with pytest.raises(NoFactoryError): + c.get(B) + + +def test_provider(): + p1 = Provider(scope=Scope.APP) + p1.provide(lambda: 1, provides=int) + p2 = Provider(scope=Scope.APP, when=always) + p2.provide(lambda: 2, provides=int, when=never) + p3 = Provider(scope=Scope.APP, when=never) + p3.provide(lambda: 3, provides=int, when=always) + c = make_container(p1, p2, p3) + assert c.get(int) == 1 + + +def test_provider_class_when(): + class MyProvide(Provider): + def when(self, x: ActivationContext) -> bool: + return False + + p = MyProvide(scope=Scope.APP) + p.provide(lambda: 1, provides=int) + c = make_container(p) + with pytest.raises(NoFactoryError): + c.get(int) + + +def add(x: int) -> int: + return x + 1 + + +def neg(x: int) -> int: + return -x + + +def test_decorator(): + p = Provider(scope=Scope.APP) + p.provide(lambda: 1, provides=int) + p.decorate(add, when=always) + p.decorate(neg, when=never) + + c = make_container(p) + assert c.get(int) == 2 + + +def test_class_based(): + class MyProvide(Provider): + scope = Scope.APP + + @provide + def b(self) -> complex: + return 42 + + @provide(when=never) + def i(self) -> int: + return 1 + + a = alias(complex, provides=float, when=never) + s = from_context(str, when=never) + + @decorate(when=never) + def d(self, value: complex) -> complex: + return value * 100 + + x = provide_all(A, B, when=never) + + p = MyProvide() + c = make_container(p) + with pytest.raises(NoFactoryError): + c.get(int) + with pytest.raises(NoFactoryError): + c.get(float) + with pytest.raises(NoFactoryError): + c.get(str) + with pytest.raises(NoFactoryError): + c.get(A) + with pytest.raises(NoFactoryError): + c.get(B) + assert c.get(complex) == 42 diff --git a/tests/unit/sample_providers.py b/tests/unit/sample_providers.py index 25c23575..e2bfa1ac 100644 --- a/tests/unit/sample_providers.py +++ b/tests/unit/sample_providers.py @@ -75,5 +75,6 @@ async def async_gen_a(self, dep: int) -> AsyncGenerator[ClassA, None]: is_to_bind=False, cache=False, override=False, + when=None, ) value_source = CompositeDependencySource(lambda: None, [value_factory]) diff --git a/tests/unit/test_provider.py b/tests/unit/test_provider.py index a6600864..624dc7b7 100644 --- a/tests/unit/test_provider.py +++ b/tests/unit/test_provider.py @@ -68,7 +68,6 @@ def foo(self, x: object) -> str: assert len(provider.factories) == 1 assert len(provider.aliases) == 1 - @pytest.mark.parametrize( ("source", "provider_type", "is_to_bound"), [ @@ -82,8 +81,9 @@ def foo(self, x: object) -> str: (async_gen_a, FactoryType.ASYNC_GENERATOR, True), ], ) -def test_parse_factory(source, provider_type, is_to_bound): - composite = provide(source, scope=Scope.REQUEST) +@pytest.mark.parametrize("when", [None, lambda x: bool(x)]) +def test_parse_factory(source, provider_type, is_to_bound, when): + composite = provide(source, scope=Scope.REQUEST, when=when) assert len(composite.dependency_sources) == 1 factory = composite.dependency_sources[0] @@ -96,6 +96,7 @@ def test_parse_factory(source, provider_type, is_to_bound): assert factory.scope == Scope.REQUEST assert factory.source == source assert factory.type == provider_type + assert factory.when == when def test_provide_no_scope(): @@ -137,7 +138,8 @@ async def foo(self) -> int: (ClassA, FactoryType.FACTORY, False), ], ) -def test_parse_factory_cls(source, provider_type, is_to_bound): +@pytest.mark.parametrize("when", [None, lambda x: bool(x)]) +def test_parse_factory_cls(source, provider_type, is_to_bound, when): factory = make_factory( provides=None, source=source, @@ -145,6 +147,7 @@ def test_parse_factory_cls(source, provider_type, is_to_bound): scope=Scope.REQUEST, is_in_class=False, override=False, + when=when, ) assert factory.provides == hint_to_dependency_key(ClassA) assert factory.dependencies == [hint_to_dependency_key(int)] @@ -152,6 +155,7 @@ def test_parse_factory_cls(source, provider_type, is_to_bound): assert factory.scope == Scope.REQUEST assert factory.source == source assert factory.type == provider_type + assert factory.when == when def test_provider_class_scope(): @@ -430,6 +434,7 @@ def b(self, num: int) -> float: cache=True, is_in_class=False, override=False, + when=None, ) diff --git a/tests/unit/test_quickstart_example.py b/tests/unit/test_quickstart_example.py index 3d793e43..302385fa 100644 --- a/tests/unit/test_quickstart_example.py +++ b/tests/unit/test_quickstart_example.py @@ -6,4 +6,4 @@ def test_readme_example(): - runpy.run_path(QUICKSTART_EXAMPLE_PATH) + runpy.run_path(str(QUICKSTART_EXAMPLE_PATH)) diff --git a/tests/unit/test_registry.py b/tests/unit/test_registry.py index e0db5425..1208b08f 100644 --- a/tests/unit/test_registry.py +++ b/tests/unit/test_registry.py @@ -26,6 +26,7 @@ def factory() -> Factory: cache=True, is_in_class=False, override=False, + when=None, ) diff --git a/tests/unit/text_rendering/test_path.py b/tests/unit/text_rendering/test_path.py index d902eb1a..2260826f 100644 --- a/tests/unit/text_rendering/test_path.py +++ b/tests/unit/text_rendering/test_path.py @@ -32,6 +32,7 @@ def test_cycle(cycle_renderer): scope=None, is_in_class=True, override=False, + when=None, ) res = cycle_renderer.render([factory, factory, factory]) @@ -53,6 +54,7 @@ def test_cycle_2scopes(cycle_renderer): scope=Scope.APP, is_in_class=True, override=False, + when=None, ) factory = make_factory( provides=Annotated[int, FromComponent("")], @@ -61,6 +63,7 @@ def test_cycle_2scopes(cycle_renderer): scope=Scope.REQUEST, is_in_class=True, override=False, + when=None, ) res = cycle_renderer.render([app_factory, factory, factory, factory]) @@ -84,6 +87,7 @@ def test_cycle_1(cycle_renderer): scope=None, is_in_class=True, override=False, + when=None, ) res = cycle_renderer.render([factory]) @@ -103,6 +107,7 @@ def test_linear(linear_renderer): scope=Scope.APP, is_in_class=True, override=False, + when=None, ) res = linear_renderer.render([factory, factory], last=factory.provides) @@ -124,6 +129,7 @@ def test_linear_2scopes(linear_renderer): scope=Scope.APP, is_in_class=True, override=False, + when=None, ) factory = make_factory( provides=Annotated[int, FromComponent("")], @@ -132,6 +138,7 @@ def test_linear_2scopes(linear_renderer): scope=Scope.REQUEST, is_in_class=True, override=False, + when=None, ) res = linear_renderer.render(