Skip to content

Commit 50aef14

Browse files
MuriloScarpaSitonioMuriloScarpaSitonio
andauthored
Handle functools.partial in wireup.ioc.util.get_globals (#92)
* fix wireup.ioc.util.get_globals * PR remarks * PR remarks * fix lint * fix ruff --------- Co-authored-by: MuriloScarpaSitonio <[email protected]>
1 parent 7574d25 commit 50aef14

File tree

2 files changed

+53
-3
lines changed

2 files changed

+53
-3
lines changed

test/unit/test_util.py

Lines changed: 48 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,20 @@
1+
import functools
12
import inspect
23
import unittest
4+
from typing import Callable
35

46
import pytest
57
from typing_extensions import Annotated
68
from wireup import Inject
79
from wireup.errors import WireupError
8-
from wireup.ioc.types import AnnotatedParameter, InjectableType, ParameterWrapper, ServiceQualifier
10+
from wireup.ioc.types import (
11+
AnnotatedParameter,
12+
InjectableType,
13+
ParameterWrapper,
14+
ServiceQualifier,
15+
)
916
from wireup.ioc.util import (
17+
get_globals,
1018
param_get_annotation,
1119
)
1220

@@ -29,10 +37,14 @@ def inner(
2937
): ...
3038

3139
params = inspect.signature(inner)
32-
self.assertEqual(param_get_annotation(params.parameters["_a"], globalns=globals()), AnnotatedParameter(str, d1))
40+
self.assertEqual(
41+
param_get_annotation(params.parameters["_a"], globalns=globals()),
42+
AnnotatedParameter(str, d1),
43+
)
3344
self.assertEqual(param_get_annotation(params.parameters["_b"], globalns=globals()), None)
3445
self.assertEqual(
35-
param_get_annotation(params.parameters["_c"], globalns=globals()), AnnotatedParameter(str, None)
46+
param_get_annotation(params.parameters["_c"], globalns=globals()),
47+
AnnotatedParameter(str, None),
3648
)
3749
self.assertEqual(
3850
param_get_annotation(params.parameters["_d"], globalns=globals()),
@@ -65,3 +77,36 @@ def inner(_a: Annotated[str, Inject(), Inject(param="foo")]): ...
6577

6678
class MyCustomClass:
6779
pass
80+
81+
82+
def _sample_function():
83+
pass
84+
85+
86+
def test_returns_globals_for_class():
87+
# GIVEN
88+
cls = MyCustomClass
89+
90+
# WHEN
91+
result = get_globals(cls)
92+
93+
# THEN
94+
assert "_sample_function" in result
95+
assert "MyCustomClass" in result
96+
97+
98+
@pytest.mark.parametrize(
99+
"partial_func",
100+
(
101+
_sample_function,
102+
functools.partial(_sample_function),
103+
functools.partial(functools.partial(functools.partial(_sample_function))),
104+
),
105+
)
106+
def test_unwraps_functools_partial(partial_func: Callable):
107+
# GIVEN
108+
# WHEN
109+
result = get_globals(partial_func)
110+
111+
# THEN
112+
assert result is _sample_function.__globals__

wireup/ioc/util.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from __future__ import annotations
22

3+
import functools
34
import importlib
45
import inspect
56
import types
@@ -82,6 +83,10 @@ def get_globals(obj: type[Any] | Callable[..., Any]) -> dict[str, Any]:
8283
if isinstance(obj, type):
8384
return importlib.import_module(obj.__module__).__dict__
8485

86+
# Unwrap nested functools.partial to get the underlying function
87+
while isinstance(obj, functools.partial):
88+
obj = obj.func
89+
8590
return obj.__globals__
8691

8792

0 commit comments

Comments
 (0)