Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 20 additions & 3 deletions src/dishka/integrations/base.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import inspect
from collections.abc import Awaitable, Callable, Sequence
from inspect import (
Parameter,
Expand Down Expand Up @@ -96,9 +97,15 @@ def wrap_injection(
func_signature = signature(func)

dependencies = {}
for name, param in func_signature.parameters.items():
hint = hints.get(name, Any)
dep = parse_dependency(param, hint)
for index, (name, param) in enumerate(func_signature.parameters.items()):
if name == "self" and index == 0:
# If it's a method in a class, by the time this is run the class
# hasn't been created yet, and inspection would fail. So,
# postpone it.
dep = DependencyKey(func, DEFAULT_COMPONENT)
else:
hint = hints.get(name, Any)
dep = parse_dependency(param, hint)
if dep is None:
continue
dependencies[name] = dep
Expand Down Expand Up @@ -185,6 +192,11 @@ async def auto_injected_func(*args: P.args, **kwargs: P.kwargs) -> T:
container = container_getter(args, kwargs)
for param in additional_params:
kwargs.pop(param.name)

if (dep := dependencies.get("self")) and dep.type_hint == func:
klass = inspect._findclass(dep.type_hint) # noqa: SLF001
dependencies["self"] = DependencyKey(klass, dep.component)

solved = {
name: await container.get(
dep.type_hint, component=dep.component,
Expand All @@ -198,6 +210,11 @@ async def auto_injected_func(*args: P.args, **kwargs: P.kwargs) -> T:
container = container_getter(args, kwargs)
for param in additional_params:
kwargs.pop(param.name)

if (dep := dependencies.get("self")) and dep.type_hint == func:
klass = inspect._findclass(dep.type_hint) # noqa: SLF001
dependencies["self"] = DependencyKey(klass, dep.component)

solved = {
name: await container.get(
dep.type_hint, component=dep.component,
Expand Down
41 changes: 40 additions & 1 deletion tests/integrations/fastapi/test_fastapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from asgi_lifespan import LifespanManager
from fastapi.testclient import TestClient

from dishka import make_async_container
from dishka import Scope, make_async_container, provide
from dishka.integrations.fastapi import (
DishkaRoute,
FromDishka,
Expand All @@ -18,6 +18,7 @@
APP_DEP_VALUE,
REQUEST_DEP_VALUE,
AppDep,
AppMock,
AppProvider,
RequestDep,
)
Expand Down Expand Up @@ -68,6 +69,44 @@ async def test_app_dependency(app_provider: AppProvider, app_factory):
app_provider.app_released.assert_called()


class Wrapper:
def __init__(
self,
a: AppDep,
app_mock: AppMock,
):
self.a = a
self.app_mock = app_mock

async def get_with_app(
self,
a: FromDishka[AppDep],
app_mock: FromDishka[AppMock],
) -> None:
assert self.a == a
assert self.app_mock == app_mock
app_mock(a)


class LocalProvider(AppProvider):
scope = Scope.APP

wrapper = provide(Wrapper)


@pytest.mark.parametrize("app_factory", [
dishka_app, dishka_auto_app,
])
@pytest.mark.asyncio
async def test_app_dependency_class(app_factory):
app_provider = LocalProvider()
async with app_factory(Wrapper.get_with_app, app_provider) as client:
client.get("/")
app_provider.app_mock.assert_called_with(APP_DEP_VALUE)
app_provider.app_released.assert_not_called()
app_provider.app_released.assert_called()


async def get_with_request(
a: FromDishka[RequestDep],
mock: FromDishka[Mock],
Expand Down