Skip to content
17 changes: 16 additions & 1 deletion .mypy.ini
Original file line number Diff line number Diff line change
@@ -1,6 +1,21 @@
[mypy]
files = src/dishka
exclude = ^src/dishka/(_adaptix|integrations)/
exclude = (?x)(
^src/dishka/_adaptix/
|^src/dishka/integrations/(
aiohttp
|aiogram
|celery
|click
|fastapi
|faststream
|grpcio
|litestar
|sanic
|starlette
|taskiq
|telebot
).py)

strict = true
strict_bytes = true
Expand Down
143 changes: 101 additions & 42 deletions src/dishka/integrations/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
)
from dataclasses import dataclass
from enum import Enum
from functools import partial
from inspect import (
Parameter,
Signature,
Expand Down Expand Up @@ -54,10 +55,51 @@
)


class ParameterDependencyResolver:
def __init__(
self,
params: Sequence[Parameter],
dependencies: dict[str, DependencyKey],
):
self._selected_named_deps: list[tuple[str, DependencyKey]] = []
self._named_deps_predicates = []
named_params = {param.name: param for param in params}
named_idxs = {param.name: i for i, param in enumerate(params)}
for name, dep in dependencies.items():
match named_params[name].kind:
case Parameter.POSITIONAL_OR_KEYWORD:
pred = partial(_has_pos_or_kw, named_idxs[name], name)
case Parameter.KEYWORD_ONLY:
pred = partial(_has_kw_only, name)
case kind:
raise NotImplementedError(
f"Unsupported parameter kind: {kind}",
)
self._named_deps_predicates.append((name, dep, pred))

def bind(self, *args: Any, **kwargs: Any) -> None:
self._selected_named_deps = [
(name, dep)
for name, dep, has_param in self._named_deps_predicates
if not has_param(*args, **kwargs)
]

def items(self) -> Iterator[tuple[str, DependencyKey]]:
return iter(self._selected_named_deps)


def _has_pos_or_kw(i: int, name: str, *args: Any, **kwargs: Any) -> bool:
return i < len(args) or name in kwargs


def _has_kw_only(name: str, *args: Any, **kwargs: Any) -> bool:
return name in kwargs


def _get_auto_injected_async_gen_scoped(
container_getter: ContainerGetter[AsyncContainer],
additional_params: Sequence[Parameter],
dependencies: dict[str, DependencyKey],
dependencies: dict[str, DependencyKey] | ParameterDependencyResolver,
func: Callable[P, AsyncIterator[T]],
provide_context: ProvideContext | None = None,
) -> Callable[P, AsyncIterator[T]]:
Expand All @@ -73,6 +115,8 @@ async def auto_injected_generator(
)
container = container_getter(args, kwargs)
async with container(additional_context) as container:
if isinstance(dependencies, ParameterDependencyResolver):
Copy link
Member

@Tishka17 Tishka17 Jun 25, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's have 2 spearate function arguments. One is used to bind, another - to iterate over deps. So we can simplify check with if binder is not None

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just to clarify, you mean replace the dependencies: dict[str, DependencyKey] | ParameterDependencyResolver with two parameters dependencies: dict[str, DependencyKey] and resolver: ParameterDependencyResolver in all _get_auto_injected_* functions, and then use only one of the two?

if resolver is None:
   # iterate over dependencies
else:
   # iterate over resolver - dependencies are ignored 

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. I mean this. It might be not clear in terms of encapsulation but easier to check and rewrite with code generation

dependencies.bind(*args, **kwargs)
solved = {
name: await container.get(
dep.type_hint,
Expand All @@ -89,7 +133,7 @@ async def auto_injected_generator(
def _get_auto_injected_async_gen(
container_getter: ContainerGetter[AsyncContainer],
additional_params: Sequence[Parameter],
dependencies: dict[str, DependencyKey],
dependencies: dict[str, DependencyKey] | ParameterDependencyResolver,
func: Callable[P, AsyncIterator[T]],
provide_context: ProvideContext | None = None,
) -> Callable[P, AsyncIterator[T]]:
Expand All @@ -104,6 +148,8 @@ async def auto_injected_generator(
kwargs.pop(param.name)

container = container_getter(args, kwargs)
if isinstance(dependencies, ParameterDependencyResolver):
dependencies.bind(*args, **kwargs)
solved = {
name: await container.get(
dep.type_hint,
Expand All @@ -120,7 +166,7 @@ async def auto_injected_generator(
def _get_auto_injected_async_func_scoped(
container_getter: ContainerGetter[AsyncContainer],
additional_params: Sequence[Parameter],
dependencies: dict[str, DependencyKey],
dependencies: dict[str, DependencyKey] | ParameterDependencyResolver,
func: Callable[P, Awaitable[T]],
provide_context: ProvideContext | None = None,
) -> Callable[P, Awaitable[T]]:
Expand All @@ -133,6 +179,8 @@ async def auto_injected_func(*args: P.args, **kwargs: P.kwargs) -> T:
{} if provide_context is None else provide_context(args, kwargs)
)
async with container(additional_context) as container:
if isinstance(dependencies, ParameterDependencyResolver):
dependencies.bind(*args, **kwargs)
solved = {
name: await container.get(
dep.type_hint,
Expand All @@ -148,7 +196,7 @@ async def auto_injected_func(*args: P.args, **kwargs: P.kwargs) -> T:
def _get_auto_injected_async_func(
container_getter: ContainerGetter[AsyncContainer],
additional_params: Sequence[Parameter],
dependencies: dict[str, DependencyKey],
dependencies: dict[str, DependencyKey] | ParameterDependencyResolver,
func: Callable[P, Awaitable[T]],
provide_context: ProvideContext | None = None,
) -> Callable[P, Awaitable[T]]:
Expand All @@ -159,6 +207,8 @@ async def auto_injected_func(*args: P.args, **kwargs: P.kwargs) -> T:
container = container_getter(args, kwargs)
for param in additional_params:
kwargs.pop(param.name)
if isinstance(dependencies, ParameterDependencyResolver):
dependencies.bind(*args, **kwargs)
solved = {
name: await container.get(
dep.type_hint,
Expand All @@ -174,7 +224,7 @@ async def auto_injected_func(*args: P.args, **kwargs: P.kwargs) -> T:
def _get_auto_injected_sync_gen_scoped(
container_getter: ContainerGetter[Container],
additional_params: Sequence[Parameter],
dependencies: dict[str, DependencyKey],
dependencies: dict[str, DependencyKey] | ParameterDependencyResolver,
func: Callable[P, Iterator[T]],
provide_context: ProvideContext | None = None,
) -> Callable[P, Iterator[T]]:
Expand All @@ -190,6 +240,8 @@ def auto_injected_generator(
)
container = container_getter(args, kwargs)
with container(additional_context) as container:
if isinstance(dependencies, ParameterDependencyResolver):
dependencies.bind(*args, **kwargs)
solved = {
name: container.get(dep.type_hint, component=dep.component)
for name, dep in dependencies.items()
Expand All @@ -202,7 +254,7 @@ def auto_injected_generator(
def _get_auto_injected_sync_gen(
container_getter: ContainerGetter[Container],
additional_params: Sequence[Parameter],
dependencies: dict[str, DependencyKey],
dependencies: dict[str, DependencyKey] | ParameterDependencyResolver,
func: Callable[P, Iterator[T]],
provide_context: ProvideContext | None = None,
) -> Callable[P, Iterator[T]]:
Expand All @@ -216,6 +268,9 @@ def auto_injected_generator(
container = container_getter(args, kwargs)
for param in additional_params:
kwargs.pop(param.name)

if isinstance(dependencies, ParameterDependencyResolver):
dependencies.bind(*args, **kwargs)
solved = {
name: container.get(dep.type_hint, component=dep.component)
for name, dep in dependencies.items()
Expand All @@ -228,7 +283,7 @@ def auto_injected_generator(
def _get_auto_injected_sync_func_scoped(
container_getter: ContainerGetter[Container],
additional_params: Sequence[Parameter],
dependencies: dict[str, DependencyKey],
dependencies: dict[str, DependencyKey] | ParameterDependencyResolver,
func: Callable[P, T],
provide_context: ProvideContext | None = None,
) -> Callable[P, T]:
Expand All @@ -241,6 +296,8 @@ def auto_injected_func(*args: P.args, **kwargs: P.kwargs) -> T:
)
container = container_getter(args, kwargs)
with container(additional_context) as container:
if isinstance(dependencies, ParameterDependencyResolver):
dependencies.bind(*args, **kwargs)
solved = {
name: container.get(dep.type_hint, component=dep.component)
for name, dep in dependencies.items()
Expand All @@ -253,7 +310,7 @@ def auto_injected_func(*args: P.args, **kwargs: P.kwargs) -> T:
def _get_auto_injected_sync_func(
container_getter: ContainerGetter[Container],
additional_params: Sequence[Parameter],
dependencies: dict[str, DependencyKey],
dependencies: dict[str, DependencyKey] | ParameterDependencyResolver,
func: Callable[P, T],
provide_context: ProvideContext | None = None,
) -> Callable[P, T]:
Expand All @@ -265,6 +322,8 @@ def auto_injected_func(*args: P.args, **kwargs: P.kwargs) -> T:
for param in additional_params:
kwargs.pop(param.name)

if isinstance(dependencies, ParameterDependencyResolver):
dependencies.bind(*args, **kwargs)
solved = {
name: container.get(dep.type_hint, component=dep.component)
for name, dep in dependencies.items()
Expand All @@ -277,8 +336,8 @@ def auto_injected_func(*args: P.args, **kwargs: P.kwargs) -> T:
def _get_auto_injected_sync_container_async_gen_scoped(
container_getter: ContainerGetter[Container],
additional_params: Sequence[Parameter],
dependencies: dict[str, DependencyKey],
func: Callable[P, Iterator[T]],
dependencies: dict[str, DependencyKey] | ParameterDependencyResolver,
func: Callable[P, AsyncIterator[T]],
provide_context: ProvideContext | None = None,
) -> Callable[P, AsyncIterator[T]]:
async def auto_injected_generator(
Expand All @@ -293,6 +352,8 @@ async def auto_injected_generator(
)
container = container_getter(args, kwargs)
with container(additional_context) as container:
if isinstance(dependencies, ParameterDependencyResolver):
dependencies.bind(*args, **kwargs)
solved = {
name: container.get(dep.type_hint, component=dep.component)
for name, dep in dependencies.items()
Expand All @@ -306,8 +367,8 @@ async def auto_injected_generator(
def _get_auto_injected_sync_container_async_gen(
container_getter: ContainerGetter[Container],
additional_params: Sequence[Parameter],
dependencies: dict[str, DependencyKey],
func: Callable[P, Iterator[T]],
dependencies: dict[str, DependencyKey] | ParameterDependencyResolver,
func: Callable[P, AsyncIterator[T]],
provide_context: ProvideContext | None = None,
) -> Callable[P, AsyncIterator[T]]:
if provide_context is not None:
Expand All @@ -320,6 +381,9 @@ async def auto_injected_generator(
container = container_getter(args, kwargs)
for param in additional_params:
kwargs.pop(param.name)

if isinstance(dependencies, ParameterDependencyResolver):
dependencies.bind(*args, **kwargs)
solved = {
name: container.get(dep.type_hint, component=dep.component)
for name, dep in dependencies.items()
Expand All @@ -333,10 +397,10 @@ async def auto_injected_generator(
def _get_auto_injected_sync_container_async_func_scoped(
container_getter: ContainerGetter[Container],
additional_params: Sequence[Parameter],
dependencies: dict[str, DependencyKey],
func: Callable[P, T],
dependencies: dict[str, DependencyKey] | ParameterDependencyResolver,
func: Callable[P, Awaitable[T]],
provide_context: ProvideContext | None = None,
) -> Callable[P, T]:
) -> Callable[P, Awaitable[T]]:
async def auto_injected_func(*args: P.args, **kwargs: P.kwargs) -> T:
for param in additional_params:
kwargs.pop(param.name)
Expand All @@ -346,6 +410,8 @@ async def auto_injected_func(*args: P.args, **kwargs: P.kwargs) -> T:
)
container = container_getter(args, kwargs)
with container(additional_context) as container:
if isinstance(dependencies, ParameterDependencyResolver):
dependencies.bind(*args, **kwargs)
solved = {
name: container.get(dep.type_hint, component=dep.component)
for name, dep in dependencies.items()
Expand All @@ -358,10 +424,10 @@ async def auto_injected_func(*args: P.args, **kwargs: P.kwargs) -> T:
def _get_auto_injected_sync_container_async_func(
container_getter: ContainerGetter[Container],
additional_params: Sequence[Parameter],
dependencies: dict[str, DependencyKey],
func: Callable[P, T],
dependencies: dict[str, DependencyKey] | ParameterDependencyResolver,
func: Callable[P, Awaitable[T]],
provide_context: ProvideContext | None = None,
) -> Callable[P, T]:
) -> Callable[P, Awaitable[T]]:
if provide_context is not None:
raise ImproperProvideContextUsageError

Expand All @@ -370,6 +436,8 @@ async def auto_injected_func(*args: P.args, **kwargs: P.kwargs) -> T:
for param in additional_params:
kwargs.pop(param.name)

if isinstance(dependencies, ParameterDependencyResolver):
dependencies.bind(*args, **kwargs)
solved = {
name: container.get(dep.type_hint, component=dep.component)
for name, dep in dependencies.items()
Expand Down Expand Up @@ -464,7 +532,7 @@ def __post_init__(self) -> None:
}


def get_func_type(func: Callable) -> FunctionType:
def get_func_type(func: Callable[P, T]) -> FunctionType:
if isasyncgenfunction(func):
return FunctionType.ASYNC_GENERATOR
elif isgeneratorfunction(func):
Expand Down Expand Up @@ -573,47 +641,38 @@ def wrap_injection(
for param in new_params
]

auto_injected_func: Callable[P, T | Awaitable[T]]
if additional_params:
new_params = _add_params(new_params, additional_params)
for param in additional_params:
new_annotations[param.name] = param.annotation

if is_async:
func = cast(Callable[P, Awaitable[T]], func)
container_getter = cast(
ContainerGetter[AsyncContainer],
container_getter,
)
else:
container_getter = cast(
ContainerGetter[Container],
container_getter,
)

injected_func_type = InjectedFuncType(
is_async_container=is_async,
manage_scope=manage_scope,
func_type=get_func_type(func),
)
get_auto_injected_func = _GET_AUTO_INJECTED_FUNC_DICT[injected_func_type]

auto_injected_func = get_auto_injected_func(
auto_injected_func = get_auto_injected_func( # type: ignore[operator]
func=func,
provide_context=provide_context,
dependencies=dependencies,
dependencies=(
dependencies
if remove_depends
else ParameterDependencyResolver(new_params, dependencies)
),
additional_params=additional_params,
container_getter=container_getter,
)

if additional_params:
new_params = _add_params(new_params, additional_params)
for param in additional_params:
new_annotations[param.name] = param.annotation

auto_injected_func.__dishka_orig_func__ = func
auto_injected_func.__dishka_injected__ = True # type: ignore[attr-defined]
auto_injected_func.__dishka_injected__ = True
auto_injected_func.__name__ = func.__name__
auto_injected_func.__qualname__ = func.__qualname__
auto_injected_func.__doc__ = func.__doc__
auto_injected_func.__module__ = func.__module__
auto_injected_func.__annotations__ = new_annotations
auto_injected_func.__signature__ = Signature( # type: ignore[attr-defined]
auto_injected_func.__signature__ = Signature(
parameters=new_params,
return_annotation=func_signature.return_annotation,
)
Expand All @@ -627,7 +686,7 @@ def is_dishka_injected(func: Callable[..., Any]) -> bool:
def _add_params(
params: Sequence[Parameter],
additional_params: Sequence[Parameter],
):
) -> list[Parameter]:
params_kind_dict: dict[_ParameterKind, list[Parameter]] = {}

for param in params:
Expand Down
Loading