|
11 | 11 | from operator import methodcaller
|
12 | 12 | from typing_extensions import Annotated
|
13 | 13 | from dataclasses import field, dataclass
|
14 |
| -from typing import Any, TypeVar, Coroutine |
15 | 14 | from inspect import Parameter, Signature, isclass
|
16 | 15 | from collections.abc import Callable, Iterable, Generator
|
| 16 | +from typing import TYPE_CHECKING, Any, TypeVar, Coroutine |
17 | 17 | from importlib.metadata import Distribution, PackageNotFoundError, distribution
|
18 | 18 |
|
19 | 19 | import click
|
|
36 | 36 | from typing_extensions import ParamSpec, get_args, get_origin
|
37 | 37 |
|
38 | 38 |
|
| 39 | +if TYPE_CHECKING: |
| 40 | + from . import async_scoped_session |
| 41 | + |
| 42 | + |
39 | 43 | _T = TypeVar("_T")
|
40 | 44 | _P = ParamSpec("_P")
|
41 | 45 |
|
@@ -73,60 +77,67 @@ def write(self, buffer: str):
|
73 | 77 | while frame and frame.f_code.co_name != "print_stdout":
|
74 | 78 | frame = frame.f_back
|
75 | 79 | depth += 1
|
76 |
| - depth += 1 |
77 | 80 |
|
78 | 81 | for line in buffer.rstrip().splitlines():
|
79 |
| - logger.opt(depth=depth).log(self._level, line.rstrip()) |
| 82 | + logger.opt(depth=depth + 1).log(self._level, line.rstrip()) |
80 | 83 |
|
81 | 84 | def flush(self):
|
82 | 85 | pass
|
83 | 86 |
|
84 | 87 |
|
85 |
| -@dataclass |
| 88 | +@dataclass(unsafe_hash=True) |
86 | 89 | class Option:
|
87 | 90 | stream: bool = True
|
88 | 91 | scalars: bool = False
|
89 | 92 | result: methodcaller | None = None
|
90 |
| - calls: list[methodcaller] = field(default_factory=list) |
| 93 | + calls: tuple[methodcaller] = field(default_factory=tuple) |
91 | 94 |
|
92 | 95 |
|
93 |
| -def compile_dependency(statement: ExecutableReturnsRows, option: Option) -> Any: |
94 |
| - from . import async_scoped_session |
| 96 | +@dataclass |
| 97 | +class Dependency: |
| 98 | + __signature__: Signature = field(init=False) |
| 99 | + |
| 100 | + statement: ExecutableReturnsRows |
| 101 | + option: Option |
| 102 | + |
| 103 | + def __post_init__(self) -> None: |
| 104 | + from . import async_scoped_session |
| 105 | + |
| 106 | + self.__signature__ = Signature( |
| 107 | + [ |
| 108 | + Parameter( |
| 109 | + "_session", Parameter.KEYWORD_ONLY, annotation=async_scoped_session |
| 110 | + ), |
| 111 | + *( |
| 112 | + Parameter(name, Parameter.KEYWORD_ONLY, default=depends) |
| 113 | + for name, depends in self.statement.compile().params.items() |
| 114 | + if isinstance(depends, DependsInner) |
| 115 | + ), |
| 116 | + ] |
| 117 | + ) |
95 | 118 |
|
96 |
| - async def __dependency(*, __session: async_scoped_session, **params: Any): |
97 |
| - if option.stream: |
98 |
| - result = await __session.stream(statement, params) |
| 119 | + async def __call__(self, *, _session: async_scoped_session, **params: Any) -> Any: |
| 120 | + if self.option.stream: |
| 121 | + result = await _session.stream(self.statement, params) |
99 | 122 | else:
|
100 |
| - result = await __session.execute(statement, params) |
| 123 | + result = await _session.execute(self.statement, params) |
101 | 124 |
|
102 |
| - for call in option.calls: |
| 125 | + for call in self.option.calls: |
103 | 126 | result = call(result)
|
104 | 127 |
|
105 |
| - if option.scalars: |
| 128 | + if self.option.scalars: |
106 | 129 | result = result.scalars()
|
107 | 130 |
|
108 |
| - if call := option.result: |
| 131 | + if call := self.option.result: |
109 | 132 | result = call(result)
|
110 | 133 |
|
111 |
| - if option.stream: |
| 134 | + if self.option.stream: |
112 | 135 | result = await result
|
113 | 136 |
|
114 | 137 | return result
|
115 | 138 |
|
116 |
| - __dependency.__signature__ = Signature( |
117 |
| - [ |
118 |
| - Parameter( |
119 |
| - "__session", Parameter.KEYWORD_ONLY, annotation=async_scoped_session |
120 |
| - ), |
121 |
| - *( |
122 |
| - Parameter(name, Parameter.KEYWORD_ONLY, default=depends) |
123 |
| - for name, depends in statement.compile().params.items() |
124 |
| - if isinstance(depends, DependsInner) |
125 |
| - ), |
126 |
| - ] |
127 |
| - ) |
128 |
| - |
129 |
| - return Depends(__dependency) |
| 139 | + def __hash__(self) -> int: |
| 140 | + return hash((self.statement, self.option)) |
130 | 141 |
|
131 | 142 |
|
132 | 143 | def generic_issubclass(scls: Any, cls: Any) -> Any:
|
|
0 commit comments