From ae6db2309b0e4b165861294848a0a4ce02498d3e Mon Sep 17 00:00:00 2001 From: Daniil Kharkov Date: Sat, 29 Nov 2025 15:21:12 +0700 Subject: [PATCH] feat: add conditional provider activation --- examples/conditional_providers.py | 201 +++++++++ src/waku/di/__init__.py | 19 +- src/waku/di/_activation.py | 108 +++++ src/waku/di/_providers.py | 175 ++++++-- src/waku/factory.py | 10 +- src/waku/modules/_metadata.py | 6 +- src/waku/modules/_module.py | 57 ++- src/waku/modules/_registry_builder.py | 47 ++- tests/data.py | 4 +- tests/di/activation/__init__.py | 0 .../di/activation/test_activation_context.py | 91 ++++ .../activation/test_activation_integration.py | 393 ++++++++++++++++++ .../activation/test_conditional_providers.py | 136 ++++++ 13 files changed, 1200 insertions(+), 47 deletions(-) create mode 100644 examples/conditional_providers.py create mode 100644 src/waku/di/_activation.py create mode 100644 tests/di/activation/__init__.py create mode 100644 tests/di/activation/test_activation_context.py create mode 100644 tests/di/activation/test_activation_integration.py create mode 100644 tests/di/activation/test_conditional_providers.py diff --git a/examples/conditional_providers.py b/examples/conditional_providers.py new file mode 100644 index 00000000..a77865cb --- /dev/null +++ b/examples/conditional_providers.py @@ -0,0 +1,201 @@ +"""Example demonstrating conditional provider registration with the `when` feature. + +Shows: +1. Custom activators for environment-based provider selection +2. Using `Has` to conditionally activate providers based on available dependencies +""" + +from __future__ import annotations + +import asyncio +from abc import abstractmethod +from dataclasses import dataclass +from typing import Protocol + +from dishka.exceptions import NoFactoryError + +from waku import WakuFactory, module +from waku.di import ActivationContext, Has, scoped, singleton + + +class ICache(Protocol): + @abstractmethod + def get(self, key: str) -> str | None: ... + + @abstractmethod + def set(self, key: str, value: str) -> None: ... + + +class RedisCache: + """Production cache using Redis.""" + + def __init__(self) -> None: + self._data: dict[str, str] = {} + + def get(self, key: str) -> str | None: + print(f'[Redis] GET {key}') + return self._data.get(key) + + def set(self, key: str, value: str) -> None: + print(f'[Redis] SET {key}={value}') + self._data[key] = value + + +class InMemoryCache: + """Fallback in-memory cache for development/testing.""" + + def __init__(self) -> None: + self._data: dict[str, str] = {} + + def get(self, key: str) -> str | None: + print(f'[InMemory] GET {key}') + return self._data.get(key) + + def set(self, key: str, value: str) -> None: + print(f'[InMemory] SET {key}={value}') + self._data[key] = value + + +@dataclass +class AppConfig: + environment: str + + +def is_production(ctx: ActivationContext) -> bool: + """Activator that checks if running in production environment.""" + if ctx.container_context is None: + return False + config = ctx.container_context.get(AppConfig) + return config is not None and config.environment == 'production' + + +def is_not_production(ctx: ActivationContext) -> bool: + """Activator for non-production environments.""" + return not is_production(ctx) + + +# Example 1: Environment-based provider selection +@module( + providers=[ + singleton(ICache, RedisCache, when=is_production), + singleton(ICache, InMemoryCache, when=is_not_production), + ], + exports=[ICache], +) +class CacheModule: + """Module with environment-based cache selection.""" + + +# --- Example 2: Conditional provider based on Has --- + + +class IMetricsCollector(Protocol): + @abstractmethod + def record(self, metric: str, value: float) -> None: ... + + +class PrometheusCollector: + def record(self, metric: str, value: float) -> None: + print(f'[Prometheus] {metric}={value}') + + +class MetricsService: + """Service that only activates when IMetricsCollector is available.""" + + def __init__(self, collector: IMetricsCollector) -> None: + self.collector = collector + + def track_request(self, endpoint: str) -> None: + self.collector.record(f'requests.{endpoint}', 1.0) + + +@module( + providers=[singleton(IMetricsCollector, PrometheusCollector)], + exports=[IMetricsCollector], +) +class MetricsModule: + """Optional module providing metrics collection.""" + + +class UserService: + def __init__(self, cache: ICache) -> None: + self.cache = cache + + def get_user(self, user_id: str) -> str: + if cached := self.cache.get(f'user:{user_id}'): + return cached + user_data = f'User-{user_id}' + self.cache.set(f'user:{user_id}', user_data) + return user_data + + +@module( + imports=[CacheModule], + providers=[scoped(UserService)], +) +class AppModule: + pass + + +async def demo_environment_based() -> None: + """Demo 1: Environment-based provider selection.""" + print('=== Example 1: Environment-based Selection ===\n') + + # Production environment - uses RedisCache + print('Production:') + prod_config = AppConfig(environment='production') + app = WakuFactory(AppModule, context={AppConfig: prod_config}).create() + async with app, app.container() as container: + service = await container.get(UserService) + service.get_user('123') + + # Development environment - uses InMemoryCache + print('\nDevelopment:') + dev_config = AppConfig(environment='development') + app = WakuFactory(AppModule, context={AppConfig: dev_config}).create() + async with app, app.container() as container: + service = await container.get(UserService) + service.get_user('456') + + +async def demo_has_conditional() -> None: + """Demo 2: Conditional activation with Has.""" + print('\n=== Example 2: Conditional with Has ===\n') + + # With MetricsModule - MetricsService is available + @module( + imports=[MetricsModule], + providers=[scoped(MetricsService, when=Has(IMetricsCollector))], + ) + class AppWithMetrics: + pass + + print('With MetricsModule imported:') + app = WakuFactory(AppWithMetrics).create() + async with app, app.container() as container: + service = await container.get(MetricsService) + service.track_request('/api/users') + + # Without MetricsModule - MetricsService is not registered + @module( + providers=[scoped(MetricsService, when=Has(IMetricsCollector))], + ) + class AppWithoutMetrics: + pass + + print('\nWithout MetricsModule (MetricsService not available):') + app = WakuFactory(AppWithoutMetrics).create() + async with app, app.container() as container: + try: + await container.get(MetricsService) + except NoFactoryError: + print('MetricsService not available (as expected)') + + +async def main() -> None: + await demo_environment_based() + await demo_has_conditional() + + +if __name__ == '__main__': + asyncio.run(main()) diff --git a/src/waku/di/__init__.py b/src/waku/di/__init__.py index b5ee0730..3b508120 100644 --- a/src/waku/di/__init__.py +++ b/src/waku/di/__init__.py @@ -14,16 +14,33 @@ ) from dishka.provider import BaseProvider -from waku.di._providers import contextual, many, object_, provider, scoped, singleton, transient +from waku.di._activation import ( + ActivationBuilder, + ActivationContext, + Activator, + ConditionalProvider, + Has, + IProviderFilter, + ProviderFilter, +) +from waku.di._providers import ProviderSpec, contextual, many, object_, provider, scoped, singleton, transient __all__ = [ 'DEFAULT_COMPONENT', + 'ActivationBuilder', + 'ActivationContext', + 'Activator', 'AnyOf', 'AsyncContainer', 'BaseProvider', + 'ConditionalProvider', 'FromComponent', + 'Has', + 'IProviderFilter', 'Injected', 'Provider', + 'ProviderFilter', + 'ProviderSpec', 'Scope', 'WithParents', 'alias', diff --git a/src/waku/di/_activation.py b/src/waku/di/_activation.py new file mode 100644 index 00000000..84725633 --- /dev/null +++ b/src/waku/di/_activation.py @@ -0,0 +1,108 @@ +from __future__ import annotations + +from abc import abstractmethod +from collections.abc import Callable +from dataclasses import dataclass, field +from typing import TYPE_CHECKING, Any, NamedTuple, Protocol, TypeAlias + +if TYPE_CHECKING: + from dishka import Provider + + from waku import DynamicModule + from waku.di._providers import ProviderSpec + from waku.modules import ModuleType + +__all__ = [ + 'ActivationBuilder', + 'ActivationContext', + 'Activator', + 'ConditionalProvider', + 'Has', + 'IProviderFilter', + 'ProviderFilter', +] + + +class ActivationBuilder(Protocol): + @abstractmethod + def has_active(self, type_: Any) -> bool: + raise NotImplementedError + + +class ActivationContext(NamedTuple): + """Context passed to activators for provider activation decisions.""" + + container_context: dict[Any, Any] | None + module_type: ModuleType | DynamicModule + provided_type: Any + builder: ActivationBuilder + + +Activator: TypeAlias = Callable[[ActivationContext], bool] + + +@dataclass(frozen=True, slots=True) +class Has: + """Activator that checks if a provider for a type is registered.""" + + type_: Any + + def __call__(self, ctx: ActivationContext) -> bool: + return ctx.builder.has_active(self.type_) + + +@dataclass(frozen=True, slots=True) +class ConditionalProvider: + """Provider with activation condition.""" + + provider: Provider + when: Activator + provided_type: Any + + +class IProviderFilter(Protocol): + """Strategy for filtering providers based on activation context.""" + + def filter( + self, + providers: list[ProviderSpec], + context: dict[Any, Any] | None, + module_type: ModuleType | DynamicModule, + builder: ActivationBuilder, + ) -> list[Provider]: ... + + +OnSkipCallback: TypeAlias = Callable[[ConditionalProvider, ActivationContext], None] + + +@dataclass(slots=True) +class ProviderFilter: + """Default provider filter implementation.""" + + on_skip: OnSkipCallback | None = field(default=None) + + def filter( + self, + providers: list[ProviderSpec], + context: dict[Any, Any] | None, + module_type: ModuleType | DynamicModule, + builder: ActivationBuilder, + ) -> list[Provider]: + result: list[Provider] = [] + + for spec in providers: + if isinstance(spec, ConditionalProvider): + ctx = ActivationContext( + container_context=context, + module_type=module_type, + provided_type=spec.provided_type, + builder=builder, + ) + if spec.when(ctx): + result.append(spec.provider) + elif self.on_skip: + self.on_skip(spec, ctx) + else: + result.append(spec) + + return result diff --git a/src/waku/di/_providers.py b/src/waku/di/_providers.py index 93d17d88..90b0266b 100644 --- a/src/waku/di/_providers.py +++ b/src/waku/di/_providers.py @@ -1,12 +1,13 @@ import inspect from collections.abc import Callable, Sequence -from typing import Any, TypeVar, get_type_hints, overload +from typing import Any, TypeAlias, TypeVar, get_type_hints, overload from dishka import Provider, Scope -_T = TypeVar('_T') +from waku.di._activation import Activator, ConditionalProvider __all__ = [ + 'ProviderSpec', 'contextual', 'many', 'object_', @@ -16,6 +17,10 @@ 'transient', ] +ProviderSpec: TypeAlias = Provider | ConditionalProvider + +_T = TypeVar('_T') + def provider( source: Callable[..., Any] | type[Any], @@ -44,84 +49,157 @@ def provider( def singleton(source: type[_T] | Callable[..., _T], /) -> Provider: ... +@overload +def singleton(source: type[_T] | Callable[..., _T], /, *, when: Activator) -> ConditionalProvider: ... + + @overload def singleton(interface: Any, implementation: type[_T] | Callable[..., _T], /) -> Provider: ... +@overload +def singleton( + interface: Any, implementation: type[_T] | Callable[..., _T], /, *, when: Activator +) -> ConditionalProvider: ... + + def singleton( interface_or_source: type[Any] | Callable[..., Any], implementation: type[Any] | Callable[..., Any] | None = None, /, -) -> Provider: + *, + when: Activator | None = None, +) -> ProviderSpec: """Create a singleton provider (lifetime: app). Args: interface_or_source: Interface type or source if no separate implementation. implementation: Implementation type if interface is provided. + when: Optional predicate to conditionally activate the provider. Returns: - Provider: Singleton provider instance. + Provider or ConditionalProvider if `when` is specified. """ if implementation is not None: - return provider(implementation, scope=Scope.APP, provided_type=interface_or_source) - return provider(interface_or_source, scope=Scope.APP) + provided_type = interface_or_source + base = provider(implementation, scope=Scope.APP, provided_type=provided_type) + else: + provided_type = _get_provided_type(interface_or_source) + base = provider(interface_or_source, scope=Scope.APP) + + if when is None: + return base + return ConditionalProvider(provider=base, when=when, provided_type=provided_type) @overload def scoped(source: type[_T] | Callable[..., _T], /) -> Provider: ... +@overload +def scoped(source: type[_T] | Callable[..., _T], /, *, when: Activator) -> ConditionalProvider: ... + + @overload def scoped(interface: Any, implementation: type[_T] | Callable[..., _T], /) -> Provider: ... +@overload +def scoped( + interface: Any, implementation: type[_T] | Callable[..., _T], /, *, when: Activator +) -> ConditionalProvider: ... + + def scoped( interface_or_source: type[Any] | Callable[..., Any], implementation: type[Any] | Callable[..., Any] | None = None, /, -) -> Provider: + *, + when: Activator | None = None, +) -> ProviderSpec: """Create a scoped provider (lifetime: request). Args: interface_or_source: Interface type or source if no separate implementation. implementation: Implementation type if interface is provided. + when: Optional predicate to conditionally activate the provider. Returns: - Provider: Scoped provider instance. + Provider or ConditionalProvider if `when` is specified. """ if implementation is not None: - return provider(implementation, scope=Scope.REQUEST, provided_type=interface_or_source) - return provider(interface_or_source, scope=Scope.REQUEST) + provided_type = interface_or_source + base = provider(implementation, scope=Scope.REQUEST, provided_type=provided_type) + else: + provided_type = _get_provided_type(interface_or_source) + base = provider(interface_or_source, scope=Scope.REQUEST) + + if when is None: + return base + return ConditionalProvider(provider=base, when=when, provided_type=provided_type) @overload def transient(source: type[_T] | Callable[..., _T], /) -> Provider: ... +@overload +def transient(source: type[_T] | Callable[..., _T], /, *, when: Activator) -> ConditionalProvider: ... + + @overload def transient(interface: Any, implementation: type[_T] | Callable[..., _T], /) -> Provider: ... +@overload +def transient( + interface: Any, implementation: type[_T] | Callable[..., _T], /, *, when: Activator +) -> ConditionalProvider: ... + + def transient( interface_or_source: type[Any] | Callable[..., Any], implementation: type[Any] | Callable[..., Any] | None = None, /, -) -> Provider: + *, + when: Activator | None = None, +) -> ProviderSpec: """Create a transient provider (new instance per injection). Args: interface_or_source: Interface type or source if no separate implementation. implementation: Implementation type if interface is provided. + when: Optional predicate to conditionally activate the provider. Returns: - Provider: Transient provider instance. + Provider or ConditionalProvider if `when` is specified. """ if implementation is not None: - return provider(implementation, scope=Scope.REQUEST, provided_type=interface_or_source, cache=False) - return provider(interface_or_source, scope=Scope.REQUEST, cache=False) + provided_type = interface_or_source + base = provider(implementation, scope=Scope.REQUEST, provided_type=provided_type, cache=False) + else: + provided_type = _get_provided_type(interface_or_source) + base = provider(interface_or_source, scope=Scope.REQUEST, cache=False) + + if when is None: + return base + return ConditionalProvider(provider=base, when=when, provided_type=provided_type) + + +@overload +def object_(obj: Any, *, provided_type: Any | None = None) -> Provider: ... + +@overload +def object_(obj: Any, *, provided_type: Any | None = None, when: Activator) -> ConditionalProvider: ... -def object_(obj: Any, *, provided_type: Any | None = None) -> Provider: + +def object_( + obj: Any, + *, + provided_type: Any | None = None, + when: Activator | None = None, +) -> ProviderSpec: """Provide the exact object passed at creation time as a singleton dependency. The provider always returns the same object instance, without instantiation or copying. @@ -129,26 +207,49 @@ def object_(obj: Any, *, provided_type: Any | None = None) -> Provider: Args: obj: The instance to provide as-is. provided_type: Explicit type to provide (default: inferred). + when: Optional predicate to conditionally activate the provider. Returns: - Provider: Provider that always returns the given object. + Provider or ConditionalProvider if `when` is specified. """ - return provider(lambda: obj, scope=Scope.APP, provided_type=provided_type, cache=True) + actual_type = provided_type if provided_type is not None else type(obj) + base = provider(lambda: obj, scope=Scope.APP, provided_type=actual_type, cache=True) + + if when is None: + return base + return ConditionalProvider(provider=base, when=when, provided_type=actual_type) + + +@overload +def contextual(provided_type: Any, *, scope: Scope = Scope.REQUEST) -> Provider: ... + +@overload +def contextual(provided_type: Any, *, scope: Scope = Scope.REQUEST, when: Activator) -> ConditionalProvider: ... -def contextual(provided_type: Any, *, scope: Scope = Scope.REQUEST) -> Provider: + +def contextual( + provided_type: Any, + *, + scope: Scope = Scope.REQUEST, + when: Activator | None = None, +) -> ProviderSpec: """Provide a dependency from the current context (e.g., app/request). Args: provided_type: The type to resolve from context. scope: Scope of the context variable (default: Scope.REQUEST). + when: Optional predicate to conditionally activate the provider. Returns: - Provider: Contextual provider instance. + Provider or ConditionalProvider if `when` is specified. """ provider_ = Provider() provider_.from_context(provided_type, scope=scope) - return provider_ + + if when is None: + return provider_ + return ConditionalProvider(provider=provider_, when=when, provided_type=provided_type) def _get_provided_type(impl: Any) -> Any: @@ -176,12 +277,32 @@ def _get_provided_type(impl: Any) -> Any: raise TypeError(msg) +@overload +def many( + interface: Any, + *implementations: Any, + scope: Scope = Scope.REQUEST, + cache: bool = True, +) -> Provider: ... + + +@overload def many( interface: Any, *implementations: Any, scope: Scope = Scope.REQUEST, cache: bool = True, -) -> Provider: + when: Activator, +) -> ConditionalProvider: ... + + +def many( + interface: Any, + *implementations: Any, + scope: Scope = Scope.REQUEST, + cache: bool = True, + when: Activator | None = None, +) -> ProviderSpec: """Register multiple implementations as a collection. Args: @@ -189,18 +310,14 @@ def many( *implementations: Implementation types or factory functions to include in collection. scope: Scope of the collection (default: Scope.REQUEST). cache: Whether to cache the resolve results within scope. + when: Optional predicate to conditionally activate the provider. Returns: - Provider: Collection provider instance. + Provider or ConditionalProvider if `when` is specified. Raises: ValueError: If no implementations are provided. TypeError: If a factory function lacks a return type annotation. - - Examples: - many(IPipelineBehavior[Any, Any], ValidationBehavior, LoggingBehavior) - many(IEventHandler[UserCreated], EmailHandler, AuditHandler, scope=Scope.APP) - many(IRuleStrategy, rule_strategy_factory) # Factory functions supported """ if not implementations: msg = 'At least one implementation must be provided' @@ -224,4 +341,6 @@ def many( def _(many_: list[interface], one: provided_type) -> list[interface]: # type: ignore[valid-type] return [*many_, one] - return provider_ + if when is None: + return provider_ + return ConditionalProvider(provider=provider_, when=when, provided_type=Sequence[interface]) diff --git a/src/waku/factory.py b/src/waku/factory.py index 73a54ee3..578cda35 100644 --- a/src/waku/factory.py +++ b/src/waku/factory.py @@ -16,7 +16,7 @@ from collections.abc import Sequence from waku import Module - from waku.di import AsyncContainer, BaseProvider, Scope + from waku.di import AsyncContainer, BaseProvider, IProviderFilter, Scope from waku.extensions import ApplicationExtension from waku.lifespan import LifespanFunc from waku.modules import ModuleType @@ -45,6 +45,7 @@ def __init__( lifespan: Sequence[LifespanFunc] = (), extensions: Sequence[ApplicationExtension] = DEFAULT_EXTENSIONS, container_config: ContainerConfig | None = None, + provider_filter: IProviderFilter | None = None, ) -> None: self._root_module_type = root_module_type @@ -52,9 +53,14 @@ def __init__( self._lifespan = lifespan self._extensions = extensions self._container_config = container_config or ContainerConfig() + self._provider_filter = provider_filter def create(self) -> WakuApplication: - registry = ModuleRegistryBuilder(self._root_module_type).build() + registry = ModuleRegistryBuilder( + self._root_module_type, + context=self._context, + provider_filter=self._provider_filter, + ).build() container = self._build_container(registry.providers) return WakuApplication( container=container, diff --git a/src/waku/modules/_metadata.py b/src/waku/modules/_metadata.py index 64ee1442..11da3ba2 100644 --- a/src/waku/modules/_metadata.py +++ b/src/waku/modules/_metadata.py @@ -11,7 +11,7 @@ if TYPE_CHECKING: from collections.abc import Callable, Sequence - from waku.di import BaseProvider + from waku.di import ProviderSpec from waku.extensions import ModuleExtension __all__ = [ @@ -31,7 +31,7 @@ @dataclass(kw_only=True, slots=True) class ModuleMetadata: - providers: list[BaseProvider] = field(default_factory=list) + providers: list[ProviderSpec] = field(default_factory=list) """List of providers for dependency injection.""" imports: list[ModuleType | DynamicModule] = field(default_factory=list) """List of modules imported by this module.""" @@ -66,7 +66,7 @@ def __hash__(self) -> int: def module( *, - providers: Sequence[BaseProvider] = (), + providers: Sequence[ProviderSpec] = (), imports: Sequence[ModuleType | DynamicModule] = (), exports: Sequence[type[object] | ModuleType | DynamicModule] = (), extensions: Sequence[ModuleExtension] = (), diff --git a/src/waku/modules/_module.py b/src/waku/modules/_module.py index e6ccc393..1962cf04 100644 --- a/src/waku/modules/_module.py +++ b/src/waku/modules/_module.py @@ -1,14 +1,17 @@ from __future__ import annotations -from functools import cached_property from typing import TYPE_CHECKING, Final, cast -from waku.di import DEFAULT_COMPONENT, BaseProvider +from waku.di import DEFAULT_COMPONENT, ActivationBuilder, BaseProvider, IProviderFilter if TYPE_CHECKING: from collections.abc import Iterable, Sequence + from typing import Any from uuid import UUID + from dishka import Provider + + from waku.di import ProviderSpec from waku.extensions import ModuleExtension from waku.modules._metadata import DynamicModule, ModuleMetadata, ModuleType @@ -18,7 +21,6 @@ class Module: __slots__ = ( - '__dict__', '_provider', 'exports', 'extensions', @@ -33,20 +35,59 @@ def __init__(self, module_type: ModuleType, metadata: ModuleMetadata) -> None: self.id: Final[UUID] = metadata.id self.target: Final[ModuleType] = module_type - self.providers: Final[Sequence[BaseProvider]] = metadata.providers + self.providers: Final[Sequence[ProviderSpec]] = metadata.providers self.imports: Final[Sequence[ModuleType | DynamicModule]] = metadata.imports self.exports: Final[Sequence[type[object] | ModuleType | DynamicModule]] = metadata.exports self.extensions: Final[Sequence[ModuleExtension]] = metadata.extensions self.is_global: Final[bool] = metadata.is_global + self._provider: BaseProvider | None = None + @property def name(self) -> str: return self.target.__name__ - @cached_property + @property def provider(self) -> BaseProvider: + """Get the aggregated provider for this module. + + This property returns the provider created by create_provider(). + Must be called after create_provider() has been invoked. + + Raises: + RuntimeError: If create_provider() has not been called yet. + """ + if self._provider is None: + msg = f'Module {self.name} provider not yet created. Call create_provider() first.' + raise RuntimeError(msg) + return self._provider + + def create_provider( + self, + context: dict[Any, Any] | None, + builder: ActivationBuilder, + provider_filter: IProviderFilter, + ) -> BaseProvider: + """Create aggregated provider with activation filtering applied. + + Args: + context: Context dict for activation decisions. + builder: Activation builder for checking if types are registered. + provider_filter: Filter strategy for conditional provider activation. + + Returns: + BaseProvider with only active providers aggregated. + """ + active_providers = provider_filter.filter( + list(self.providers), + context=context, + module_type=self.target, + builder=builder, + ) + cls = cast(type[_ModuleProvider], type(f'{self.name}Provider', (_ModuleProvider,), {})) - return cls(self.providers) + self._provider = cls(active_providers) + return self._provider def __str__(self) -> str: return self.__repr__() @@ -62,7 +103,9 @@ def __eq__(self, other: object) -> bool: class _ModuleProvider(BaseProvider): - def __init__(self, providers: Iterable[BaseProvider]) -> None: + """Aggregates factories from filtered providers.""" + + def __init__(self, providers: Iterable[Provider]) -> None: super().__init__(DEFAULT_COMPONENT) for provider in providers: self.factories.extend(provider.factories) diff --git a/src/waku/modules/_registry_builder.py b/src/waku/modules/_registry_builder.py index 9a02f62b..05161506 100644 --- a/src/waku/modules/_registry_builder.py +++ b/src/waku/modules/_registry_builder.py @@ -1,9 +1,10 @@ from __future__ import annotations from collections import OrderedDict, defaultdict -from typing import TYPE_CHECKING, Final, TypeAlias +from typing import TYPE_CHECKING, Any, Final, TypeAlias from uuid import UUID +from waku.di import ConditionalProvider, IProviderFilter, ProviderFilter from waku.modules import Module, ModuleCompiler, ModuleMetadata, ModuleRegistry, ModuleType if TYPE_CHECKING: @@ -20,17 +21,40 @@ AdjacencyMatrix: TypeAlias = dict[UUID, OrderedDict[UUID, str]] +class _ActivationBuilder: + """Build-time activation builder for checking registered types.""" + + def __init__(self) -> None: + self._registered_types: set[Any] = set() + + def register(self, type_: Any) -> None: + self._registered_types.add(type_) + + def has_active(self, type_: Any) -> bool: + return type_ in self._registered_types + + class ModuleRegistryBuilder: - def __init__(self, root_module_type: ModuleType, compiler: ModuleCompiler | None = None) -> None: + def __init__( + self, + root_module_type: ModuleType, + compiler: ModuleCompiler | None = None, + context: dict[Any, Any] | None = None, + provider_filter: IProviderFilter | None = None, + ) -> None: self._compiler: Final = compiler or ModuleCompiler() self._root_module_type: Final = root_module_type + self._context: Final = context + self._provider_filter: Final[IProviderFilter] = provider_filter or ProviderFilter() self._modules: dict[UUID, Module] = {} self._providers: list[BaseProvider] = [] self._metadata_cache: dict[ModuleType | DynamicModule, tuple[ModuleType, ModuleMetadata]] = {} + self._builder: Final = _ActivationBuilder() def build(self) -> ModuleRegistry: modules, adjacency = self._collect_modules() + self._build_type_registry(modules) root_module = self._register_modules(modules) return self._build_registry(root_module, adjacency) @@ -64,6 +88,16 @@ def _collect_modules_recursive( post_order.append((type_, metadata)) visited.add(metadata.id) + def _build_type_registry(self, modules: list[tuple[ModuleType, ModuleMetadata]]) -> None: + """Build registry of all provided types before filtering.""" + for _, metadata in modules: + for spec in metadata.providers: + if isinstance(spec, ConditionalProvider): + self._builder.register(spec.provided_type) + else: + for factory in spec.factories: + self._builder.register(factory.provides.type_hint) + def _register_modules(self, post_order: list[tuple[ModuleType, ModuleMetadata]]) -> Module: for type_, metadata in post_order: if metadata.id in self._modules: @@ -75,7 +109,13 @@ def _register_modules(self, post_order: list[tuple[ModuleType, ModuleMetadata]]) module = Module(type_, metadata) self._modules[module.id] = module - self._providers.append(module.provider) + self._providers.append( + module.create_provider( + context=self._context, + builder=self._builder, + provider_filter=self._provider_filter, + ) + ) _, root_metadata = self._get_metadata(self._root_module_type) return self._modules[root_metadata.id] @@ -87,7 +127,6 @@ def _get_metadata(self, module_type: ModuleType | DynamicModule) -> tuple[Module return self._metadata_cache[module_type] def _build_registry(self, root_module: Module, adjacency: AdjacencyMatrix) -> ModuleRegistry: - # Store topological order (post_order DFS) for event triggering return ModuleRegistry( compiler=self._compiler, modules=self._modules, diff --git a/tests/data.py b/tests/data.py index c5faf8a2..5e2cc074 100644 --- a/tests/data.py +++ b/tests/data.py @@ -4,7 +4,7 @@ from typing import NewType from waku import Module -from waku.di import BaseProvider +from waku.di import ProviderSpec from waku.extensions import OnModuleConfigure, OnModuleDestroy, OnModuleInit from waku.modules import ModuleMetadata @@ -99,7 +99,7 @@ async def on_module_destroy(self, module: Module) -> None: class AddDepOnConfigure(OnModuleConfigure): """Extension that adds a dependency during module configuration.""" - def __init__(self, provider: BaseProvider) -> None: + def __init__(self, provider: ProviderSpec) -> None: self.provider = provider def on_module_configure(self, metadata: ModuleMetadata) -> None: diff --git a/tests/di/activation/__init__.py b/tests/di/activation/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/di/activation/test_activation_context.py b/tests/di/activation/test_activation_context.py new file mode 100644 index 00000000..ab5a558c --- /dev/null +++ b/tests/di/activation/test_activation_context.py @@ -0,0 +1,91 @@ +import pytest + +from waku.di import ActivationContext, Has + + +class _MockBuilder: + def __init__(self, registered: set[type] | None = None) -> None: + self._registered = registered or set() + + def has_active(self, type_: object) -> bool: + return type_ in self._registered + + +class TestActivationContextFields: + @staticmethod + def test_container_context_field() -> None: + ctx = ActivationContext( + container_context={'key': 'value'}, + module_type=object, + provided_type=str, + builder=_MockBuilder(), + ) + + assert ctx.container_context == {'key': 'value'} + + @staticmethod + def test_module_type_field() -> None: + class MyModule: + pass + + ctx = ActivationContext( + container_context={}, + module_type=MyModule, + provided_type=str, + builder=_MockBuilder(), + ) + + assert ctx.module_type is MyModule + + @staticmethod + def test_provided_type_field() -> None: + ctx = ActivationContext( + container_context={}, + module_type=object, + provided_type=int, + builder=_MockBuilder(), + ) + + assert ctx.provided_type is int + + @staticmethod + def test_builder_has_active() -> None: + builder = _MockBuilder(registered={str, int}) + + ctx = ActivationContext( + container_context={}, + module_type=object, + provided_type=int, + builder=builder, + ) + + assert ctx.builder.has_active(str) is True + assert ctx.builder.has_active(float) is False + + +class TestHasActivator: + @staticmethod + @pytest.mark.parametrize( + ('registered', 'check_type', 'expected'), + [ + pytest.param({str, int}, str, True, id='type_registered'), + pytest.param({str}, float, False, id='type_not_registered'), + pytest.param(set(), str, False, id='empty_registry'), + ], + ) + def test_checks_if_type_is_registered( + registered: set[type], + check_type: type, + expected: bool, + ) -> None: + builder = _MockBuilder(registered=registered) + ctx = ActivationContext( + container_context={}, + module_type=object, + provided_type=str, + builder=builder, + ) + + has = Has(check_type) + + assert has(ctx) is expected diff --git a/tests/di/activation/test_activation_integration.py b/tests/di/activation/test_activation_integration.py new file mode 100644 index 00000000..ef8ac872 --- /dev/null +++ b/tests/di/activation/test_activation_integration.py @@ -0,0 +1,393 @@ +from dataclasses import dataclass +from typing import Any, Protocol + +import pytest +from dishka import Provider +from dishka.exceptions import GraphMissingFactoryError, NoFactoryError + +from waku import DynamicModule, WakuFactory +from waku.di import ( + ActivationBuilder, + ActivationContext, + ConditionalProvider, + IProviderFilter, + ProviderFilter, + ProviderSpec, + scoped, + singleton, +) +from waku.modules import ModuleType + +from tests.data import A, B, Service +from tests.module_utils import create_basic_module + + +class _MockBuilder: + def __init__(self, registered: set[type] | None = None) -> None: + self._registered = registered or set() + + def has_active(self, type_: object) -> bool: + return type_ in self._registered + + +def when_redis(ctx: ActivationContext) -> bool: + return bool(ctx.container_context.get('use_redis')) if ctx.container_context else False + + +def when_production(ctx: ActivationContext) -> bool: + return ctx.container_context.get('environment') == 'production' if ctx.container_context else False + + +def when_debug(ctx: ActivationContext) -> bool: + return bool(ctx.container_context.get('debug')) if ctx.container_context else False + + +def always(_: ActivationContext) -> bool: + return True + + +def never(_: ActivationContext) -> bool: + return False + + +async def test_activated_provider_available_in_container() -> None: + AppModule = create_basic_module( + providers=[scoped(Service, when=when_redis)], + name='AppModule', + ) + + app = WakuFactory(AppModule, context={'use_redis': True}).create() + + async with app, app.container() as container: + result = await container.get(Service) + assert isinstance(result, Service) + + +@pytest.mark.parametrize( + 'context', + [ + pytest.param({'use_redis': False}, id='explicit_false'), + pytest.param({}, id='missing_key'), + ], +) +async def test_deactivated_provider_raises_no_factory_error(context: dict[str, object]) -> None: + AppModule = create_basic_module( + providers=[scoped(Service, when=when_redis)], + name='AppModule', + ) + + app = WakuFactory(AppModule, context=context).create() + + async with app, app.container() as container: + with pytest.raises(NoFactoryError): + await container.get(Service) + + +async def test_unconditional_provider_always_available() -> None: + AppModule = create_basic_module( + providers=[scoped(Service)], + name='AppModule', + ) + + app = WakuFactory(AppModule, context={'use_redis': False}).create() + + async with app, app.container() as container: + result = await container.get(Service) + assert isinstance(result, Service) + + +async def test_multiple_conditionals_for_same_interface() -> None: + @dataclass + class RedisCache: + pass + + @dataclass + class InMemoryCache: + pass + + class ICache(Protocol): + pass + + def when_not_redis(ctx: ActivationContext) -> bool: + return not bool(ctx.container_context.get('use_redis')) if ctx.container_context else True + + AppModule = create_basic_module( + providers=[ + scoped(ICache, RedisCache, when=when_redis), + scoped(ICache, InMemoryCache, when=when_not_redis), + ], + name='AppModule', + ) + + redis_app = WakuFactory(AppModule, context={'use_redis': True}).create() + async with redis_app, redis_app.container() as container: + result = await container.get(ICache) + assert isinstance(result, RedisCache) + + inmem_app = WakuFactory(AppModule, context={'use_redis': False}).create() + async with inmem_app, inmem_app.container() as container: + result = await container.get(ICache) + assert isinstance(result, InMemoryCache) + + +async def test_some_providers_activated_some_not() -> None: + @dataclass + class DebugLogger: + pass + + @dataclass + class ProductionService: + pass + + AppModule = create_basic_module( + providers=[ + scoped(DebugLogger, when=when_debug), + scoped(ProductionService, when=when_production), + ], + name='AppModule', + ) + + app = WakuFactory( + AppModule, + context={'debug': True, 'environment': 'development'}, + ).create() + + async with app, app.container() as container: + debug = await container.get(DebugLogger) + assert isinstance(debug, DebugLogger) + + with pytest.raises(NoFactoryError): + await container.get(ProductionService) + + +async def test_conditional_dependency_available_when_active() -> None: + AppModule = create_basic_module( + providers=[ + scoped(A, when=when_redis), + scoped(B), + ], + name='AppModule', + ) + + app = WakuFactory(AppModule, context={'use_redis': True}).create() + + async with app, app.container() as container: + b = await container.get(B) + assert isinstance(b, B) + assert isinstance(b.a, A) + + +def test_conditional_dependency_fails_graph_validation_when_inactive() -> None: + AppModule = create_basic_module( + providers=[ + scoped(A, when=when_redis), + scoped(B), + ], + name='AppModule', + ) + + with pytest.raises(GraphMissingFactoryError): + WakuFactory(AppModule, context={'use_redis': False}).create() + + +async def test_activation_none_creates_empty_context() -> None: + AppModule = create_basic_module( + providers=[scoped(Service, when=when_redis)], + name='AppModule', + ) + + app = WakuFactory(AppModule, context=None).create() + + async with app, app.container() as container: + with pytest.raises(NoFactoryError): + await container.get(Service) + + +async def test_custom_filter_receives_providers_and_context() -> None: + received: list[tuple[list[ProviderSpec], dict[Any, Any] | None, ModuleType | DynamicModule, ActivationBuilder]] = [] + + class RecordingFilter(IProviderFilter): + def filter( # noqa: PLR6301 + self, + providers: list[ProviderSpec], + context: dict[Any, Any] | None, + module_type: ModuleType | DynamicModule, + builder: ActivationBuilder, + ) -> list[Provider]: + received.append((list(providers), context, module_type, builder)) + return [p if isinstance(p, Provider) else p.provider for p in providers] + + AppModule = create_basic_module( + providers=[scoped(Service)], + name='AppModule', + ) + + app = WakuFactory( + AppModule, + context={'env': 'test'}, + provider_filter=RecordingFilter(), + ).create() + + async with app, app.container() as container: + await container.get(Service) + + assert received + _providers, ctx, _module_type, _builder = received[0] + assert ctx is not None + assert ctx.get('env') == 'test' + + +async def test_custom_filter_can_always_include() -> None: + class AlwaysIncludeFilter(IProviderFilter): + def filter( # noqa: PLR6301 + self, + providers: list[ProviderSpec], + context: dict[Any, Any] | None, # noqa: ARG002 + module_type: ModuleType | DynamicModule, # noqa: ARG002 + builder: ActivationBuilder, # noqa: ARG002 + ) -> list[Provider]: + return [p if isinstance(p, Provider) else p.provider for p in providers] + + AppModule = create_basic_module( + providers=[scoped(Service, when=never)], + name='AppModule', + ) + + app = WakuFactory( + AppModule, + provider_filter=AlwaysIncludeFilter(), + ).create() + + async with app, app.container() as container: + result = await container.get(Service) + assert isinstance(result, Service) + + +async def test_custom_filter_can_always_exclude() -> None: + class AlwaysExcludeFilter(IProviderFilter): + def filter( # noqa: PLR6301 + self, + providers: list[ProviderSpec], # noqa: ARG002 + context: dict[Any, Any] | None, # noqa: ARG002 + module_type: ModuleType | DynamicModule, # noqa: ARG002 + builder: ActivationBuilder, # noqa: ARG002 + ) -> list[Provider]: + return [] + + AppModule = create_basic_module( + providers=[scoped(Service)], + name='AppModule', + ) + + app = WakuFactory( + AppModule, + provider_filter=AlwaysExcludeFilter(), + ).create() + + async with app, app.container() as container: + with pytest.raises(NoFactoryError): + await container.get(Service) + + +def test_on_skip_called_during_factory_creation() -> None: + skipped: list[ConditionalProvider] = [] + + def record_skip(cond: ConditionalProvider, _: ActivationContext) -> None: + skipped.append(cond) + + filter_ = ProviderFilter(on_skip=record_skip) + + AppModule = create_basic_module( + providers=[ + scoped(Service, when=never), + scoped(A, when=never), + ], + name='AppModule', + ) + + WakuFactory(AppModule, provider_filter=filter_).create() + + assert len(skipped) == 2 + + +async def test_production_only_provider() -> None: + @dataclass + class ProductionCache: + pass + + AppModule = create_basic_module( + providers=[singleton(ProductionCache, when=when_production)], + name='AppModule', + ) + + prod_app = WakuFactory( + AppModule, + context={'environment': 'production'}, + ).create() + + async with prod_app, prod_app.container() as container: + result = await container.get(ProductionCache) + assert isinstance(result, ProductionCache) + + dev_app = WakuFactory( + AppModule, + context={'environment': 'development'}, + ).create() + + async with dev_app, dev_app.container() as container: + with pytest.raises(NoFactoryError): + await container.get(ProductionCache) + + +async def test_multiple_environment_conditions() -> None: + @dataclass + class StagingOrProdService: + pass + + def when_staging_or_prod(ctx: ActivationContext) -> bool: + if not ctx.container_context: + return False + env = ctx.container_context.get('environment') + return env in {'staging', 'production'} + + AppModule = create_basic_module( + providers=[scoped(StagingOrProdService, when=when_staging_or_prod)], + name='AppModule', + ) + + for env in ('staging', 'production'): + app = WakuFactory(AppModule, context={'environment': env}).create() + async with app, app.container() as container: + result = await container.get(StagingOrProdService) + assert isinstance(result, StagingOrProdService) + + dev_app = WakuFactory(AppModule, context={'environment': 'development'}).create() + async with dev_app, dev_app.container() as container: + with pytest.raises(NoFactoryError): + await container.get(StagingOrProdService) + + +async def test_always_true_predicate() -> None: + AppModule = create_basic_module( + providers=[scoped(Service, when=always)], + name='AppModule', + ) + + app = WakuFactory(AppModule, context={}).create() + + async with app, app.container() as container: + result = await container.get(Service) + assert isinstance(result, Service) + + +async def test_always_false_predicate() -> None: + AppModule = create_basic_module( + providers=[scoped(Service, when=never)], + name='AppModule', + ) + + app = WakuFactory(AppModule, context={'everything': True}).create() + + async with app, app.container() as container: + with pytest.raises(NoFactoryError): + await container.get(Service) diff --git a/tests/di/activation/test_conditional_providers.py b/tests/di/activation/test_conditional_providers.py new file mode 100644 index 00000000..c15978ad --- /dev/null +++ b/tests/di/activation/test_conditional_providers.py @@ -0,0 +1,136 @@ +from collections.abc import Callable, Sequence +from typing import Any + +import pytest +from dishka import Provider + +from waku.di import ( + ActivationContext, + ConditionalProvider, + contextual, + many, + object_, + scoped, + singleton, + transient, +) + +from tests.data import A, B, Service + + +def always(_: ActivationContext) -> bool: + return True + + +class TestProviderFunctionsWithoutWhen: + @staticmethod + @pytest.mark.parametrize( + 'provider_func', + [singleton, scoped, transient, contextual], + ids=['singleton', 'scoped', 'transient', 'contextual'], + ) + def test_simple_provider_returns_provider_instance( + provider_func: Callable[..., Any], + ) -> None: + result = provider_func(Service) + + assert isinstance(result, Provider) + + @staticmethod + @pytest.mark.parametrize( + 'provider_func', + [singleton, scoped, transient], + ids=['singleton', 'scoped', 'transient'], + ) + def test_interface_implementation_returns_provider( + provider_func: Callable[..., Any], + ) -> None: + result = provider_func(A, B) + + assert isinstance(result, Provider) + + @staticmethod + def test_object_returns_provider() -> None: + instance = Service() + + result = object_(instance, provided_type=Service) + + assert isinstance(result, Provider) + + @staticmethod + def test_many_returns_provider() -> None: + result = many(Service, Service) + + assert isinstance(result, Provider) + + +class TestProviderFunctionsWithWhen: + @staticmethod + @pytest.mark.parametrize( + 'provider_func', + [singleton, scoped, transient, contextual], + ids=['singleton', 'scoped', 'transient', 'contextual'], + ) + def test_simple_provider_returns_conditional_provider( + provider_func: Callable[..., Any], + ) -> None: + result = provider_func(Service, when=always) + + assert isinstance(result, ConditionalProvider) + assert isinstance(result.provider, Provider) + assert result.provided_type is Service + + @staticmethod + @pytest.mark.parametrize( + 'provider_func', + [singleton, scoped, transient], + ids=['singleton', 'scoped', 'transient'], + ) + def test_interface_implementation_returns_conditional_provider( + provider_func: Callable[..., Any], + ) -> None: + result = provider_func(A, B, when=always) + + assert isinstance(result, ConditionalProvider) + assert result.provided_type is A + + @staticmethod + def test_object_returns_conditional_provider() -> None: + instance = Service() + + result = object_(instance, provided_type=Service, when=always) + + assert isinstance(result, ConditionalProvider) + assert result.provided_type is Service + + @staticmethod + def test_many_returns_conditional_provider() -> None: + result = many(Service, Service, when=always) + + assert isinstance(result, ConditionalProvider) + assert result.provided_type == Sequence[Service] + + +class TestPredicateAttachment: + @staticmethod + @pytest.mark.parametrize( + ('provider_func', 'args'), + [ + (singleton, (Service,)), + (scoped, (Service,)), + (transient, (Service,)), + (contextual, (Service,)), + ], + ids=['singleton', 'scoped', 'transient', 'contextual'], + ) + def test_predicate_is_correctly_attached( + provider_func: Callable[..., Any], + args: tuple[Any, ...], + ) -> None: + def custom_predicate(_: ActivationContext) -> bool: + return True + + result = provider_func(*args, when=custom_predicate) + + assert isinstance(result, ConditionalProvider) + assert result.when is custom_predicate