11
11
from operator import methodcaller
12
12
from typing_extensions import Annotated
13
13
from dataclasses import field , dataclass
14
- from inspect import Parameter , Signature , isclass
14
+ from inspect import Parameter , Signature
15
15
from collections .abc import Callable , Iterable , Generator
16
16
from typing import TYPE_CHECKING , Any , TypeVar , Coroutine
17
17
from importlib .metadata import Distribution , PackageNotFoundError , distribution
@@ -89,8 +89,8 @@ def flush(self):
89
89
class Option :
90
90
stream : bool = True
91
91
scalars : bool = False
92
+ calls : tuple [methodcaller , ...] = field (default_factory = tuple )
92
93
result : methodcaller | None = None
93
- calls : tuple [methodcaller ] = field (default_factory = tuple )
94
94
95
95
96
96
@dataclass
@@ -122,12 +122,12 @@ async def __call__(self, *, _session: async_scoped_session, **params: Any) -> An
122
122
else :
123
123
result = await _session .execute (self .statement , params )
124
124
125
- for call in self .option .calls :
126
- result = call (result )
127
-
128
125
if self .option .scalars :
129
126
result = result .scalars ()
130
127
128
+ for call in self .option .calls :
129
+ result = call (result )
130
+
131
131
if call := self .option .result :
132
132
result = call (result )
133
133
@@ -140,14 +140,17 @@ def __hash__(self) -> int:
140
140
return hash ((self .statement , self .option ))
141
141
142
142
143
- def generic_issubclass (scls : Any , cls : Any ) -> Any :
144
- if cls is Any :
145
- return True
143
+ def generic_issubclass (scls : Any , cls : Any ) -> bool | list [ Any ] :
144
+ if isinstance ( cls , tuple ) :
145
+ return _map_generic_issubclass ( repeat ( scls ), cls )
146
146
147
147
if scls is Any :
148
- return cls
148
+ return [cls ]
149
+
150
+ if cls is Any :
151
+ return True
149
152
150
- if isclass ( scls ) and ( isclass ( cls ) or isinstance ( cls , tuple ) ):
153
+ with suppress ( TypeError ):
151
154
return issubclass (scls , cls )
152
155
153
156
scls_origin , scls_args = get_origin (scls ) or scls , get_args (scls )
@@ -158,15 +161,17 @@ def generic_issubclass(scls: Any, cls: Any) -> Any:
158
161
return generic_issubclass (scls_args [0 ], cls_args )
159
162
160
163
if len (cls_args ) == 2 and cls_args [1 ] is Ellipsis :
161
- return all (map (generic_issubclass , scls_args , repeat (cls_args [0 ])))
164
+ return _map_generic_issubclass (
165
+ scls_args , repeat (cls_args [0 ]), failfast = True
166
+ )
162
167
163
168
if scls_origin is Annotated :
164
169
return generic_issubclass (scls_args [0 ], cls )
165
170
if cls_origin is Annotated :
166
171
return generic_issubclass (scls , cls_args [0 ])
167
172
168
173
if origin_is_union (scls_origin ):
169
- return all ( map ( generic_issubclass , scls_args , repeat (cls )) )
174
+ return _map_generic_issubclass ( scls_args , repeat (cls ), failfast = True )
170
175
if origin_is_union (cls_origin ):
171
176
return generic_issubclass (scls , cls_args )
172
177
@@ -182,9 +187,25 @@ def generic_issubclass(scls: Any, cls: Any) -> Any:
182
187
if not cls_args :
183
188
return True
184
189
185
- return len (scls_args ) == len (cls_args ) and all (
186
- map (generic_issubclass , scls_args , cls_args )
187
- )
190
+ if len (scls_args ) != len (cls_args ):
191
+ return False
192
+
193
+ return _map_generic_issubclass (scls_args , cls_args , failfast = True )
194
+
195
+
196
+ def _map_generic_issubclass (
197
+ scls : Iterable [Any ], cls : Iterable [Any ], * , failfast : bool = False
198
+ ) -> bool | list [Any ]:
199
+ results = []
200
+ for scls_arg , cls_arg in zip (scls , cls ):
201
+ if not (result := generic_issubclass (scls_arg , cls_arg )) and failfast :
202
+ return False
203
+ elif isinstance (result , list ):
204
+ results .extend (result )
205
+ elif not isinstance (result , bool ):
206
+ results .append (result )
207
+
208
+ return results or False
188
209
189
210
190
211
def return_progressbar (func : Callable [_P , Iterable [_T ]]) -> Callable [_P , Iterable [_T ]]:
@@ -217,7 +238,11 @@ def get_parent_plugins(plugin: Plugin | None) -> Generator[Plugin, Any, None]:
217
238
def is_editable (plugin : Plugin ) -> bool :
218
239
* _ , plugin = get_parent_plugins (plugin )
219
240
220
- path = files (plugin .module )
241
+ try :
242
+ path = files (plugin .module )
243
+ except TypeError :
244
+ return False
245
+
221
246
if not isinstance (path , Path ) or "site-packages" in path .parts :
222
247
return False
223
248
0 commit comments