diff --git a/README.md b/README.md index 0880b4df..223c6c93 100644 --- a/README.md +++ b/README.md @@ -1,15 +1,160 @@ -## DIshka - DI by Tishka17 +## DIshka (from russian "small DI") -Minimal DI framework with scopes +Small DI framework with scopes and agreeable API. + +### Purpose + +This library is targeting to provide only an IoC-container. If you are tired manually passing objects to create others objects which are only used to create more object - we have a solution. Otherwise, you do not probably need a IoC-container but check what we have. + +Unlike other instruments we are not trying to solve tasks not related to dependency injection. We want to keep DI in place, not soiling you code with global variables and additional specifiers in all places. + +Main ideas: +* **Scopes**. Any object can have lifespan of the whole app, single request or even more fractionally. Many frameworks do not have scopes or have only 2 of them. Here you can have as many scopes as you need. +* **Finalization**. Some dependencies like database connections must be not only created, but carefully released. Many framework lack this essential feature +* **Modular providers**. Instead of creating lots of separate functions or contrariwise a big single class, you can split your factories into several classes, which makes them simpler reusable. +* **Clean dependencies**. You do not need to add custom markers to the code of dependencies so to allow library to see them. All customization is done within providers code and only borders of scopes have to deal with library API. +* **Simple API**. You need minimum of objects to start using library. You can easily integrate it with your task framework, examples provided. +* **Speed**. It is fast enough so you not to worry about. It is even faster than many of the analogs. + +See more in [technical requirements](docs/technical_requirements.md) ### Quickstart -1. Create Scopes enum -2. Create Provider subclass. -3. Mark methods which actually create depedencies with `@provide` decorator -4. Do not forget typehints -5. Create Container instance passing providers -6. Call `get` to get dependency and use context manager to get deeper through scopes -7. Add decorators and middleware for your framework +1. Create Provider subclass. +```python +from dishka import Provider +class MyProvider(Provider): + ... +``` +2. Mark methods which actually create dependencies with `@provide` decorator with carefully arranged scopes. Do not forget to place correct typehints for parameters and result. +Here we describe how to create instances of A and B classes, where B class requires itself an instance of A. +```python +from dishka import provide, Provider, Scope +class MyProvider(Provider): + @provide(scope=Scope.APP) + def get_a(self) -> A: + return A() + + @provide(scope=Scope.REQUEST) + def get_b(self, a: A) -> B: + return B(a) +``` +4. Create Container instance passing providers, and step into `APP` scope. Or deeper if you need. +```python +with make_container(MyProvider()) as container: # enter Scope.APP + with container() as request_container: # enter Scope.REQUEST + ... +``` + +5. Call `get` to get dependency and use context manager to get deeper through scopes +```python +with make_container(MyProvider()) as container: + a = container.get(A) # `A` has Scope.APP, so it is accessible here + with container() as request_container: + b = request_container.get(B) # `B` has Scope.REQUEST + a = request_container.get(A) # `A` is accessible here too +``` + +6. Add decorators and middleware for your framework (_would be described soon_) + +See [examples](examples) + +### Concepts + +**Dependency** is what you need for some part of your code to work. They are just object which you do not create in place and probably want to replace some day. At least for tests. +Some of them can live while you application is running, others are destroyed and created on each request. Dependencies can depend on other objects, which are their dependencies. + +**Scope** is a lifespan of a dependency. Standard scopes are: + + `APP` -> `REQUEST` -> `ACTION` -> `STEP`. + +You decide when to enter and exit them, but it is done one by one. You set a scope for your dependency when you configure how to create it. If the same dependency is requested multiple time within one scope without leaving it, then the same instance is returned. + +If you are developing web application, you would enter `APP` scope on startup, and you would `REQUEST` scope in each HTTP-request. + +You can provide your own Scopes class if you are not satisfied with standard flow. + +**Container** is what you use to get your dependency. You just call `.get(SomeType)` and it finds a way to get you an instance of that type. It does not create things itself, but manages their lifecycle and caches. It delegates objects creation to providers which are passed during creation. + + +**Provider** is a collection of functions which really provide some objects. +Provider itself is a class with some attributes and methods. Each of them is either result of `provide`, `alias` or `decorate`. + +`@provide` can be used as a decorator for some method. This method will be called when corresponding dependency has to be created. Name of the method is not important: just check that it is different form other `Provider` attributes. Type hints do matter: they show what this method creates and what does it require. All method parameters are treated as other dependencies and created using container. + +If `provide` is used with some class then that class itself is treated as a factory (`__init__` is analyzed for parameters). But do not forget to assing that call to some attribute otherwise it will be ignored. + + + +### Tips + +* Add method and mark it with `@provide` decorator. It can be sync or async method returning some value. + ```python + class MyProvider(Provider): + @provide(scope=Scope.REQUEST) + def get_a(self) -> A: + return A() + ``` +* Want some finalization when exiting the scope? Make that method generator: + ```python + class MyProvider(Provider): + @provide(scope=Scope.REQUEST) + def get_a(self) -> Iterable[A]: + a = A() + yield a + a.close() + ``` +* Do not have any specific logic and just want to create class using its `__init__`? then add a provider attribute using `provide` as function passing that class. + ```python + class MyProvider(Provider): + a = provide(A, scope=Scope.REQUEST) + ``` +* Want to create a child class instance when parent is requested? add a `dependency` attribute to `provide` function with a parent class while passing child as a first parameter + ```python + class MyProvider(Provider): + a = provide(source=AChild, scope=Scope.REQUEST, provides=A) + ``` +* Having multiple interfaces which can be created as a same class with defined provider? Use alias: + ```python + class MyProvider(Provider): + p = alias(source=A, provides=AProtocol) + ``` + it works the same way as + ```python + class MyProvider(Provider): + @provide(scope=) + def p(self, a: A) -> AProtocol: + return a + ``` + +* Want to apply decorator pattern and do not want to alter existing provide method? Use `decorate`. It will construct object using earlie defined provider and then pass it to your decorator before returning from the container. + ```python + class MyProvider(Provider): + @decorate + def decorate_a(self, a: A) -> A: + return ADecorator(a) + ``` + Decorator function can also have additional parameters. + +* Want to go `async`? Make provide methods asynchronous. Create async container. Use `async with` and await `get` calls: +```python +class MyProvider(Provider): + @provide(scope=Scope.APP) + async def get_a(self) -> A: + return A() + +async with make_async_container(MyProvider()) as container: + a = await container.get(A) +``` + +* Having some data connected with scope which you want to use when solving dependencies? Set it when entering scope. These classes can be used as parameters of your `provide` methods +```python +with make_container(MyProvider(), context={App: app}) as container: + with container(context={RequestClass: request_instance}) as request_container: + pass +``` -See [examples](examples/sync_simple.py) \ No newline at end of file +* Having to many dependencies? Or maybe want to replace only part of them in tests keeping others? Create multiple `Provider` classes +```python +with make_container(MyProvider(), OtherProvider()) as container: +``` diff --git a/docs/technical_requirements.md b/docs/technical_requirements.md new file mode 100644 index 00000000..d95d4aed --- /dev/null +++ b/docs/technical_requirements.md @@ -0,0 +1,55 @@ +## Technical requirements for IoC-container + +#### 1. Scopes + +1. Library should support various number of scopes +2. All dependencies are attached to scopes before any of them can be created +3. There should be default set of scopes +4. Scopes are ordered. Order is defined when declaring scopes. +5. Scope can be entered and exited. +6. Scope can be entered not earlier than enter into previous one. +7. Same scope can be entered multiple times concurrently. +8. If the same dependency is requested more than one time within the scope the same instance is returned. Cache is not shared between concurrent instances of same scope +9. Dependency can require other dependencies of the same or previous scope. + +#### 2. Concurrency + +1. Containers should be allowed to use with multithreading or asyncio. Not required to support both within same object. +2. Dependency creation using async functions should be supported if container is configured to run in asyncio +3. Concurrent entrance of scopes must not break requirement of single instance of dependency. Type of concurrency model can be configured when creating container +4. User of container may be allowed to switch synchronization on or off for performance tuning + +#### 3. Clean dependencies + +1. Usage of container must not require modification of objects we are creating +2. Container must not require to be global variable. +4. Container can require code changes on the borders of scopes (e.g. application start, middlewares, request handlers) + +#### 4. Lifecycle + +1. Dependencies which require some cleanup must be cleaned up on the scope exit +2. Dependencies which do not require cleanup should somehow be supported + +#### 5. Context data + +1. It should be allowed to pass some data when entering the scope +2. Context data must be accessible when creating dependencies + +#### 6. Modularity + +1. There can be multiple containers within same code base for different purposes +2. There must be a way to assemble a container from some reusable parts. +3. Assembling of container should be done in runtime in local scope + +#### 7. Usability + +1. There should be a way to create dependency based on its `__init__` +2. When creating a dependency there should be a way to decide which subtype is used and request only its dependencies +3. There should be a way to reuse same object for multiple requested types +4. There should be a way to decorate dependency just adding new providers + +#### 8. Integration + +1. Additional helpers should be provided for some popular frameworks. E.g: flask, fastapi, aiogram, celery, apscheduler +2. These helpers should be optional +3. Additional integrations should be done without changing library code \ No newline at end of file diff --git a/examples/aiogram_bot.py b/examples/aiogram_bot.py index 860b6832..537f3ef2 100644 --- a/examples/aiogram_bot.py +++ b/examples/aiogram_bot.py @@ -73,7 +73,7 @@ async def start( async def main(): # real main logging.basicConfig(level=logging.INFO) - async with make_async_container(MyProvider(), with_lock=True) as container: + async with make_async_container(MyProvider()) as container: bot = Bot(token=API_TOKEN) dp = Dispatcher() for observer in dp.observers.values(): diff --git a/examples/async_simple.py b/examples/async_simple.py index 2752c379..7c7f74ee 100644 --- a/examples/async_simple.py +++ b/examples/async_simple.py @@ -21,9 +21,7 @@ async def get_str(self, dep: int) -> AsyncGenerator[str, None]: async def main(): - async with make_async_container( - MyProvider(1), with_lock=True, - ) as container: + async with make_async_container(MyProvider(1)) as container: print(await container.get(int)) async with container() as c_request: diff --git a/examples/di/classes.py b/examples/benchmarks/classes.py similarity index 100% rename from examples/di/classes.py rename to examples/benchmarks/classes.py diff --git a/examples/benchmarks/fastapi_app.py b/examples/benchmarks/fastapi_app.py new file mode 100644 index 00000000..c0f8b2d6 --- /dev/null +++ b/examples/benchmarks/fastapi_app.py @@ -0,0 +1,166 @@ +import logging +from contextlib import asynccontextmanager +from inspect import Parameter +from typing import Annotated, Callable, Iterable, NewType, get_type_hints + +import uvicorn +from fastapi import APIRouter +from fastapi import Depends as FastapiDepends +from fastapi import FastAPI, Request + +from dishka import Provider, Scope, make_async_container, provide +from dishka.inject import Depends, wrap_injection + + +# framework level +def inject(func): + hints = get_type_hints(func) + requests_param = next( + (name for name, hint in hints.items() if hint is Request), + None, + ) + if requests_param: + getter = lambda kwargs: kwargs[requests_param].state.container + additional_params = [] + else: + getter = lambda kwargs: kwargs["___r___"].state.container + additional_params = [Parameter( + name="___r___", + annotation=Request, + kind=Parameter.KEYWORD_ONLY, + )] + + return wrap_injection( + func=func, + remove_depends=True, + container_getter=getter, + additional_params=additional_params, + is_async=True, + ) + + +def container_middleware(): + async def add_request_container(request: Request, call_next): + async with request.app.state.container( + {Request: request} + ) as subcontainer: + request.state.container = subcontainer + return await call_next(request) + + return add_request_container + + +class Stub: + def __init__(self, dependency: Callable, **kwargs): + self._dependency = dependency + self._kwargs = kwargs + + def __call__(self): + raise NotImplementedError + + def __eq__(self, other) -> bool: + if isinstance(other, Stub): + return ( + self._dependency == other._dependency + and self._kwargs == other._kwargs + ) + else: + if not self._kwargs: + return self._dependency == other + return False + + def __hash__(self): + if not self._kwargs: + return hash(self._dependency) + serial = ( + self._dependency, + *self._kwargs.items(), + ) + return hash(serial) + + +# app dependency logic + +Host = NewType("Host", str) + + +class B: + def __init__(self, x: int): + pass + + +class C: + def __init__(self, x: int): + pass + + +class A: + def __init__(self, b: B, c: C): + pass + + +MyInt = NewType("MyInt", int) + + +class MyProvider(Provider): + @provide(scope=Scope.REQUEST) + async def get_a(self, b: B, c: C) -> A: + return A(b, c) + + @provide(scope=Scope.REQUEST) + async def get_b(self) -> Iterable[B]: + yield B(1) + + @provide(scope=Scope.REQUEST) + async def get_c(self) -> Iterable[C]: + yield C(1) + + +# app +router = APIRouter() + + +@router.get("/") +@inject +async def index( + *, + value: Annotated[A, Depends()], + value2: Annotated[A, Depends()], +) -> str: + return f"{value} {value is value2}" + + +@router.get("/f") +async def index( + *, + value: Annotated[A, FastapiDepends(Stub(A))], + value2: Annotated[A, FastapiDepends(Stub(A))], +) -> str: + return f"{value} {value is value2}" + + +def new_a(b: B = FastapiDepends(Stub(B)), c: C = FastapiDepends(Stub(C))): + return A(b, c) + + +@asynccontextmanager +async def lifespan(app: FastAPI): + async with make_async_container(MyProvider()) as container: + app.state.container = container + yield + + +def create_app() -> FastAPI: + logging.basicConfig(level=logging.WARNING) + + app = FastAPI(lifespan=lifespan) + app.middleware("http")(container_middleware()) + app.dependency_overrides[A] = new_a + app.dependency_overrides[B] = lambda: B(1) + app.dependency_overrides[C] = lambda: C(1) + app.include_router(router) + return app + + +if __name__ == "__main__": + uvicorn.run(create_app(), host="0.0.0.0", port=8000) diff --git a/examples/di/with_di.py b/examples/benchmarks/with_di.py similarity index 100% rename from examples/di/with_di.py rename to examples/benchmarks/with_di.py diff --git a/examples/di/with_dishka.py b/examples/benchmarks/with_dishka.py similarity index 93% rename from examples/di/with_dishka.py rename to examples/benchmarks/with_dishka.py index 65ea97a2..0011688c 100644 --- a/examples/di/with_dishka.py +++ b/examples/benchmarks/with_dishka.py @@ -5,7 +5,7 @@ class MyProvider(Provider): - a = provide(A1, scope=MyScope.REQUEST, dependency=A) + a = provide(A1, scope=MyScope.REQUEST, provides=A) c1 = provide(CA, scope=MyScope.REQUEST) c2 = provide(CAA, scope=MyScope.REQUEST) c3 = provide(CAAA, scope=MyScope.REQUEST) diff --git a/examples/fastapi_app.py b/examples/fastapi_app.py index d6d9865a..3ce83060 100644 --- a/examples/fastapi_app.py +++ b/examples/fastapi_app.py @@ -1,15 +1,19 @@ import logging +from abc import abstractmethod from contextlib import asynccontextmanager from inspect import Parameter -from typing import Annotated, Callable, Iterable, NewType, get_type_hints +from typing import ( + Annotated, get_type_hints, Protocol, Any, get_origin, + get_args, +) import uvicorn from fastapi import APIRouter -from fastapi import Depends as FastapiDepends from fastapi import FastAPI, Request -from dishka import Provider, Scope, make_async_container, provide -from dishka.inject import Depends, wrap_injection +from dishka import ( + Depends, wrap_injection, Provider, Scope, make_async_container, provide, +) # framework level @@ -50,73 +54,38 @@ async def add_request_container(request: Request, call_next): return add_request_container -class Stub: - def __init__(self, dependency: Callable, **kwargs): - self._dependency = dependency - self._kwargs = kwargs - - def __call__(self): +# app core +class DbGateway(Protocol): + @abstractmethod + def get(self) -> str: raise NotImplementedError - def __eq__(self, other) -> bool: - if isinstance(other, Stub): - return ( - self._dependency == other._dependency - and self._kwargs == other._kwargs - ) - else: - if not self._kwargs: - return self._dependency == other - return False - - def __hash__(self): - if not self._kwargs: - return hash(self._dependency) - serial = ( - self._dependency, - *self._kwargs.items(), - ) - return hash(serial) - - -# app dependency logic - -Host = NewType("Host", str) +class FakeDbGateway(DbGateway): + def get(self) -> str: + return "Hello" -class B: - def __init__(self, x: int): - pass +class Interactor: + def __init__(self, db: DbGateway): + self.db = db -class C: - def __init__(self, x: int): - pass + def __call__(self) -> str: + return self.db.get() -class A: - def __init__(self, b: B, c: C): - pass - - -MyInt = NewType("MyInt", int) - - -class MyProvider(Provider): +# app dependency logic +class AdaptersProvider(Provider): @provide(scope=Scope.REQUEST) - async def get_a(self, b: B, c: C) -> A: - return A(b, c) + def get_db(self) -> DbGateway: + return FakeDbGateway() - @provide(scope=Scope.REQUEST) - async def get_b(self) -> Iterable[B]: - yield B(1) - @provide(scope=Scope.REQUEST) - async def get_c(self) -> Iterable[C]: - yield C(1) +class InteractorProvider(Provider): + i1 = provide(Interactor, scope=Scope.REQUEST) -# app +# presentation layer router = APIRouter() @@ -124,28 +93,18 @@ async def get_c(self) -> Iterable[C]: @inject async def index( *, - value: Annotated[A, Depends()], - value2: Annotated[A, Depends()], + interactor: Annotated[Interactor, Depends()], ) -> str: - return f"{value} {value is value2}" - - -@router.get("/f") -async def index( - *, - value: Annotated[A, FastapiDepends(Stub(A))], - value2: Annotated[A, FastapiDepends(Stub(A))], -) -> str: - return f"{value} {value is value2}" - - -def new_a(b: B = FastapiDepends(Stub(B)), c: C = FastapiDepends(Stub(C))): - return A(b, c) + result = interactor() + return result +# app configuration @asynccontextmanager async def lifespan(app: FastAPI): - async with make_async_container(MyProvider(), with_lock=True) as container: + async with make_async_container( + AdaptersProvider(), InteractorProvider(), + ) as container: app.state.container = container yield @@ -155,9 +114,6 @@ def create_app() -> FastAPI: app = FastAPI(lifespan=lifespan) app.middleware("http")(container_middleware()) - app.dependency_overrides[A] = new_a - app.dependency_overrides[B] = lambda: B(1) - app.dependency_overrides[C] = lambda: C(1) app.include_router(router) return app diff --git a/examples/sync_simple.py b/examples/sync_simple.py index 16298b5e..fcf7ff33 100644 --- a/examples/sync_simple.py +++ b/examples/sync_simple.py @@ -21,7 +21,7 @@ def __init__(self, a: int): self.a = a get_a = provide(A, scope=Scope.REQUEST) - get_basea = alias(A, dependency=BaseA) + get_basea = alias(source=A, provides=BaseA) @provide(scope=Scope.APP) def get_int(self) -> int: @@ -33,7 +33,7 @@ def get_str(self, dep: int) -> Generator[None, str, None]: def main(): - with make_container(MyProvider(1), with_lock=True) as container: + with make_container(MyProvider(1)) as container: print(container.get(int)) with container() as c_request: print(c_request.get(BaseA)) diff --git a/pyproject.toml b/pyproject.toml index 801d9475..4a069662 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -24,10 +24,12 @@ classifiers = [ "License :: OSI Approved :: Apache Software License", "Operating System :: OS Independent", ] -dependencies = [] +dependencies = [ + 'exceptiongroup>=1.1.3; python_version<"3.11"', +] [project.urls] -"Homepage" = "https://github.com/tishka17/dishka" -"Bug Tracker" = "https://github.com/tishka17/dishka/issues" +"Homepage" = "https://github.com/reagento/dishka" +"Bug Tracker" = "https://github.com/reagento/dishka/issues" diff --git a/requirements_dev.txt b/requirements_dev.txt index 27de1df4..2102e4da 100644 --- a/requirements_dev.txt +++ b/requirements_dev.txt @@ -1,3 +1,4 @@ -ruff -pytest -pytest-asyncio +ruff==0.1.* +pytest==7.* +pytest-asyncio==0.23.* +pytest-repeat==0.9.* diff --git a/src/dishka/__init__.py b/src/dishka/__init__.py index e0fc3eb9..70872a59 100644 --- a/src/dishka/__init__.py +++ b/src/dishka/__init__.py @@ -2,12 +2,14 @@ "make_async_container", "AsyncContainer", "make_container", "Container", "Depends", "wrap_injection", - "Provider", "provide", "alias", + "Provider", + "alias", "decorate", "provide", "BaseScope", "Scope", ] from .async_container import AsyncContainer, make_async_container from .container import Container, make_container +from .dependency_source import alias, decorate, provide from .inject import Depends, wrap_injection -from .provider import Provider, alias, provide +from .provider import Provider from .scope import BaseScope, Scope diff --git a/src/dishka/async_container.py b/src/dishka/async_container.py index f381100e..14679034 100644 --- a/src/dishka/async_container.py +++ b/src/dishka/async_container.py @@ -2,8 +2,10 @@ from dataclasses import dataclass from typing import Callable, List, Optional, Type, TypeVar -from .provider import DependencyProvider, Provider, ProviderType -from .registry import Registry, make_registry +from .dependency_source import Factory, FactoryType +from .exceptions import ExitExceptionGroup +from .provider import Provider +from .registry import Registry, make_registries from .scope import BaseScope, Scope T = TypeVar("T") @@ -12,14 +14,14 @@ @dataclass class Exit: __slots__ = ("type", "callable") - type: ProviderType + type: FactoryType callable: Callable class AsyncContainer: __slots__ = ( "registry", "child_registries", "context", "parent_container", - "lock", "exits", + "lock", "_exits", ) def __init__( @@ -28,68 +30,68 @@ def __init__( *child_registries: Registry, parent_container: Optional["AsyncContainer"] = None, context: Optional[dict] = None, - with_lock: bool = False, + lock_factory: Callable[[], Lock] | None = None, ): self.registry = registry self.child_registries = child_registries - self.context = {} + self.context = {type(self): self} if context: self.context.update(context) self.parent_container = parent_container - if with_lock: - self.lock = Lock() + if lock_factory: + self.lock = lock_factory() else: self.lock = None - self.exits: List[Exit] = [] + self._exits: List[Exit] = [] - def _get_child( + def _create_child( self, context: Optional[dict], - with_lock: bool, + lock_factory: Callable[[], Lock] | None, ) -> "AsyncContainer": return AsyncContainer( *self.child_registries, parent_container=self, context=context, - with_lock=with_lock, + lock_factory=lock_factory, ) def __call__( self, context: Optional[dict] = None, - with_lock: bool = False, + lock_factory: Callable[[], Lock] | None = None, ) -> "AsyncContextWrapper": + """ + Prepare container for entering the inner scope. + :param context: Data which will available in inner scope + :param lock_factory: Callable to create lock instance or None + :return: async context manager for inner scope + """ if not self.child_registries: raise ValueError("No child scopes found") - return AsyncContextWrapper(self._get_child(context, with_lock)) + return AsyncContextWrapper(self._create_child(context, lock_factory)) - async def _get_parent(self, dependency_type: Type[T]) -> T: - return await self.parent_container.get(dependency_type) - - async def _get_self( - self, - dep_provider: DependencyProvider, - ) -> T: + async def _get_from_self(self, factory: Factory) -> T: sub_dependencies = [ await self._get_unlocked(dependency) - for dependency in dep_provider.dependencies + for dependency in factory.dependencies ] - if dep_provider.type is ProviderType.GENERATOR: - generator = dep_provider.callable(*sub_dependencies) - self.exits.append(Exit(dep_provider.type, generator)) + if factory.type is FactoryType.GENERATOR: + generator = factory.source(*sub_dependencies) + self._exits.append(Exit(factory.type, generator)) return next(generator) - elif dep_provider.type is ProviderType.ASYNC_GENERATOR: - generator = dep_provider.callable(*sub_dependencies) - self.exits.append(Exit(dep_provider.type, generator)) + elif factory.type is FactoryType.ASYNC_GENERATOR: + generator = factory.source(*sub_dependencies) + self._exits.append(Exit(factory.type, generator)) return await anext(generator) - elif dep_provider.type is ProviderType.ASYNC_FACTORY: - return await dep_provider.callable(*sub_dependencies) - elif dep_provider.type is ProviderType.FACTORY: - return dep_provider.callable(*sub_dependencies) - elif dep_provider.type is ProviderType.VALUE: - return dep_provider.callable + elif factory.type is FactoryType.ASYNC_FACTORY: + return await factory.source(*sub_dependencies) + elif factory.type is FactoryType.FACTORY: + return factory.source(*sub_dependencies) + elif factory.type is FactoryType.VALUE: + return factory.source else: - raise ValueError(f"Unsupported type {dep_provider.type}") + raise ValueError(f"Unsupported type {factory.type}") async def get(self, dependency_type: Type[T]) -> T: lock = self.lock @@ -106,26 +108,26 @@ async def _get_unlocked(self, dependency_type: Type[T]) -> T: if not self.parent_container: raise ValueError(f"No provider found for {dependency_type!r}") return await self.parent_container.get(dependency_type) - solved = await self._get_self(provider) + solved = await self._get_from_self(provider) self.context[dependency_type] = solved return solved async def close(self): - e = None - for exit_generator in self.exits: + errors = [] + for exit_generator in self._exits[::-1]: try: - if exit_generator.type is ProviderType.ASYNC_GENERATOR: + if exit_generator.type is FactoryType.ASYNC_GENERATOR: await anext(exit_generator.callable) - elif exit_generator.type is ProviderType.GENERATOR: + elif exit_generator.type is FactoryType.GENERATOR: next(exit_generator.callable) except StopIteration: pass except StopAsyncIteration: pass except Exception as err: # noqa: BLE001 - e = err - if e: - raise e + errors.append(err) + if errors: + raise ExitExceptionGroup("Cleanup context errors", errors) class AsyncContextWrapper: @@ -143,12 +145,11 @@ def make_async_container( *providers: Provider, scopes: Type[BaseScope] = Scope, context: Optional[dict] = None, - with_lock: bool = False, + lock_factory: Callable[[], Lock] | None = Lock, ) -> AsyncContextWrapper: - registries = [ - make_registry(*providers, scope=scope) - for scope in scopes - ] - return AsyncContextWrapper( - AsyncContainer(*registries, context=context, with_lock=with_lock), - ) + registries = make_registries(*providers, scopes=scopes) + return AsyncContextWrapper(AsyncContainer( + *registries, + context=context, + lock_factory=lock_factory, + )) diff --git a/src/dishka/container.py b/src/dishka/container.py index 6e239724..2d688a06 100644 --- a/src/dishka/container.py +++ b/src/dishka/container.py @@ -1,17 +1,27 @@ +from dataclasses import dataclass from threading import Lock -from typing import Optional, Type, TypeVar +from typing import Callable, List, Optional, Type, TypeVar -from .provider import DependencyProvider, ProviderType -from .registry import Registry, make_registry +from .dependency_source import Factory, FactoryType +from .exceptions import ExitExceptionGroup +from .provider import Provider +from .registry import Registry, make_registries from .scope import BaseScope, Scope T = TypeVar("T") +@dataclass +class Exit: + __slots__ = ("type", "callable") + type: FactoryType + callable: Callable + + class Container: __slots__ = ( "registry", "child_registries", "context", "parent_container", - "lock", "exits", + "lock", "_exits", ) def __init__( @@ -20,62 +30,62 @@ def __init__( *child_registries: Registry, parent_container: Optional["Container"] = None, context: Optional[dict] = None, - with_lock: bool = False, + lock_factory: Callable[[], Lock] | None = None, ): self.registry = registry self.child_registries = child_registries - self.context = {} + self.context = {type(self): self} if context: self.context.update(context) self.parent_container = parent_container - if with_lock: - self.lock = Lock() + if lock_factory: + self.lock = lock_factory() else: self.lock = None - self.exits = [] + self._exits: List[Exit] = [] - def _get_child( + def _create_child( self, context: Optional[dict], - with_lock: bool, + lock_factory: Callable[[], Lock] | None, ) -> "Container": return Container( *self.child_registries, parent_container=self, context=context, - with_lock=with_lock, + lock_factory=lock_factory, ) def __call__( self, context: Optional[dict] = None, - with_lock: bool = False, + lock_factory: Callable[[], Lock] | None = None, ) -> "ContextWrapper": + """ + Prepare container for entering the inner scope. + :param context: Data which will available in inner scope + :param lock_factory: Callable to create lock instance or None + :return: context manager for inner scope + """ if not self.child_registries: raise ValueError("No child scopes found") - return ContextWrapper(self._get_child(context, with_lock)) - - def _get_parent(self, dependency_type: Type[T]) -> T: - return self.parent_container.get(dependency_type) + return ContextWrapper(self._create_child(context, lock_factory)) - def _get_self( - self, - dep_provider: DependencyProvider, - ) -> T: + def _get_from_self(self, factory: Factory) -> T: sub_dependencies = [ self._get_unlocked(dependency) - for dependency in dep_provider.dependencies + for dependency in factory.dependencies ] - if dep_provider.type is ProviderType.GENERATOR: - generator = dep_provider.callable(*sub_dependencies) - self.exits.append(generator) + if factory.type is FactoryType.GENERATOR: + generator = factory.source(*sub_dependencies) + self._exits.append(Exit(factory.type, generator)) return next(generator) - elif dep_provider.type is ProviderType.FACTORY: - return dep_provider.callable(*sub_dependencies) - elif dep_provider.type is ProviderType.VALUE: - return dep_provider.callable + elif factory.type is FactoryType.FACTORY: + return factory.source(*sub_dependencies) + elif factory.type is FactoryType.VALUE: + return factory.source else: - raise ValueError(f"Unsupported type {dep_provider.type}") + raise ValueError(f"Unsupported type {factory.type}") def get(self, dependency_type: Type[T]) -> T: lock = self.lock @@ -92,24 +102,27 @@ def _get_unlocked(self, dependency_type: Type[T]) -> T: if not self.parent_container: raise ValueError(f"No provider found for {dependency_type!r}") return self.parent_container.get(dependency_type) - solved = self._get_self(provider) + solved = self._get_from_self(provider) self.context[dependency_type] = solved return solved - def close(self): - e = None - for exit_generator in self.exits: + def close(self) -> None: + errors = [] + for exit_generator in self._exits[::-1]: try: - next(exit_generator) + if exit_generator.type is FactoryType.GENERATOR: + next(exit_generator.callable) except StopIteration: pass except Exception as err: # noqa: BLE001 - e = err - if e: - raise e + errors.append(err) + if errors: + raise ExitExceptionGroup("Cleanup context errors", errors) class ContextWrapper: + __slots__ = ("container",) + def __init__(self, container: Container): self.container = container @@ -121,15 +134,12 @@ def __exit__(self, exc_type, exc_val, exc_tb): def make_container( - *providers, + *providers: Provider, scopes: Type[BaseScope] = Scope, context: Optional[dict] = None, - with_lock: bool = False, + lock_factory: Callable[[], Lock] | None = None, ) -> ContextWrapper: - registries = [ - make_registry(*providers, scope=scope) - for scope in scopes - ] + registries = make_registries(*providers, scopes=scopes) return ContextWrapper( - Container(*registries, context=context, with_lock=with_lock), + Container(*registries, context=context, lock_factory=lock_factory), ) diff --git a/src/dishka/dependency_source.py b/src/dishka/dependency_source.py new file mode 100644 index 00000000..f3988f5d --- /dev/null +++ b/src/dishka/dependency_source.py @@ -0,0 +1,260 @@ +from collections.abc import AsyncIterable, Iterable +from enum import Enum +from inspect import ( + isasyncgenfunction, + isclass, + iscoroutinefunction, + isgeneratorfunction, +) +from typing import ( + Any, + Callable, + Optional, + Sequence, + Type, + get_args, + get_origin, + get_type_hints, + overload, +) + +from .scope import BaseScope + + +class FactoryType(Enum): + GENERATOR = "generator" + ASYNC_GENERATOR = "async_generator" + FACTORY = "factory" + ASYNC_FACTORY = "async_factory" + VALUE = "value" + + +def _identity(x: Any) -> Any: + return x + + +class Factory: + __slots__ = ( + "dependencies", "source", "provides", "scope", "type", + "is_to_bound", + ) + + def __init__( + self, + dependencies: Sequence[Any], + source: Any, + provides: Type, + scope: Optional[BaseScope], + type: FactoryType, + is_to_bound: bool, + ): + self.dependencies = dependencies + self.source = source + self.provides = provides + self.scope = scope + self.type = type + self.is_to_bound = is_to_bound + + def __get__(self, instance, owner): + if instance is None: + return self + if self.is_to_bound: + source = self.source.__get__(instance, owner) + else: + source = self.source + return Factory( + dependencies=self.dependencies, + source=source, + provides=self.provides, + scope=self.scope, + type=self.type, + is_to_bound=False, + ) + + +def make_factory( + provides: Any, + scope: Optional[BaseScope], + source: Callable, +) -> Factory: + if isclass(source): + hints = get_type_hints(source.__init__, include_extras=True) + hints.pop("return", None) + possible_dependency = source + is_to_bind = False + else: + hints = get_type_hints(source, include_extras=True) + possible_dependency = hints.pop("return", None) + is_to_bind = True + + if isclass(source): + provider_type = FactoryType.FACTORY + elif isasyncgenfunction(source): + provider_type = FactoryType.ASYNC_GENERATOR + if get_origin(possible_dependency) is AsyncIterable: + possible_dependency = get_args(possible_dependency)[0] + else: # async generator + possible_dependency = get_args(possible_dependency)[0] + elif isgeneratorfunction(source): + provider_type = FactoryType.GENERATOR + if get_origin(possible_dependency) is Iterable: + possible_dependency = get_args(possible_dependency)[0] + else: # generator + possible_dependency = get_args(possible_dependency)[1] + elif iscoroutinefunction(source): + provider_type = FactoryType.ASYNC_FACTORY + else: + provider_type = FactoryType.FACTORY + + return Factory( + dependencies=list(hints.values()), + type=provider_type, + source=source, + scope=scope, + provides=provides or possible_dependency, + is_to_bound=is_to_bind, + ) + + +@overload +def provide( + *, + scope: BaseScope, + provides: Any = None, +) -> Callable[[Callable], Factory]: + ... + + +@overload +def provide( + source: Callable | Type, + *, + scope: BaseScope, + provides: Any = None, +) -> Factory: + ... + + +def provide( + source: Callable | Type | None = None, + *, + scope: BaseScope, + provides: Any = None, +) -> Factory | Callable[[Callable], Factory]: + """ + Mark a method or class as providing some dependency. + + If used as a method decorator then return annotation is used + to determine what is provided. User `provides` to override that. + Method parameters are analyzed and passed automatically. + + If used with a class a first parameter than `__init__` method parameters + are passed automatically. If no provides is passed then it is + supposed that class itself is a provided dependency. + + Return value must be saved as a `Provider` class attribute and + not intended for direct usage + + :param source: Method to decorate or class. + :param scope: Scope of the dependency to limit its lifetime + :param provides: Dependency type which is provided by this factory + :return: instance of Factory or a decorator returning it + """ + if source is not None: + return make_factory(provides, scope, source) + + def scoped(func): + return make_factory(provides, scope, func) + + return scoped + + +class Alias: + __slots__ = ("source", "provides") + + def __init__(self, source, provides): + self.source = source + self.provides = provides + + def as_factory(self, scope: BaseScope) -> Factory: + return Factory( + scope=scope, + source=_identity, + provides=self.provides, + is_to_bound=False, + dependencies=[self.source], + type=FactoryType.FACTORY, + ) + + def __get__(self, instance, owner): + return self + + +def alias( + *, + source: Type, + provides: Type, +) -> Alias: + return Alias( + source=source, + provides=provides, + ) + + +class Decorator: + __slots__ = ("provides", "factory") + + def __init__(self, factory: Factory): + self.factory = factory + self.provides = factory.provides + + def as_factory( + self, scope: BaseScope, new_dependency: Any, + ) -> Factory: + return Factory( + scope=scope, + source=self.factory.source, + provides=self.factory.provides, + is_to_bound=self.factory.is_to_bound, + dependencies=[ + new_dependency if dep is self.provides else dep + for dep in self.factory.dependencies + ], + type=self.factory.type, + ) + + def __get__(self, instance, owner): + return Decorator(self.factory.__get__(instance, owner)) + + +@overload +def decorate( + *, + provides: Any = None, +) -> Callable[[Callable], Decorator]: + ... + + +@overload +def decorate( + source: Callable | Type, + *, + provides: Any = None, +) -> Decorator: + ... + + +def decorate( + source: Callable | Type | None = None, + provides: Any = None, +) -> Decorator | Callable[[Callable], Decorator]: + if source is not None: + return Decorator(make_factory(provides, None, source)) + + def scoped(func): + return Decorator(make_factory(provides, None, func)) + + return scoped + + +DependencySource = Alias | Factory | Decorator diff --git a/src/dishka/exceptions.py b/src/dishka/exceptions.py new file mode 100644 index 00000000..6841be5b --- /dev/null +++ b/src/dishka/exceptions.py @@ -0,0 +1,16 @@ +try: + from builtins import ExceptionGroup +except ImportError: + from exceptiongroup import ExceptionGroup + + +class DishkaException: + pass + + +class InvalidGraphError(DishkaException, ValueError): + pass + + +class ExitExceptionGroup(ExceptionGroup, DishkaException): + pass diff --git a/src/dishka/inject.py b/src/dishka/inject.py index 0b661fb1..1c912d82 100644 --- a/src/dishka/inject.py +++ b/src/dishka/inject.py @@ -17,30 +17,46 @@ def __init__(self, param: Any = None): self.param = param +def default_parse_dependency( + parameter: Parameter, + hint: Any, +) -> Any: + """ Resolve dependency type or return None if it is not a dependency """ + if get_origin(hint) is not Annotated: + return None + dep = next( + (arg for arg in get_args(hint) if isinstance(arg, Depends)), + None, + ) + if not dep: + return None + if dep.param is None: + return get_args(hint)[0] + else: + return dep.param + + +DependencyParser = Callable[[Parameter, Any], Any] + + def wrap_injection( func: Callable, container_getter: Callable[[dict], Container], remove_depends: bool = True, additional_params: Sequence[Parameter] = (), is_async: bool = False, -): + parse_dependency: DependencyParser = default_parse_dependency, +) -> Callable: hints = get_type_hints(func, include_extras=True) func_signature = signature(func) dependencies = {} - for name, hint in hints.items(): - if get_origin(hint) is not Annotated: - continue - dep = next( - (arg for arg in get_args(hint) if isinstance(arg, Depends)), - None, - ) - if not dep: + for name, param in func_signature.parameters.items(): + hint = hints.get(name, Any) + dep = parse_dependency(param, hint) + if dep is None: continue - if dep.param is None: - dependencies[name] = get_args(hint)[0] - else: - dependencies[name] = dep.param + dependencies[name] = dep if remove_depends: new_annotations = { diff --git a/src/dishka/provider.py b/src/dishka/provider.py index 5bf73916..80a327b2 100644 --- a/src/dishka/provider.py +++ b/src/dishka/provider.py @@ -1,164 +1,49 @@ -from collections.abc import AsyncIterable, Iterable -from enum import Enum -from inspect import ( - isasyncgenfunction, - isclass, - iscoroutinefunction, - isgeneratorfunction, -) -from typing import ( - Any, - Callable, - Optional, - Sequence, - Type, - Union, - get_args, - get_origin, - get_type_hints, -) +import inspect +from typing import Any, List -from .scope import BaseScope +from .dependency_source import Alias, Decorator, DependencySource, Factory +from .exceptions import InvalidGraphError -class ProviderType(Enum): - GENERATOR = "generator" - ASYNC_GENERATOR = "async_generator" - FACTORY = "factory" - ASYNC_FACTORY = "async_factory" - VALUE = "value" +def is_dependency_source(attribute: Any) -> bool: + return isinstance(attribute, DependencySource) -class DependencyProvider: - __slots__ = ( - "dependencies", "callable", "result_type", "scope", "type", - "is_to_bound", - ) - - def __init__( - self, - dependencies: Sequence, - callable: Callable, - result_type: Type, - scope: Optional[BaseScope], - type: ProviderType, - is_to_bound: bool, - ): - self.dependencies = dependencies - self.callable = callable - self.result_type = result_type - self.scope = scope - self.type = type - self.is_to_bound = is_to_bound - - def __get__(self, instance, owner): - if instance is None: - return self - if self.is_to_bound: - callable = self.callable.__get__(instance, owner) - else: - callable = self.callable - return DependencyProvider( - dependencies=self.dependencies, - callable=callable, - result_type=self.result_type, - scope=self.scope, - type=self.type, - is_to_bound=False, - ) - - def aliased(self, target: Type): - return DependencyProvider( - dependencies=self.dependencies, - callable=self.callable, - result_type=target, - scope=self.scope, - type=self.type, - is_to_bound=self.is_to_bound, - ) - - -def make_dependency_provider( - dependency: Any, - scope: Optional[BaseScope], - func: Callable, -): - if isclass(func): - hints = get_type_hints(func.__init__, include_extras=True) - hints.pop("return", None) - possible_dependency = func - is_to_bind = False - else: - hints = get_type_hints(func, include_extras=True) - possible_dependency = hints.pop("return", None) - is_to_bind = True - - if isclass(func): - provider_type = ProviderType.FACTORY - elif isasyncgenfunction(func): - provider_type = ProviderType.ASYNC_GENERATOR - if get_origin(possible_dependency) is AsyncIterable: - possible_dependency = get_args(possible_dependency)[0] - else: # async generator - possible_dependency = get_args(possible_dependency)[0] - elif isgeneratorfunction(func): - provider_type = ProviderType.GENERATOR - if get_origin(possible_dependency) is Iterable: - possible_dependency = get_args(possible_dependency)[0] - else: # generator - possible_dependency = get_args(possible_dependency)[1] - elif iscoroutinefunction(func): - provider_type = ProviderType.ASYNC_FACTORY - else: - provider_type = ProviderType.FACTORY - - return DependencyProvider( - dependencies=list(hints.values()), - type=provider_type, - callable=func, - scope=scope, - result_type=dependency or possible_dependency, - is_to_bound=is_to_bind, - ) - - -class Alias: - def __init__(self, target, result_type): - self.target = target - self.result_type = result_type - - -def alias( - target: Type, - dependency: Any = None, -): - return Alias( - target=target, - result_type=dependency, - ) - - -def provide( - func: Union[None, Callable] = None, - *, - scope: BaseScope = None, - dependency: Any = None, -): - if func is not None: - return make_dependency_provider(dependency, scope, func) +class Provider: + """ + A collection of dependency sources. - def scoped(func): - return make_dependency_provider(dependency, scope, func) + Inherit this class and add attributes using + `provide`, `alias` or `decorate`. - return scoped + You can use `__init__`, regular methods and attributes as usual, + they won't be analyzed when creating a container + The only intended usage of providers is to pass them when + creating a container + """ -class Provider: def __init__(self): - self.dependencies = {} - self.aliases = [] - for name, attr in vars(type(self)).items(): - if isinstance(attr, DependencyProvider): - self.dependencies[attr.result_type] = getattr(self, name) - elif isinstance(attr, Alias): - self.aliases.append(attr) + self.factories: List[Factory] = [] + self.aliases: List[Alias] = [] + self.decorators: List[Decorator] = [] + self._init_dependency_sources() + + def _init_dependency_sources(self) -> None: + processed_types = {} + + source: DependencySource + for name, source in inspect.getmembers(self, is_dependency_source): + if source.provides in processed_types: + raise InvalidGraphError( + f"Type {source.provides} is registered multiple times " + f"in the same {Provider} by attributes " + f"{processed_types[source.provides]!r} and {name!r}", + ) + if isinstance(source, Alias): + self.aliases.append(source) + if isinstance(source, Factory): + self.factories.append(source) + if isinstance(source, Decorator): + self.decorators.append(source) + processed_types[source.provides] = name diff --git a/src/dishka/registry.py b/src/dishka/registry.py index 74527d5b..5bc04129 100644 --- a/src/dishka/registry.py +++ b/src/dishka/registry.py @@ -1,35 +1,73 @@ -from typing import Any +from collections import defaultdict +from typing import Any, List, NewType, Type -from .provider import DependencyProvider, Provider +from .dependency_source import Factory +from .exceptions import InvalidGraphError +from .provider import Provider from .scope import BaseScope class Registry: - __slots__ = ("scope", "_providers") + __slots__ = ("scope", "_factories") def __init__(self, scope: BaseScope): - self._providers = {} + self._factories: dict[Type, Factory] = {} self.scope = scope - def add_provider(self, provider: DependencyProvider): - self._providers[provider.result_type] = provider + def add_provider(self, factory: Factory): + self._factories[factory.provides] = factory - def get_provider(self, dependency: Any): - return self._providers.get(dependency) + def get_provider(self, dependency: Any) -> Factory: + return self._factories.get(dependency) -def make_registry(*providers: Provider, scope: BaseScope) -> Registry: - registry = Registry(scope) +def make_registries( + *providers: Provider, scopes: Type[BaseScope], +) -> List[Registry]: + dep_scopes: dict[Type, BaseScope] = {} + alias_sources = {} for provider in providers: - for dependency_provider in provider.dependencies.values(): - if dependency_provider.scope is scope: - registry.add_provider(dependency_provider) + for source in provider.factories: + dep_scopes[source.provides] = source.scope + for source in provider.aliases: + alias_sources[source.provides] = source.source + + registries = {scope: Registry(scope) for scope in scopes} + decorator_depth: dict[Type, int] = defaultdict(int) for provider in providers: - for alias in provider.aliases: - dependency_provider = registry.get_provider(alias.target) - if dependency_provider: - registry.add_provider( - dependency_provider.aliased(alias.result_type), - ) - return registry + for source in provider.factories: + scope = source.scope + registries[scope].add_provider(source) + for source in provider.aliases: + alias_source = source.source + visited_types = [alias_source] + while alias_source not in dep_scopes: + alias_source = alias_sources[alias_source] + if alias_source in visited_types: + raise InvalidGraphError( + f"Cycle aliases detected {visited_types}", + ) + visited_types.append(alias_source) + scope = dep_scopes[alias_source] + dep_scopes[source.provides] = scope + source = source.as_factory(scope) + registries[scope].add_provider(source) + for source in provider.decorators: + provides = source.provides + scope = dep_scopes[provides] + registry = registries[scope] + undecorated_type = NewType( + f"{provides.__name__}@{decorator_depth[provides]}", + source.provides, + ) + decorator_depth[provides] += 1 + old_provider = registry.get_provider(provides) + old_provider.provides = undecorated_type + registry.add_provider(old_provider) + source = source.as_factory( + scope, undecorated_type, + ) + registries[scope].add_provider(source) + + return list(registries.values()) diff --git a/tests/container/__init__.py b/tests/container/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/container/test_alias.py b/tests/container/test_alias.py new file mode 100644 index 00000000..b91315ad --- /dev/null +++ b/tests/container/test_alias.py @@ -0,0 +1,33 @@ +import pytest + +from dishka import Provider, Scope, alias, make_container, provide + + +class AliasProvider(Provider): + @provide(scope=Scope.APP) + def provide_int(self) -> int: + return 42 + + aliased_complex = alias(source=float, provides=complex) + aliased_float = alias(source=int, provides=float) + + +def test_alias(): + with make_container(AliasProvider()) as container: + assert container.get(float) == container.get(int) + + +def test_alias_to_alias(): + with make_container(AliasProvider()) as container: + assert container.get(complex) == container.get(int) + + +class CycleProvider(Provider): + a = alias(source=int, provides=bool) + b = alias(source=bool, provides=float) + c = alias(source=float, provides=int) + + +def test_cycle(): + with pytest.raises(ValueError): + make_container(CycleProvider()) diff --git a/tests/container/test_cache.py b/tests/container/test_cache.py new file mode 100644 index 00000000..32e26d6c --- /dev/null +++ b/tests/container/test_cache.py @@ -0,0 +1,75 @@ +import pytest + +from dishka import ( + Provider, + Scope, + alias, + make_async_container, + make_container, + provide, +) + + +def test_cache_sync(): + class MyProvider(Provider): + value = 0 + + @provide(scope=Scope.REQUEST) + def get_int(self) -> int: + self.value += 1 + return self.value + + with make_container(MyProvider()) as container: + with container() as state: + assert state.get(int) == 1 + assert state.get(int) == 1 + with container() as state: + assert state.get(int) == 2 + assert state.get(int) == 2 + + +@pytest.mark.asyncio +async def test_cache_async(): + class MyProvider(Provider): + value = 0 + + @provide(scope=Scope.REQUEST) + async def get_int(self) -> int: + self.value += 1 + return self.value + + async with make_async_container(MyProvider()) as container: + async with container() as state: + assert await state.get(int) == 1 + assert await state.get(int) == 1 + async with container() as state: + assert await state.get(int) == 2 + assert await state.get(int) == 2 + + +@pytest.fixture() +def alias_provider(): + class MyProvider(Provider): + value = 0 + + @provide(scope=Scope.APP) + def get_int(self) -> int: + self.value += 1 + return self.value + + float = alias(source=int, provides=float) + + return MyProvider() + + +def test_alias_sync(alias_provider): + with make_container(alias_provider) as container: + assert container.get(int) == 1 + assert container.get(float) == 1 + + +@pytest.mark.asyncio +async def test_alias_async(alias_provider): + async with make_async_container(alias_provider) as container: + assert await container.get(int) == 1 + assert await container.get(float) == 1 diff --git a/tests/container/test_concurrency.py b/tests/container/test_concurrency.py new file mode 100644 index 00000000..76c6059a --- /dev/null +++ b/tests/container/test_concurrency.py @@ -0,0 +1,93 @@ +import asyncio +import threading +import time +from concurrent.futures import ThreadPoolExecutor +from unittest.mock import Mock + +import pytest + +from dishka import ( + AsyncContainer, + Container, + Provider, + Scope, + make_async_container, + make_container, + provide, +) + + +class SyncProvider(Provider): + def __init__(self, event: threading.Event, mock: Mock): + super().__init__() + self.event = event + self.mock = mock + + @provide(scope=Scope.APP) + def get_int(self) -> int: + self.event.wait() + return self.mock() + + @provide(scope=Scope.APP) + def get_str(self, value: int) -> str: + return "str" + + +def sync_get(container: Container): + container.get(str) + + +@pytest.mark.repeat(10) +def test_cache_sync(): + int_getter = Mock(return_value=123) + event = threading.Event() + provider = SyncProvider(event, int_getter) + with ThreadPoolExecutor() as pool: + with make_container( + provider, lock_factory=threading.Lock, + ) as container: + pool.submit(sync_get, container) + pool.submit(sync_get, container) + time.sleep(0.01) + event.set() + int_getter.assert_called_once_with() + + +class AsyncProvider(Provider): + def __init__(self, event: asyncio.Event, mock: Mock): + super().__init__() + self.event = event + self.mock = mock + + @provide(scope=Scope.APP) + async def get_int(self) -> int: + await self.event.wait() + return self.mock() + + @provide(scope=Scope.APP) + def get_str(self, value: int) -> str: + return "str" + + +async def async_get(container: AsyncContainer): + await container.get(str) + + +@pytest.mark.repeat(10) +@pytest.mark.asyncio +async def test_cache_async(): + int_getter = Mock(return_value=123) + event = asyncio.Event() + provider = AsyncProvider(event, int_getter) + + async with make_async_container( + provider, lock_factory=asyncio.Lock, + ) as container: + t1 = asyncio.create_task(async_get(container)) + t2 = asyncio.create_task(async_get(container)) + await asyncio.sleep(0.01) + event.set() + await t1 + await t2 + + int_getter.assert_called_once_with() diff --git a/tests/container/test_decorator.py b/tests/container/test_decorator.py new file mode 100644 index 00000000..4ed84de1 --- /dev/null +++ b/tests/container/test_decorator.py @@ -0,0 +1,84 @@ +import pytest + +from dishka import Provider, Scope, alias, decorate, make_container, provide + + +class A: + pass + + +class A1(A): + pass + + +class A2(A1): + pass + + +class ADecorator: + def __init__(self, a: A): + self.a = a + + +def test_simple(): + class MyProvider(Provider): + a = provide(A, scope=Scope.APP) + + class DProvider(Provider): + ad = decorate(ADecorator, provides=A) + + with make_container(MyProvider(), DProvider()) as container: + a = container.get(A) + assert isinstance(a, ADecorator) + assert isinstance(a.a, A) + + +def test_alias(): + class MyProvider(Provider): + a2 = provide(A2, scope=Scope.APP) + a1 = alias(source=A2, provides=A1) + a = alias(source=A1, provides=A) + + class DProvider(Provider): + @decorate + def decorated(self, a: A1) -> A1: + return ADecorator(a) + + with make_container(MyProvider(), DProvider()) as container: + a1 = container.get(A1) + assert isinstance(a1, ADecorator) + assert isinstance(a1.a, A2) + + a2 = container.get(A2) + assert isinstance(a2, A2) + assert a2 is a1.a + + a = container.get(A) + assert a is a1 + + +def test_double_error(): + class MyProvider(Provider): + a = provide(A, scope=Scope.APP) + ad = decorate(ADecorator, provides=A) + ad2 = decorate(ADecorator, provides=A) + + with pytest.raises(ValueError): + MyProvider() + + +def test_double_ok(): + class MyProvider(Provider): + a = provide(A, scope=Scope.APP) + + class DProvider(Provider): + ad = decorate(ADecorator, provides=A) + + class D2Provider(Provider): + ad2 = decorate(ADecorator, provides=A) + + with make_container(MyProvider(), DProvider(), D2Provider()) as container: + a = container.get(A) + assert isinstance(a, ADecorator) + assert isinstance(a.a, ADecorator) + assert isinstance(a.a.a, A) diff --git a/tests/container/test_dynamic.py b/tests/container/test_dynamic.py new file mode 100644 index 00000000..dcc49cca --- /dev/null +++ b/tests/container/test_dynamic.py @@ -0,0 +1,37 @@ +from typing import NewType + +from dishka import Container, Provider, Scope, make_container, provide + +Request = NewType("Request", int) + + +class A: + pass + + +class A0(A): + pass + + +class A1(A): + pass + + +class MyProvider(Provider): + a0 = provide(A0, scope=Scope.APP) + a1 = provide(A1, scope=Scope.APP) + + @provide(scope=Scope.REQUEST) + def get_a(self, container: Container, request: Request) -> A: + if request == 0: + return container.get(A0) + else: + return container.get(A1) + + +def test_dynamic(): + with make_container(MyProvider()) as container: + with container({Request: 0}) as c: + assert type(c.get(A)) is A0 + with container({Request: 1}) as c: + assert type(c.get(A)) is A1 diff --git a/tests/container/test_exceptions.py b/tests/container/test_exceptions.py new file mode 100644 index 00000000..ff84b304 --- /dev/null +++ b/tests/container/test_exceptions.py @@ -0,0 +1,76 @@ +from typing import AsyncIterable, Iterable, NewType +from unittest.mock import Mock + +import pytest + +from dishka import ( + Provider, + Scope, + make_async_container, + make_container, + provide, +) +from dishka.exceptions import ExitExceptionGroup + + +class MyError(Exception): + pass + + +SyncError = NewType("SyncError", int) +SyncFinalizationError = NewType("SyncFinalizationError", int) +AsyncError = NewType("SyncError", int) +AsyncFinalizationError = NewType("SyncFinalizationError", int) + + +class MyProvider(Provider): + def __init__(self, release_mock: Mock): + super().__init__() + self.release_mock = release_mock + + @provide(scope=Scope.APP) + def get_int(self) -> Iterable[int]: + yield 1 + self.release_mock() + + @provide(scope=Scope.APP) + def get1(self, value: int) -> SyncError: + raise MyError + + @provide(scope=Scope.APP) + def get2(self, value: int) -> Iterable[SyncFinalizationError]: + yield value + raise MyError + + @provide(scope=Scope.APP) + async def get3(self, value: int) -> AsyncError: + raise MyError + + @provide(scope=Scope.APP) + async def get4(self, value: int) -> AsyncIterable[AsyncFinalizationError]: + yield value + raise MyError + + +@pytest.mark.parametrize("dep_type", [ + SyncFinalizationError, +]) +def test_sync(dep_type): + finalizer = Mock(return_value=123) + with pytest.raises(ExitExceptionGroup): + with make_container(MyProvider(finalizer)) as container: + container.get(dep_type) + finalizer.assert_called_once() + + +@pytest.mark.parametrize("dep_type", [ + SyncFinalizationError, + AsyncFinalizationError, +]) +@pytest.mark.asyncio +async def test_async(dep_type): + finalizer = Mock(return_value=123) + with pytest.raises(ExitExceptionGroup): + async with make_async_container(MyProvider(finalizer)) as container: + await container.get(dep_type) + finalizer.assert_called_once() diff --git a/tests/test_container.py b/tests/container/test_resolve.py similarity index 54% rename from tests/test_container.py rename to tests/container/test_resolve.py index 21ff9edc..230d0181 100644 --- a/tests/test_container.py +++ b/tests/container/test_resolve.py @@ -7,7 +7,7 @@ make_container, provide, ) -from .sample_providers import ( +from ..sample_providers import ( ClassA, async_func_a, async_gen_a, @@ -66,44 +66,3 @@ def get_int(self) -> int: assert a assert a.dep == 100 assert a.closed == closed - - -def test_cache_sync(): - class MyProvider(Provider): - def __init__(self): - super().__init__() - self.value = 0 - - @provide(scope=Scope.REQUEST) - def get_int(self) -> int: - self.value += 1 - return self.value - - with make_container(MyProvider()) as container: - with container() as state: - assert state.get(int) == 1 - assert state.get(int) == 1 - with container() as state: - assert state.get(int) == 2 - assert state.get(int) == 2 - - -@pytest.mark.asyncio -async def test_cache_async(): - class MyProvider(Provider): - def __init__(self): - super().__init__() - self.value = 0 - - @provide(scope=Scope.REQUEST) - async def get_int(self) -> int: - self.value += 1 - return self.value - - async with make_async_container(MyProvider()) as container: - async with container() as state: - assert await state.get(int) == 1 - assert await state.get(int) == 1 - async with container() as state: - assert await state.get(int) == 2 - assert await state.get(int) == 2 diff --git a/tests/test_privider.py b/tests/test_privider.py deleted file mode 100644 index 612df89e..00000000 --- a/tests/test_privider.py +++ /dev/null @@ -1,48 +0,0 @@ -import pytest - -from dishka import Provider, Scope, alias, provide -from dishka.provider import ProviderType -from .sample_providers import ( - ClassA, - async_func_a, - async_gen_a, - async_iter_a, - sync_func_a, - sync_gen_a, - sync_iter_a, -) - - -def test_provider_init(): - class MyProvider(Provider): - a = alias(int, bool) - b = provide(lambda: False, scope=Scope.APP, dependency=bool) - - @provide(scope=Scope.REQUEST) - def foo(self, x: bool) -> str: - return f"{x}" - - provider = MyProvider() - assert len(provider.dependencies) == 2 - assert len(provider.aliases) == 1 - - -@pytest.mark.parametrize( - "factory, provider_type, is_to_bound", [ - (ClassA, ProviderType.FACTORY, False), - (sync_func_a, ProviderType.FACTORY, True), - (sync_iter_a, ProviderType.GENERATOR, True), - (sync_gen_a, ProviderType.GENERATOR, True), - (async_func_a, ProviderType.ASYNC_FACTORY, True), - (async_iter_a, ProviderType.ASYNC_GENERATOR, True), - (async_gen_a, ProviderType.ASYNC_GENERATOR, True), - ], -) -def test_parse_provider(factory, provider_type, is_to_bound): - dep_provider = provide(factory, scope=Scope.REQUEST) - assert dep_provider.result_type == ClassA - assert dep_provider.dependencies == [int] - assert dep_provider.is_to_bound == is_to_bound - assert dep_provider.scope == Scope.REQUEST - assert dep_provider.callable == factory - assert dep_provider.type == provider_type diff --git a/tests/test_provider.py b/tests/test_provider.py new file mode 100644 index 00000000..578fe024 --- /dev/null +++ b/tests/test_provider.py @@ -0,0 +1,47 @@ +import pytest + +from dishka import Provider, Scope, alias, provide +from dishka.dependency_source import FactoryType +from .sample_providers import ( + ClassA, + async_func_a, + async_gen_a, + async_iter_a, + sync_func_a, + sync_gen_a, + sync_iter_a, +) + + +def test_provider_init(): + class MyProvider(Provider): + a = alias(source=int, provides=bool) + + @provide(scope=Scope.REQUEST) + def foo(self, x: bool) -> str: + return f"{x}" + + provider = MyProvider() + assert len(provider.factories) == 1 + assert len(provider.aliases) == 1 + + +@pytest.mark.parametrize( + "source, provider_type, is_to_bound", [ + (ClassA, FactoryType.FACTORY, False), + (sync_func_a, FactoryType.FACTORY, True), + (sync_iter_a, FactoryType.GENERATOR, True), + (sync_gen_a, FactoryType.GENERATOR, True), + (async_func_a, FactoryType.ASYNC_FACTORY, True), + (async_iter_a, FactoryType.ASYNC_GENERATOR, True), + (async_gen_a, FactoryType.ASYNC_GENERATOR, True), + ], +) +def test_parse_factory(source, provider_type, is_to_bound): + factory = provide(source, scope=Scope.REQUEST) + assert factory.provides == ClassA + assert factory.dependencies == [int] + assert factory.is_to_bound == is_to_bound + assert factory.scope == Scope.REQUEST + assert factory.source == source + assert factory.type == provider_type