Skip to content

Commit c92972c

Browse files
committed
✨ feat(sqla): dependency cache and validate
1 parent 636c969 commit c92972c

File tree

2 files changed

+64
-50
lines changed

2 files changed

+64
-50
lines changed

nonebot_plugin_orm/param.py

Lines changed: 24 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,13 @@
99

1010
from pydantic.fields import FieldInfo
1111
from nonebot.dependencies import Param
12-
from nonebot.params import DependParam
12+
from nonebot.params import Depends, DependParam
1313
from sqlalchemy import Row, Result, ScalarResult, select
1414
from sqlalchemy.sql.selectable import ExecutableReturnsRows
1515
from sqlalchemy.ext.asyncio import AsyncResult, AsyncScalarResult
1616

1717
from .model import Model
18-
from .utils import Option, compile_dependency, generic_issubclass
18+
from .utils import Option, Dependency, generic_issubclass
1919

2020
if sys.version_info >= (3, 10):
2121
from typing import Annotated, get_args, get_origin
@@ -119,12 +119,6 @@
119119
@dataclass
120120
class SQLDependsInner:
121121
dependency: ExecutableReturnsRows
122-
123-
if sys.version_info >= (3, 10):
124-
from dataclasses import KW_ONLY
125-
126-
_: KW_ONLY
127-
128122
use_cache: bool = True
129123
validate: bool | FieldInfo = False
130124

@@ -135,7 +129,7 @@ def SQLDepends(
135129
use_cache: bool = True,
136130
validate: bool | FieldInfo = False,
137131
) -> Any:
138-
return SQLDependsInner(dependency, use_cache=use_cache, validate=validate)
132+
return SQLDependsInner(dependency, use_cache, validate)
139133

140134

141135
class ORMParam(DependParam):
@@ -164,26 +158,35 @@ def _check_param(
164158
models = (models,)
165159

166160
if depends_inner is not None:
167-
dependency = compile_dependency(depends_inner.dependency, option)
161+
statement = depends_inner.dependency
168162
elif all(map(isclass, models)) and all(
169163
map(issubclass, cast(Tuple[type, ...], models), repeat(Model))
170164
):
171165
models = cast(Tuple[Type[Model], ...], models)
172-
dependency = compile_dependency(
173-
select(*models).where(
174-
*(
175-
getattr(model, name) == param.default
176-
for model in models
177-
for name, param in model.__signature__.parameters.items()
178-
)
179-
),
180-
option,
166+
# NOTE: statement is generated (see below)
167+
statement = select(*models).where(
168+
*(
169+
getattr(model, name) == param.default
170+
for model in models
171+
for name, param in model.__signature__.parameters.items()
172+
)
181173
)
182174
else:
183175
return
184176

185-
return super()._check_param(param.replace(default=dependency), allow_types)
177+
return super()._check_param(
178+
param.replace(
179+
default=Depends(
180+
Dependency(statement, option),
181+
use_cache=(
182+
depends_inner.use_cache if depends_inner else False
183+
), # NOTE: default use_cache=False as it is impossible to reuse a generated statement (see above)
184+
validate=depends_inner.validate if depends_inner else False,
185+
)
186+
),
187+
allow_types,
188+
)
186189

187190
@classmethod
188-
def _check_parameterless(cls, *_) -> Param | None:
191+
def _check_parameterless(cls, *_) -> None:
189192
return

nonebot_plugin_orm/utils.py

Lines changed: 40 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,9 @@
1111
from operator import methodcaller
1212
from typing_extensions import Annotated
1313
from dataclasses import field, dataclass
14-
from typing import Any, TypeVar, Coroutine
1514
from inspect import Parameter, Signature, isclass
1615
from collections.abc import Callable, Iterable, Generator
16+
from typing import TYPE_CHECKING, Any, TypeVar, Coroutine
1717
from importlib.metadata import Distribution, PackageNotFoundError, distribution
1818

1919
import click
@@ -36,6 +36,10 @@
3636
from typing_extensions import ParamSpec, get_args, get_origin
3737

3838

39+
if TYPE_CHECKING:
40+
from . import async_scoped_session
41+
42+
3943
_T = TypeVar("_T")
4044
_P = ParamSpec("_P")
4145

@@ -73,60 +77,67 @@ def write(self, buffer: str):
7377
while frame and frame.f_code.co_name != "print_stdout":
7478
frame = frame.f_back
7579
depth += 1
76-
depth += 1
7780

7881
for line in buffer.rstrip().splitlines():
79-
logger.opt(depth=depth).log(self._level, line.rstrip())
82+
logger.opt(depth=depth + 1).log(self._level, line.rstrip())
8083

8184
def flush(self):
8285
pass
8386

8487

85-
@dataclass
88+
@dataclass(unsafe_hash=True)
8689
class Option:
8790
stream: bool = True
8891
scalars: bool = False
8992
result: methodcaller | None = None
90-
calls: list[methodcaller] = field(default_factory=list)
93+
calls: tuple[methodcaller] = field(default_factory=tuple)
9194

9295

93-
def compile_dependency(statement: ExecutableReturnsRows, option: Option) -> Any:
94-
from . import async_scoped_session
96+
@dataclass
97+
class Dependency:
98+
__signature__: Signature = field(init=False)
99+
100+
statement: ExecutableReturnsRows
101+
option: Option
102+
103+
def __post_init__(self) -> None:
104+
from . import async_scoped_session
105+
106+
self.__signature__ = Signature(
107+
[
108+
Parameter(
109+
"_session", Parameter.KEYWORD_ONLY, annotation=async_scoped_session
110+
),
111+
*(
112+
Parameter(name, Parameter.KEYWORD_ONLY, default=depends)
113+
for name, depends in self.statement.compile().params.items()
114+
if isinstance(depends, DependsInner)
115+
),
116+
]
117+
)
95118

96-
async def __dependency(*, __session: async_scoped_session, **params: Any):
97-
if option.stream:
98-
result = await __session.stream(statement, params)
119+
async def __call__(self, *, _session: async_scoped_session, **params: Any) -> Any:
120+
if self.option.stream:
121+
result = await _session.stream(self.statement, params)
99122
else:
100-
result = await __session.execute(statement, params)
123+
result = await _session.execute(self.statement, params)
101124

102-
for call in option.calls:
125+
for call in self.option.calls:
103126
result = call(result)
104127

105-
if option.scalars:
128+
if self.option.scalars:
106129
result = result.scalars()
107130

108-
if call := option.result:
131+
if call := self.option.result:
109132
result = call(result)
110133

111-
if option.stream:
134+
if self.option.stream:
112135
result = await result
113136

114137
return result
115138

116-
__dependency.__signature__ = Signature(
117-
[
118-
Parameter(
119-
"__session", Parameter.KEYWORD_ONLY, annotation=async_scoped_session
120-
),
121-
*(
122-
Parameter(name, Parameter.KEYWORD_ONLY, default=depends)
123-
for name, depends in statement.compile().params.items()
124-
if isinstance(depends, DependsInner)
125-
),
126-
]
127-
)
128-
129-
return Depends(__dependency)
139+
def __hash__(self) -> int:
140+
return hash((self.statement, self.option))
130141

131142

132143
def generic_issubclass(scls: Any, cls: Any) -> Any:

0 commit comments

Comments
 (0)