Skip to content

Commit 81d3321

Browse files
committed
🐛 fix(sqla): dependency injection type check
1 parent ac212b0 commit 81d3321

File tree

2 files changed

+72
-30
lines changed

2 files changed

+72
-30
lines changed

nonebot_plugin_orm/param.py

Lines changed: 31 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
from pydantic.fields import FieldInfo
1111
from nonebot.dependencies import Param
12+
from nonebot.typing import origin_is_union
1213
from nonebot.params import Depends, DependParam
1314
from sqlalchemy import Row, Result, ScalarResult, select
1415
from sqlalchemy.sql.selectable import ExecutableReturnsRows
@@ -37,32 +38,32 @@
3738
AsyncIterator[Sequence[Row[Tuple[Any, ...]]]]: Option(
3839
True,
3940
False,
40-
methodcaller("partitions"),
41+
(methodcaller("partitions"),),
4142
),
4243
AsyncIterator[Sequence[Tuple[Any, ...]]]: Option(
4344
True,
4445
False,
45-
methodcaller("partitions"),
46+
(methodcaller("partitions"),),
4647
),
4748
AsyncIterator[Sequence[Any]]: Option(
4849
True,
4950
True,
50-
methodcaller("partitions"),
51+
(methodcaller("partitions"),),
5152
),
5253
Iterator[Sequence[Row[Tuple[Any, ...]]]]: Option(
5354
False,
5455
False,
55-
methodcaller("partitions"),
56+
(methodcaller("partitions"),),
5657
),
5758
Iterator[Sequence[Tuple[Any, ...]]]: Option(
5859
False,
5960
False,
60-
methodcaller("partitions"),
61+
(methodcaller("partitions"),),
6162
),
6263
Iterator[Sequence[Any]]: Option(
6364
False,
6465
True,
65-
methodcaller("partitions"),
66+
(methodcaller("partitions"),),
6667
),
6768
AsyncResult[Tuple[Any, ...]]: Option(
6869
True,
@@ -91,26 +92,31 @@
9192
Sequence[Row[Tuple[Any, ...]]]: Option(
9293
True,
9394
False,
95+
(),
9496
methodcaller("all"),
9597
),
9698
Sequence[Tuple[Any, ...]]: Option(
9799
True,
98100
False,
101+
(),
99102
methodcaller("all"),
100103
),
101104
Sequence[Any]: Option(
102105
True,
103106
True,
107+
(),
104108
methodcaller("all"),
105109
),
106110
Tuple[Any, ...]: Option(
107111
True,
108112
False,
113+
(),
109114
methodcaller("one_or_none"),
110115
),
111116
Any: Option(
112117
True,
113118
True,
119+
(),
114120
methodcaller("one_or_none"),
115121
),
116122
}
@@ -149,20 +155,31 @@ def _check_param(
149155
depends_inner = param.default
150156

151157
for pattern, option in PATTERNS.items():
152-
if models := generic_issubclass(pattern, type_annotation):
158+
if models := cast(
159+
"list[Any]", generic_issubclass(pattern, type_annotation)
160+
):
153161
break
154162
else:
155-
models, option = None, Option()
163+
models, option = [], Option()
164+
165+
for index, model in enumerate(models):
166+
if origin_is_union(get_origin(model)):
167+
models[index] = next(
168+
(
169+
arg
170+
for arg in get_args(model)
171+
if isclass(arg) and issubclass(arg, Model)
172+
),
173+
None,
174+
)
156175

157-
if not isinstance(models, tuple):
158-
models = (models,)
176+
if not (isclass(models[index]) and issubclass(models[index], Model)):
177+
models = []
178+
break
159179

160180
if depends_inner is not None:
161181
statement = depends_inner.dependency
162-
elif all(map(isclass, models)) and all(
163-
map(issubclass, cast(Tuple[type, ...], models), repeat(Model))
164-
):
165-
models = cast(Tuple[Type[Model], ...], models)
182+
elif models:
166183
# NOTE: statement is generated (see below)
167184
statement = select(*models).where(
168185
*(

nonebot_plugin_orm/utils.py

Lines changed: 41 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from operator import methodcaller
1212
from typing_extensions import Annotated
1313
from dataclasses import field, dataclass
14-
from inspect import Parameter, Signature, isclass
14+
from inspect import Parameter, Signature
1515
from collections.abc import Callable, Iterable, Generator
1616
from typing import TYPE_CHECKING, Any, TypeVar, Coroutine
1717
from importlib.metadata import Distribution, PackageNotFoundError, distribution
@@ -89,8 +89,8 @@ def flush(self):
8989
class Option:
9090
stream: bool = True
9191
scalars: bool = False
92+
calls: tuple[methodcaller, ...] = field(default_factory=tuple)
9293
result: methodcaller | None = None
93-
calls: tuple[methodcaller] = field(default_factory=tuple)
9494

9595

9696
@dataclass
@@ -122,12 +122,12 @@ async def __call__(self, *, _session: async_scoped_session, **params: Any) -> An
122122
else:
123123
result = await _session.execute(self.statement, params)
124124

125-
for call in self.option.calls:
126-
result = call(result)
127-
128125
if self.option.scalars:
129126
result = result.scalars()
130127

128+
for call in self.option.calls:
129+
result = call(result)
130+
131131
if call := self.option.result:
132132
result = call(result)
133133

@@ -140,14 +140,17 @@ def __hash__(self) -> int:
140140
return hash((self.statement, self.option))
141141

142142

143-
def generic_issubclass(scls: Any, cls: Any) -> Any:
144-
if cls is Any:
145-
return True
143+
def generic_issubclass(scls: Any, cls: Any) -> bool | list[Any]:
144+
if isinstance(cls, tuple):
145+
return _map_generic_issubclass(repeat(scls), cls)
146146

147147
if scls is Any:
148-
return cls
148+
return [cls]
149+
150+
if cls is Any:
151+
return True
149152

150-
if isclass(scls) and (isclass(cls) or isinstance(cls, tuple)):
153+
with suppress(TypeError):
151154
return issubclass(scls, cls)
152155

153156
scls_origin, scls_args = get_origin(scls) or scls, get_args(scls)
@@ -158,15 +161,17 @@ def generic_issubclass(scls: Any, cls: Any) -> Any:
158161
return generic_issubclass(scls_args[0], cls_args)
159162

160163
if len(cls_args) == 2 and cls_args[1] is Ellipsis:
161-
return all(map(generic_issubclass, scls_args, repeat(cls_args[0])))
164+
return _map_generic_issubclass(
165+
scls_args, repeat(cls_args[0]), failfast=True
166+
)
162167

163168
if scls_origin is Annotated:
164169
return generic_issubclass(scls_args[0], cls)
165170
if cls_origin is Annotated:
166171
return generic_issubclass(scls, cls_args[0])
167172

168173
if origin_is_union(scls_origin):
169-
return all(map(generic_issubclass, scls_args, repeat(cls)))
174+
return _map_generic_issubclass(scls_args, repeat(cls), failfast=True)
170175
if origin_is_union(cls_origin):
171176
return generic_issubclass(scls, cls_args)
172177

@@ -182,9 +187,25 @@ def generic_issubclass(scls: Any, cls: Any) -> Any:
182187
if not cls_args:
183188
return True
184189

185-
return len(scls_args) == len(cls_args) and all(
186-
map(generic_issubclass, scls_args, cls_args)
187-
)
190+
if len(scls_args) != len(cls_args):
191+
return False
192+
193+
return _map_generic_issubclass(scls_args, cls_args, failfast=True)
194+
195+
196+
def _map_generic_issubclass(
197+
scls: Iterable[Any], cls: Iterable[Any], *, failfast: bool = False
198+
) -> bool | list[Any]:
199+
results = []
200+
for scls_arg, cls_arg in zip(scls, cls):
201+
if not (result := generic_issubclass(scls_arg, cls_arg)) and failfast:
202+
return False
203+
elif isinstance(result, list):
204+
results.extend(result)
205+
elif not isinstance(result, bool):
206+
results.append(result)
207+
208+
return results or False
188209

189210

190211
def return_progressbar(func: Callable[_P, Iterable[_T]]) -> Callable[_P, Iterable[_T]]:
@@ -217,7 +238,11 @@ def get_parent_plugins(plugin: Plugin | None) -> Generator[Plugin, Any, None]:
217238
def is_editable(plugin: Plugin) -> bool:
218239
*_, plugin = get_parent_plugins(plugin)
219240

220-
path = files(plugin.module)
241+
try:
242+
path = files(plugin.module)
243+
except TypeError:
244+
return False
245+
221246
if not isinstance(path, Path) or "site-packages" in path.parts:
222247
return False
223248

0 commit comments

Comments
 (0)