diff --git a/.mypy.ini b/.mypy.ini index ac2f644c..5ee6bf3a 100644 --- a/.mypy.ini +++ b/.mypy.ini @@ -1,6 +1,21 @@ [mypy] files = src/dishka -exclude = ^src/dishka/(_adaptix|integrations)/ +exclude = (?x)( + ^src/dishka/_adaptix/ + |^src/dishka/integrations/( + aiohttp + |aiogram + |celery + |click + |fastapi + |faststream + |grpcio + |litestar + |sanic + |starlette + |taskiq + |telebot + ).py) strict = true strict_bytes = true diff --git a/src/dishka/integrations/base.py b/src/dishka/integrations/base.py index 680d1cab..fdb07349 100644 --- a/src/dishka/integrations/base.py +++ b/src/dishka/integrations/base.py @@ -7,6 +7,7 @@ ) from dataclasses import dataclass from enum import Enum +from functools import partial from inspect import ( Parameter, Signature, @@ -54,10 +55,51 @@ ) +class ParameterDependencyResolver: + def __init__( + self, + params: Sequence[Parameter], + dependencies: dict[str, DependencyKey], + ): + self._selected_named_deps: list[tuple[str, DependencyKey]] = [] + self._named_deps_predicates = [] + named_params = {param.name: param for param in params} + named_idxs = {param.name: i for i, param in enumerate(params)} + for name, dep in dependencies.items(): + match named_params[name].kind: + case Parameter.POSITIONAL_OR_KEYWORD: + pred = partial(_has_pos_or_kw, named_idxs[name], name) + case Parameter.KEYWORD_ONLY: + pred = partial(_has_kw_only, name) + case kind: + raise NotImplementedError( + f"Unsupported parameter kind: {kind}", + ) + self._named_deps_predicates.append((name, dep, pred)) + + def bind(self, *args: Any, **kwargs: Any) -> None: + self._selected_named_deps = [ + (name, dep) + for name, dep, has_param in self._named_deps_predicates + if not has_param(*args, **kwargs) + ] + + def items(self) -> Iterator[tuple[str, DependencyKey]]: + return iter(self._selected_named_deps) + + +def _has_pos_or_kw(i: int, name: str, *args: Any, **kwargs: Any) -> bool: + return i < len(args) or name in kwargs + + +def _has_kw_only(name: str, *args: Any, **kwargs: Any) -> bool: + return name in kwargs + + def _get_auto_injected_async_gen_scoped( container_getter: ContainerGetter[AsyncContainer], additional_params: Sequence[Parameter], - dependencies: dict[str, DependencyKey], + dependencies: dict[str, DependencyKey] | ParameterDependencyResolver, func: Callable[P, AsyncIterator[T]], provide_context: ProvideContext | None = None, ) -> Callable[P, AsyncIterator[T]]: @@ -73,6 +115,8 @@ async def auto_injected_generator( ) container = container_getter(args, kwargs) async with container(additional_context) as container: + if isinstance(dependencies, ParameterDependencyResolver): + dependencies.bind(*args, **kwargs) solved = { name: await container.get( dep.type_hint, @@ -89,7 +133,7 @@ async def auto_injected_generator( def _get_auto_injected_async_gen( container_getter: ContainerGetter[AsyncContainer], additional_params: Sequence[Parameter], - dependencies: dict[str, DependencyKey], + dependencies: dict[str, DependencyKey] | ParameterDependencyResolver, func: Callable[P, AsyncIterator[T]], provide_context: ProvideContext | None = None, ) -> Callable[P, AsyncIterator[T]]: @@ -104,6 +148,8 @@ async def auto_injected_generator( kwargs.pop(param.name) container = container_getter(args, kwargs) + if isinstance(dependencies, ParameterDependencyResolver): + dependencies.bind(*args, **kwargs) solved = { name: await container.get( dep.type_hint, @@ -120,7 +166,7 @@ async def auto_injected_generator( def _get_auto_injected_async_func_scoped( container_getter: ContainerGetter[AsyncContainer], additional_params: Sequence[Parameter], - dependencies: dict[str, DependencyKey], + dependencies: dict[str, DependencyKey] | ParameterDependencyResolver, func: Callable[P, Awaitable[T]], provide_context: ProvideContext | None = None, ) -> Callable[P, Awaitable[T]]: @@ -133,6 +179,8 @@ async def auto_injected_func(*args: P.args, **kwargs: P.kwargs) -> T: {} if provide_context is None else provide_context(args, kwargs) ) async with container(additional_context) as container: + if isinstance(dependencies, ParameterDependencyResolver): + dependencies.bind(*args, **kwargs) solved = { name: await container.get( dep.type_hint, @@ -148,7 +196,7 @@ async def auto_injected_func(*args: P.args, **kwargs: P.kwargs) -> T: def _get_auto_injected_async_func( container_getter: ContainerGetter[AsyncContainer], additional_params: Sequence[Parameter], - dependencies: dict[str, DependencyKey], + dependencies: dict[str, DependencyKey] | ParameterDependencyResolver, func: Callable[P, Awaitable[T]], provide_context: ProvideContext | None = None, ) -> Callable[P, Awaitable[T]]: @@ -159,6 +207,8 @@ async def auto_injected_func(*args: P.args, **kwargs: P.kwargs) -> T: container = container_getter(args, kwargs) for param in additional_params: kwargs.pop(param.name) + if isinstance(dependencies, ParameterDependencyResolver): + dependencies.bind(*args, **kwargs) solved = { name: await container.get( dep.type_hint, @@ -174,7 +224,7 @@ async def auto_injected_func(*args: P.args, **kwargs: P.kwargs) -> T: def _get_auto_injected_sync_gen_scoped( container_getter: ContainerGetter[Container], additional_params: Sequence[Parameter], - dependencies: dict[str, DependencyKey], + dependencies: dict[str, DependencyKey] | ParameterDependencyResolver, func: Callable[P, Iterator[T]], provide_context: ProvideContext | None = None, ) -> Callable[P, Iterator[T]]: @@ -190,6 +240,8 @@ def auto_injected_generator( ) container = container_getter(args, kwargs) with container(additional_context) as container: + if isinstance(dependencies, ParameterDependencyResolver): + dependencies.bind(*args, **kwargs) solved = { name: container.get(dep.type_hint, component=dep.component) for name, dep in dependencies.items() @@ -202,7 +254,7 @@ def auto_injected_generator( def _get_auto_injected_sync_gen( container_getter: ContainerGetter[Container], additional_params: Sequence[Parameter], - dependencies: dict[str, DependencyKey], + dependencies: dict[str, DependencyKey] | ParameterDependencyResolver, func: Callable[P, Iterator[T]], provide_context: ProvideContext | None = None, ) -> Callable[P, Iterator[T]]: @@ -216,6 +268,9 @@ def auto_injected_generator( container = container_getter(args, kwargs) for param in additional_params: kwargs.pop(param.name) + + if isinstance(dependencies, ParameterDependencyResolver): + dependencies.bind(*args, **kwargs) solved = { name: container.get(dep.type_hint, component=dep.component) for name, dep in dependencies.items() @@ -228,7 +283,7 @@ def auto_injected_generator( def _get_auto_injected_sync_func_scoped( container_getter: ContainerGetter[Container], additional_params: Sequence[Parameter], - dependencies: dict[str, DependencyKey], + dependencies: dict[str, DependencyKey] | ParameterDependencyResolver, func: Callable[P, T], provide_context: ProvideContext | None = None, ) -> Callable[P, T]: @@ -241,6 +296,8 @@ def auto_injected_func(*args: P.args, **kwargs: P.kwargs) -> T: ) container = container_getter(args, kwargs) with container(additional_context) as container: + if isinstance(dependencies, ParameterDependencyResolver): + dependencies.bind(*args, **kwargs) solved = { name: container.get(dep.type_hint, component=dep.component) for name, dep in dependencies.items() @@ -253,7 +310,7 @@ def auto_injected_func(*args: P.args, **kwargs: P.kwargs) -> T: def _get_auto_injected_sync_func( container_getter: ContainerGetter[Container], additional_params: Sequence[Parameter], - dependencies: dict[str, DependencyKey], + dependencies: dict[str, DependencyKey] | ParameterDependencyResolver, func: Callable[P, T], provide_context: ProvideContext | None = None, ) -> Callable[P, T]: @@ -265,6 +322,8 @@ def auto_injected_func(*args: P.args, **kwargs: P.kwargs) -> T: for param in additional_params: kwargs.pop(param.name) + if isinstance(dependencies, ParameterDependencyResolver): + dependencies.bind(*args, **kwargs) solved = { name: container.get(dep.type_hint, component=dep.component) for name, dep in dependencies.items() @@ -277,8 +336,8 @@ def auto_injected_func(*args: P.args, **kwargs: P.kwargs) -> T: def _get_auto_injected_sync_container_async_gen_scoped( container_getter: ContainerGetter[Container], additional_params: Sequence[Parameter], - dependencies: dict[str, DependencyKey], - func: Callable[P, Iterator[T]], + dependencies: dict[str, DependencyKey] | ParameterDependencyResolver, + func: Callable[P, AsyncIterator[T]], provide_context: ProvideContext | None = None, ) -> Callable[P, AsyncIterator[T]]: async def auto_injected_generator( @@ -293,6 +352,8 @@ async def auto_injected_generator( ) container = container_getter(args, kwargs) with container(additional_context) as container: + if isinstance(dependencies, ParameterDependencyResolver): + dependencies.bind(*args, **kwargs) solved = { name: container.get(dep.type_hint, component=dep.component) for name, dep in dependencies.items() @@ -306,8 +367,8 @@ async def auto_injected_generator( def _get_auto_injected_sync_container_async_gen( container_getter: ContainerGetter[Container], additional_params: Sequence[Parameter], - dependencies: dict[str, DependencyKey], - func: Callable[P, Iterator[T]], + dependencies: dict[str, DependencyKey] | ParameterDependencyResolver, + func: Callable[P, AsyncIterator[T]], provide_context: ProvideContext | None = None, ) -> Callable[P, AsyncIterator[T]]: if provide_context is not None: @@ -320,6 +381,9 @@ async def auto_injected_generator( container = container_getter(args, kwargs) for param in additional_params: kwargs.pop(param.name) + + if isinstance(dependencies, ParameterDependencyResolver): + dependencies.bind(*args, **kwargs) solved = { name: container.get(dep.type_hint, component=dep.component) for name, dep in dependencies.items() @@ -333,10 +397,10 @@ async def auto_injected_generator( def _get_auto_injected_sync_container_async_func_scoped( container_getter: ContainerGetter[Container], additional_params: Sequence[Parameter], - dependencies: dict[str, DependencyKey], - func: Callable[P, T], + dependencies: dict[str, DependencyKey] | ParameterDependencyResolver, + func: Callable[P, Awaitable[T]], provide_context: ProvideContext | None = None, -) -> Callable[P, T]: +) -> Callable[P, Awaitable[T]]: async def auto_injected_func(*args: P.args, **kwargs: P.kwargs) -> T: for param in additional_params: kwargs.pop(param.name) @@ -346,6 +410,8 @@ async def auto_injected_func(*args: P.args, **kwargs: P.kwargs) -> T: ) container = container_getter(args, kwargs) with container(additional_context) as container: + if isinstance(dependencies, ParameterDependencyResolver): + dependencies.bind(*args, **kwargs) solved = { name: container.get(dep.type_hint, component=dep.component) for name, dep in dependencies.items() @@ -358,10 +424,10 @@ async def auto_injected_func(*args: P.args, **kwargs: P.kwargs) -> T: def _get_auto_injected_sync_container_async_func( container_getter: ContainerGetter[Container], additional_params: Sequence[Parameter], - dependencies: dict[str, DependencyKey], - func: Callable[P, T], + dependencies: dict[str, DependencyKey] | ParameterDependencyResolver, + func: Callable[P, Awaitable[T]], provide_context: ProvideContext | None = None, -) -> Callable[P, T]: +) -> Callable[P, Awaitable[T]]: if provide_context is not None: raise ImproperProvideContextUsageError @@ -370,6 +436,8 @@ async def auto_injected_func(*args: P.args, **kwargs: P.kwargs) -> T: for param in additional_params: kwargs.pop(param.name) + if isinstance(dependencies, ParameterDependencyResolver): + dependencies.bind(*args, **kwargs) solved = { name: container.get(dep.type_hint, component=dep.component) for name, dep in dependencies.items() @@ -464,7 +532,7 @@ def __post_init__(self) -> None: } -def get_func_type(func: Callable) -> FunctionType: +def get_func_type(func: Callable[P, T]) -> FunctionType: if isasyncgenfunction(func): return FunctionType.ASYNC_GENERATOR elif isgeneratorfunction(func): @@ -573,24 +641,6 @@ def wrap_injection( for param in new_params ] - auto_injected_func: Callable[P, T | Awaitable[T]] - if additional_params: - new_params = _add_params(new_params, additional_params) - for param in additional_params: - new_annotations[param.name] = param.annotation - - if is_async: - func = cast(Callable[P, Awaitable[T]], func) - container_getter = cast( - ContainerGetter[AsyncContainer], - container_getter, - ) - else: - container_getter = cast( - ContainerGetter[Container], - container_getter, - ) - injected_func_type = InjectedFuncType( is_async_container=is_async, manage_scope=manage_scope, @@ -598,22 +648,31 @@ def wrap_injection( ) get_auto_injected_func = _GET_AUTO_INJECTED_FUNC_DICT[injected_func_type] - auto_injected_func = get_auto_injected_func( + auto_injected_func = get_auto_injected_func( # type: ignore[operator] func=func, provide_context=provide_context, - dependencies=dependencies, + dependencies=( + dependencies + if remove_depends + else ParameterDependencyResolver(new_params, dependencies) + ), additional_params=additional_params, container_getter=container_getter, ) + if additional_params: + new_params = _add_params(new_params, additional_params) + for param in additional_params: + new_annotations[param.name] = param.annotation + auto_injected_func.__dishka_orig_func__ = func - auto_injected_func.__dishka_injected__ = True # type: ignore[attr-defined] + auto_injected_func.__dishka_injected__ = True auto_injected_func.__name__ = func.__name__ auto_injected_func.__qualname__ = func.__qualname__ auto_injected_func.__doc__ = func.__doc__ auto_injected_func.__module__ = func.__module__ auto_injected_func.__annotations__ = new_annotations - auto_injected_func.__signature__ = Signature( # type: ignore[attr-defined] + auto_injected_func.__signature__ = Signature( parameters=new_params, return_annotation=func_signature.return_annotation, ) @@ -627,7 +686,7 @@ def is_dishka_injected(func: Callable[..., Any]) -> bool: def _add_params( params: Sequence[Parameter], additional_params: Sequence[Parameter], -): +) -> list[Parameter]: params_kind_dict: dict[_ParameterKind, list[Parameter]] = {} for param in params: diff --git a/tests/integrations/base/test_parameter_dependency_resolver.py b/tests/integrations/base/test_parameter_dependency_resolver.py new file mode 100644 index 00000000..bf66861a --- /dev/null +++ b/tests/integrations/base/test_parameter_dependency_resolver.py @@ -0,0 +1,92 @@ +from inspect import signature +from unittest.mock import Mock + +import pytest + +from dishka import FromDishka +from dishka.integrations.base import ParameterDependencyResolver +from tests.integrations.common import AppMock + +Dep1 = FromDishka[Mock] +Dep2 = FromDishka[AppMock] + + +def pos_or_kw(i: int, d1: Dep1, d2: Dep2, j: int = 0): ... + + +def kw_only(i: int, *, d1: Dep1, d2: Dep2, j: int = 0): ... + + +def mixed(i: int, d1: Dep1, *, d2: Dep2, j: int = 0): ... + + +def pos_only(i: int, d1: Dep1, d2: Dep2, /, j: int = 0): ... + + +def pos_only_d1(i: int, d1: Dep1, /, d2: Dep2, j: int = 0): ... + + +def var_args(i: int, d1: Dep1, *d2: Dep2, j: int = 0): ... + + +def var_kwargs(i: int, d1: Dep1, j: int = 0, **d2: Dep2): ... + + +def var_args_kwargs(i: int, *d1: Dep1, j: int = 0, **d2: Dep2): ... + + +def get_injected_names_factory(func): + params = list(signature(func).parameters.values()) + deps = {"d1": Mock(), "d2": AppMock(Mock())} + resolver = ParameterDependencyResolver(params, deps) + + def get_injected_names(*args, **kw): + resolver.bind(*args, **kw) + return [name for name, _ in resolver.items()] + + return get_injected_names + + +@pytest.mark.parametrize("func", [pos_or_kw, kw_only, mixed]) +def test_dont_pass_dependencies(func): + get_injected_names = get_injected_names_factory(func) + # Both dependencies injected + assert get_injected_names(1) == ["d1", "d2"] + assert get_injected_names(2, j=9) == ["d1", "d2"] + + +@pytest.mark.parametrize("func", [pos_or_kw, kw_only, mixed]) +def test_pass_dependencies_by_name(func): + get_injected_names = get_injected_names_factory(func) + # d1 passed by name, d2 injected + assert get_injected_names(1, d1=Mock()) == ["d2"] + assert get_injected_names(2, d1=Mock(), j=9) == ["d2"] + # d2 passed by name, d1 injected + assert get_injected_names(3, d2=Mock()) == ["d1"] + assert get_injected_names(4, d2=Mock(), j=9) == ["d1"] + # Both dependencies passed by name, no injection + assert get_injected_names(1, d1=Mock(), d2=Mock()) == [] + assert get_injected_names(2, d1=Mock(), d2=Mock(), j=9) == [] + + +@pytest.mark.parametrize("func", [pos_or_kw, mixed]) +def test_pass_dependencies_by_position(func): + get_injected_names = get_injected_names_factory(func) + # d1 passed positionally, d2 injected + assert get_injected_names(1, Mock()) == ["d2"] + assert get_injected_names(2, Mock(), j=9) == ["d2"] + # d1 passed positionally, d2 passed by name, no injection + assert get_injected_names(2, Mock(), d2=Mock()) == [] + assert get_injected_names(3, Mock(), d2=Mock(), j=9) == [] + if func is pos_or_kw: + # d1 and d2 passed positionally, no injection + assert get_injected_names(3, Mock(), Mock()) == [] + assert get_injected_names(3, Mock(), Mock(), j=9) == [] + + +@pytest.mark.parametrize( + "func", [pos_only, pos_only_d1, var_args, var_kwargs, var_args_kwargs], +) +def test_not_implemented_parameter_kinds(func): + with pytest.raises(NotImplementedError): + get_injected_names_factory(func) diff --git a/tests/integrations/base/test_wrap_injection_remove_depends.py b/tests/integrations/base/test_wrap_injection_remove_depends.py new file mode 100644 index 00000000..ecd81944 --- /dev/null +++ b/tests/integrations/base/test_wrap_injection_remove_depends.py @@ -0,0 +1,220 @@ +import asyncio +from collections.abc import Iterable +from inspect import isasyncgen, iscoroutine, isgenerator +from unittest.mock import Mock + +import pytest + +from dishka import FromDishka, make_async_container, make_container +from dishka.integrations.base import wrap_injection +from tests.integrations.common import AppMock + + +def raises_multiple_values(obj): + with pytest.raises(TypeError, match="multiple values for"): # noqa: PT012 + if isgenerator(obj): + list(obj) + elif callable(obj): + obj() + else: + pytest.fail("Object is neither a generator nor callable") + + +async def raises_multiple_values_async(obj): + with pytest.raises(TypeError, match="multiple values for"): # noqa: PT012 + if isasyncgen(obj): + async for _ in obj: + pass + elif iscoroutine(obj): + await obj + else: + pytest.fail("Object is neither a generator nor callable") + + +def sync_func(i: int, dep: FromDishka[AppMock], j: int = 0): + return dep(i, j) + + +def sync_gen(data: Iterable[int], dep: FromDishka[AppMock], j: int = 0): + for i in data: + yield dep(i, j) + + +async def async_func(i: int, dep: FromDishka[AppMock], j: int = 0): + await asyncio.sleep(0) + return dep(i, j) + + +async def async_gen( + data: Iterable[int], + dep: FromDishka[AppMock], + j: int = 0, +): + for i in data: + await asyncio.sleep(0) + yield dep(i, j) + + +@pytest.mark.parametrize("remove_depends", [True, False]) +def test_sync_func(remove_depends, app_provider): + container = make_container(app_provider) + wrapped_func = wrap_injection( + func=sync_func, + container_getter=lambda *_: container, + remove_depends=remove_depends, + is_async=False, + ) + + wrapped_func(1) + app_provider.app_mock.assert_called_with(1, 0) + wrapped_func(2, j=3) + app_provider.app_mock.assert_called_with(2, 3) + + app_provider.app_mock.reset_mock() + new_dep = AppMock(Mock()) + if remove_depends: + raises_multiple_values(lambda: wrapped_func(1, new_dep)) + raises_multiple_values(lambda: wrapped_func(2, dep=new_dep)) + raises_multiple_values(lambda: wrapped_func(3, new_dep, 9)) + raises_multiple_values(lambda: wrapped_func(4, new_dep, j=9)) + raises_multiple_values(lambda: wrapped_func(5, dep=new_dep, j=9)) + else: + wrapped_func(1, new_dep) + new_dep.assert_called_with(1, 0) + wrapped_func(2, dep=new_dep) + new_dep.assert_called_with(2, 0) + wrapped_func(3, new_dep, 9) + new_dep.assert_called_with(3, 9) + wrapped_func(4, new_dep, j=9) + new_dep.assert_called_with(4, 9) + wrapped_func(5, dep=new_dep, j=9) + new_dep.assert_called_with(5, 9) + app_provider.app_mock.assert_not_called() + + +@pytest.mark.parametrize("remove_depends", [True, False]) +def test_sync_gen(remove_depends, app_provider): + container = make_container(app_provider) + wrapped_func = wrap_injection( + func=sync_gen, + container_getter=lambda *_: container, + remove_depends=remove_depends, + is_async=False, + ) + + list(wrapped_func([1])) + app_provider.app_mock.assert_called_with(1, 0) + list(wrapped_func([2], j=3)) + app_provider.app_mock.assert_called_with(2, 3) + + app_provider.app_mock.reset_mock() + new_dep = AppMock(Mock()) + if remove_depends: + raises_multiple_values(wrapped_func([1], new_dep)) + raises_multiple_values(wrapped_func([2], dep=new_dep)) + raises_multiple_values(wrapped_func([3], new_dep, 9)) + raises_multiple_values(wrapped_func([4], new_dep, j=9)) + raises_multiple_values(wrapped_func([5], dep=new_dep, j=9)) + else: + list(wrapped_func([1], new_dep)) + new_dep.assert_called_with(1, 0) + list(wrapped_func([2], dep=new_dep)) + new_dep.assert_called_with(2, 0) + list(wrapped_func([3], new_dep, 9)) + new_dep.assert_called_with(3, 9) + list(wrapped_func([4], new_dep, j=9)) + new_dep.assert_called_with(4, 9) + list(wrapped_func([5], dep=new_dep, j=9)) + new_dep.assert_called_with(5, 9) + app_provider.app_mock.assert_not_called() + + +@pytest.mark.asyncio +@pytest.mark.parametrize("async_container", [True, False]) +@pytest.mark.parametrize("remove_depends", [True, False]) +async def test_async_func(async_container, remove_depends, app_provider): + if async_container: + container = make_async_container(app_provider) + else: + container = make_container(app_provider) + wrapped_func = wrap_injection( + func=async_func, + container_getter=lambda *_: container, + remove_depends=remove_depends, + is_async=async_container, + ) + + await wrapped_func(1) + app_provider.app_mock.assert_called_with(1, 0) + await wrapped_func(2, j=3) + app_provider.app_mock.assert_called_with(2, 3) + + app_provider.app_mock.reset_mock() + new_dep = AppMock(Mock()) + if remove_depends: + await raises_multiple_values_async(wrapped_func(1, new_dep)) + await raises_multiple_values_async(wrapped_func(2, dep=new_dep)) + await raises_multiple_values_async(wrapped_func(3, new_dep, 9)) + await raises_multiple_values_async(wrapped_func(4, new_dep, j=9)) + await raises_multiple_values_async(wrapped_func(5, dep=new_dep, j=9)) + else: + await wrapped_func(1, new_dep) + new_dep.assert_called_with(1, 0) + await wrapped_func(2, dep=new_dep) + new_dep.assert_called_with(2, 0) + await wrapped_func(3, new_dep, 9) + new_dep.assert_called_with(3, 9) + await wrapped_func(4, new_dep, j=9) + new_dep.assert_called_with(4, 9) + await wrapped_func(5, dep=new_dep, j=9) + new_dep.assert_called_with(5, 9) + app_provider.app_mock.assert_not_called() + + +@pytest.mark.asyncio +@pytest.mark.parametrize("async_container", [True, False]) +@pytest.mark.parametrize("remove_depends", [True, False]) +async def test_async_gen(async_container, remove_depends, app_provider): + if async_container: + container = make_async_container(app_provider) + else: + container = make_container(app_provider) + wrapped_func = wrap_injection( + func=async_gen, + container_getter=lambda *_: container, + remove_depends=remove_depends, + is_async=async_container, + ) + + async for _ in wrapped_func([1]): + pass + app_provider.app_mock.assert_called_with(1, 0) + async for _ in wrapped_func([2], j=3): + pass + app_provider.app_mock.assert_called_with(2, 3) + + app_provider.app_mock.reset_mock() + new_dep = AppMock(Mock()) + if remove_depends: + await raises_multiple_values_async(wrapped_func([1], new_dep)) + await raises_multiple_values_async(wrapped_func([2], dep=new_dep)) + await raises_multiple_values_async(wrapped_func([3], new_dep, 9)) + await raises_multiple_values_async(wrapped_func([4], new_dep, j=9)) + await raises_multiple_values_async(wrapped_func([5], dep=new_dep, j=9)) + else: + async for _ in wrapped_func([1], new_dep): + pass + new_dep.assert_called_with(1, 0) + async for _ in wrapped_func([2], dep=new_dep): + pass + new_dep.assert_called_with(2, 0) + async for _ in wrapped_func([3], new_dep, 9): + pass + new_dep.assert_called_with(3, 9) + async for _ in wrapped_func([4], new_dep, j=9): + pass + new_dep.assert_called_with(4, 9) + async for _ in wrapped_func([5], dep=new_dep, j=9): + pass + new_dep.assert_called_with(5, 9) + app_provider.app_mock.assert_not_called()