1+ import functools
12import inspect
23import unittest
4+ from typing import Callable
35
46import pytest
57from typing_extensions import Annotated
68from wireup import Inject
79from wireup .errors import WireupError
8- from wireup .ioc .types import AnnotatedParameter , InjectableType , ParameterWrapper , ServiceQualifier
10+ from wireup .ioc .types import (
11+ AnnotatedParameter ,
12+ InjectableType ,
13+ ParameterWrapper ,
14+ ServiceQualifier ,
15+ )
916from wireup .ioc .util import (
17+ get_globals ,
1018 param_get_annotation ,
1119)
1220
@@ -29,10 +37,14 @@ def inner(
2937 ): ...
3038
3139 params = inspect .signature (inner )
32- self .assertEqual (param_get_annotation (params .parameters ["_a" ], globalns = globals ()), AnnotatedParameter (str , d1 ))
40+ self .assertEqual (
41+ param_get_annotation (params .parameters ["_a" ], globalns = globals ()),
42+ AnnotatedParameter (str , d1 ),
43+ )
3344 self .assertEqual (param_get_annotation (params .parameters ["_b" ], globalns = globals ()), None )
3445 self .assertEqual (
35- param_get_annotation (params .parameters ["_c" ], globalns = globals ()), AnnotatedParameter (str , None )
46+ param_get_annotation (params .parameters ["_c" ], globalns = globals ()),
47+ AnnotatedParameter (str , None ),
3648 )
3749 self .assertEqual (
3850 param_get_annotation (params .parameters ["_d" ], globalns = globals ()),
@@ -65,3 +77,36 @@ def inner(_a: Annotated[str, Inject(), Inject(param="foo")]): ...
6577
6678class MyCustomClass :
6779 pass
80+
81+
82+ def _sample_function ():
83+ pass
84+
85+
86+ def test_returns_globals_for_class ():
87+ # GIVEN
88+ cls = MyCustomClass
89+
90+ # WHEN
91+ result = get_globals (cls )
92+
93+ # THEN
94+ assert "_sample_function" in result
95+ assert "MyCustomClass" in result
96+
97+
98+ @pytest .mark .parametrize (
99+ "partial_func" ,
100+ (
101+ _sample_function ,
102+ functools .partial (_sample_function ),
103+ functools .partial (functools .partial (functools .partial (_sample_function ))),
104+ ),
105+ )
106+ def test_unwraps_functools_partial (partial_func : Callable ):
107+ # GIVEN
108+ # WHEN
109+ result = get_globals (partial_func )
110+
111+ # THEN
112+ assert result is _sample_function .__globals__
0 commit comments