Skip to content

Commit c363a14

Browse files
committed
feat: finished typing
1 parent c43cb59 commit c363a14

File tree

3 files changed

+47
-40
lines changed

3 files changed

+47
-40
lines changed

peewee_async/aio_model.py

Lines changed: 33 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,10 @@
55
from .result_wrappers import fetch_models
66
from .utils import CursorProtocol
77
from typing_extensions import Self
8-
from typing import Tuple, List, Any, cast, Optional
8+
from typing import Tuple, List, Any, cast, Optional, Dict, Union
99

1010

11-
async def aio_prefetch(sq, *subqueries, prefetch_type: PREFETCH_TYPE = PREFETCH_TYPE.WHERE):
11+
async def aio_prefetch(sq: Any, *subqueries: Any, prefetch_type: PREFETCH_TYPE = PREFETCH_TYPE.WHERE) -> Any:
1212
"""Asynchronous version of `prefetch()`.
1313
1414
See also:
@@ -18,8 +18,8 @@ async def aio_prefetch(sq, *subqueries, prefetch_type: PREFETCH_TYPE = PREFETCH_
1818
return sq
1919

2020
fixed_queries = peewee.prefetch_add_subquery(sq, subqueries, prefetch_type)
21-
deps = {}
22-
rel_map = {}
21+
deps: Dict[Any, Any] = {}
22+
rel_map: Dict[Any, Any] = {}
2323

2424
for pq in reversed(fixed_queries):
2525
query_model = pq.model
@@ -49,27 +49,27 @@ class AioQueryMixin:
4949
async def aio_execute(self, database: AioDatabase) -> Any:
5050
return await database.aio_execute(self)
5151

52-
async def fetch_results(self, cursor: CursorProtocol) -> List[Any]:
52+
async def fetch_results(self, cursor: CursorProtocol) -> Any:
5353
return await fetch_models(cursor, self)
5454

5555

5656
class AioModelDelete(peewee.ModelDelete, AioQueryMixin):
57-
async def fetch_results(self, cursor: CursorProtocol):
57+
async def fetch_results(self, cursor: CursorProtocol) -> Union[List[Any], int]:
5858
if self._returning:
5959
return await fetch_models(cursor, self)
6060
return cursor.rowcount
6161

6262

6363
class AioModelUpdate(peewee.ModelUpdate, AioQueryMixin):
6464

65-
async def fetch_results(self, cursor: CursorProtocol):
65+
async def fetch_results(self, cursor: CursorProtocol) -> Union[List[Any], int]:
6666
if self._returning:
6767
return await fetch_models(cursor, self)
6868
return cursor.rowcount
6969

7070

7171
class AioModelInsert(peewee.ModelInsert, AioQueryMixin):
72-
async def fetch_results(self, cursor: CursorProtocol):
72+
async def fetch_results(self, cursor: CursorProtocol) -> Union[List[Any], Any, int]:
7373
if self._returning is not None and len(self._returning) > 1:
7474
return await fetch_models(cursor, self)
7575

@@ -96,26 +96,26 @@ async def aio_scalar(self, database: AioDatabase, as_tuple: bool = False) -> Any
9696
See also:
9797
http://docs.peewee-orm.com/en/3.15.3/peewee/api.html#SelectBase.scalar
9898
"""
99-
async def fetch_results(cursor):
99+
async def fetch_results(cursor: CursorProtocol) -> Any:
100100
return await cursor.fetchone()
101101

102102
rows = await database.aio_execute(self, fetch_results=fetch_results)
103103

104104
return rows[0] if rows and not as_tuple else rows
105105

106-
async def aio_get(self, database: Optional[AioDatabase] = None):
106+
async def aio_get(self, database: Optional[AioDatabase] = None) -> Any:
107107
"""
108108
Async version of **peewee.SelectBase.get**
109109
110110
See also:
111111
http://docs.peewee-orm.com/en/3.15.3/peewee/api.html#SelectBase.get
112112
"""
113-
clone = self.paginate(1, 1)
113+
clone = self.paginate(1, 1) # type: ignore
114114
try:
115115
return (await clone.aio_execute(database))[0]
116116
except IndexError:
117117
sql, params = clone.sql()
118-
raise self.model.DoesNotExist('%s instance matching query does '
118+
raise self.model.DoesNotExist('%s instance matching query does ' # type: ignore
119119
'not exist:\nSQL: %s\nParams: %s' %
120120
(clone.model, sql, params))
121121

@@ -127,7 +127,7 @@ async def aio_count(self, database: AioDatabase, clear_limit: bool = False) -> i
127127
See also:
128128
http://docs.peewee-orm.com/en/3.15.3/peewee/api.html#SelectBase.count
129129
"""
130-
clone = self.order_by().alias('_wrapped')
130+
clone = self.order_by().alias('_wrapped') # type: ignore
131131
if clear_limit:
132132
clone._limit = clone._offset = None
133133
try:
@@ -150,28 +150,28 @@ async def aio_exists(self, database: AioDatabase) -> bool:
150150
See also:
151151
http://docs.peewee-orm.com/en/3.15.3/peewee/api.html#SelectBase.exists
152152
"""
153-
clone = self.columns(peewee.SQL('1'))
153+
clone = self.columns(peewee.SQL('1')) # type: ignore
154154
clone._limit = 1
155155
clone._offset = None
156156
return bool(await clone.aio_scalar())
157157

158-
def union_all(self, rhs):
159-
return AioModelCompoundSelectQuery(self.model, self, 'UNION ALL', rhs)
158+
def union_all(self, rhs: Any) -> "AioModelCompoundSelectQuery":
159+
return AioModelCompoundSelectQuery(self.model, self, 'UNION ALL', rhs) # type: ignore
160160
__add__ = union_all
161161

162-
def union(self, rhs):
163-
return AioModelCompoundSelectQuery(self.model, self, 'UNION', rhs)
162+
def union(self, rhs: Any) -> "AioModelCompoundSelectQuery":
163+
return AioModelCompoundSelectQuery(self.model, self, 'UNION', rhs) # type: ignore
164164
__or__ = union
165165

166-
def intersect(self, rhs):
167-
return AioModelCompoundSelectQuery(self.model, self, 'INTERSECT', rhs)
166+
def intersect(self, rhs: Any) -> "AioModelCompoundSelectQuery":
167+
return AioModelCompoundSelectQuery(self.model, self, 'INTERSECT', rhs) # type: ignore
168168
__and__ = intersect
169169

170-
def except_(self, rhs):
171-
return AioModelCompoundSelectQuery(self.model, self, 'EXCEPT', rhs)
170+
def except_(self, rhs: Any) -> "AioModelCompoundSelectQuery":
171+
return AioModelCompoundSelectQuery(self.model, self, 'EXCEPT', rhs) # type: ignore
172172
__sub__ = except_
173173

174-
def aio_prefetch(self, *subqueries, prefetch_type: PREFETCH_TYPE = PREFETCH_TYPE.WHERE):
174+
def aio_prefetch(self, *subqueries: Any, prefetch_type: PREFETCH_TYPE = PREFETCH_TYPE.WHERE) -> Any:
175175
"""
176176
Async version of **peewee.ModelSelect.prefetch**
177177
@@ -214,32 +214,32 @@ class User(peewee_async.AioModel):
214214
"""
215215

216216
@classmethod
217-
def select(cls, *fields) -> AioModelSelect:
217+
def select(cls, *fields: Any) -> AioModelSelect:
218218
is_default = not fields
219219
if not fields:
220220
fields = cls._meta.sorted_fields
221221
return AioModelSelect(cls, fields, is_default=is_default)
222222

223223
@classmethod
224-
def update(cls, __data=None, **update) -> AioModelUpdate:
224+
def update(cls, __data: Any = None, **update: Any) -> AioModelUpdate:
225225
return AioModelUpdate(cls, cls._normalize_data(__data, update))
226226

227227
@classmethod
228-
def insert(cls, __data=None, **insert) -> AioModelInsert:
228+
def insert(cls, __data: Any = None, **insert: Any) -> AioModelInsert:
229229
return AioModelInsert(cls, cls._normalize_data(__data, insert))
230230

231231
@classmethod
232-
def insert_many(cls, rows, fields=None) -> AioModelInsert:
232+
def insert_many(cls, rows: Any, fields: Any = None) -> AioModelInsert:
233233
return AioModelInsert(cls, insert=rows, columns=fields)
234234

235235
@classmethod
236-
def insert_from(cls, query, fields) -> AioModelInsert:
236+
def insert_from(cls, query: Any, fields: Any) -> AioModelInsert:
237237
columns = [getattr(cls, field) if isinstance(field, str)
238238
else field for field in fields]
239239
return AioModelInsert(cls, insert=query, columns=columns)
240240

241241
@classmethod
242-
def raw(cls, sql, *params) -> AioModelRaw:
242+
def raw(cls, sql: Optional[str], *params: Optional[List[Any]]) -> AioModelRaw:
243243
return AioModelRaw(cls, sql, params)
244244

245245
@classmethod
@@ -263,7 +263,7 @@ async def aio_delete_instance(self, recursive: bool = False, delete_nullable: bo
263263
await model.delete().where(query).aio_execute()
264264
return cast(int, await type(self).delete().where(self._pk_expr()).aio_execute())
265265

266-
async def aio_save(self, force_insert: bool = False, only=None) -> int:
266+
async def aio_save(self, force_insert: bool = False, only: Any =None) -> int:
267267
"""
268268
Async version of **peewee.Model.save**
269269
@@ -273,7 +273,7 @@ async def aio_save(self, force_insert: bool = False, only=None) -> int:
273273
field_dict = self.__data__.copy()
274274
if self._meta.primary_key is not False:
275275
pk_field = self._meta.primary_key
276-
pk_value = self._pk
276+
pk_value = self._pk # type: ignore
277277
else:
278278
pk_field = pk_value = None
279279
if only is not None:
@@ -313,7 +313,7 @@ async def aio_save(self, force_insert: bool = False, only=None) -> int:
313313
return rows
314314

315315
@classmethod
316-
async def aio_get(cls, *query, **filters) -> Self:
316+
async def aio_get(cls, *query: Any, **filters: Any) -> Self:
317317
"""Async version of **peewee.Model.get**
318318
319319
See also:
@@ -327,7 +327,7 @@ async def aio_get(cls, *query, **filters) -> Self:
327327
sq = sq.where(*query)
328328
if filters:
329329
sq = sq.filter(**filters)
330-
return await sq.aio_get()
330+
return cast(Self, await sq.aio_get())
331331

332332
@classmethod
333333
async def aio_get_or_none(cls, *query: Any, **filters: Any) -> Optional[Self]:

peewee_async/databases.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from .connection import connection_context, ConnectionContextManager
99
from .pool import PoolBackend, PostgresqlPoolBackend, MysqlPoolBackend
1010
from .transactions import Transaction
11-
from .utils import aiopg, aiomysql, __log__
11+
from .utils import aiopg, aiomysql, __log__, FetchResults
1212

1313

1414
class AioDatabase(peewee.Database):
@@ -109,7 +109,7 @@ def allow_sync(self) -> Iterator[None]:
109109
self._allow_sync = old_allow_sync
110110
self.close()
111111

112-
def execute_sql(self, *args: Any, **kwargs: Any):
112+
def execute_sql(self, *args: Any, **kwargs: Any) -> Any:
113113
"""Sync execute SQL query, `allow_sync` must be set to True.
114114
"""
115115
assert self._allow_sync, (
@@ -129,7 +129,12 @@ def aio_connection(self) -> ConnectionContextManager:
129129

130130
return ConnectionContextManager(self.pool_backend)
131131

132-
async def aio_execute_sql(self, sql: str, params: Optional[List[Any]] = None, fetch_results=None):
132+
async def aio_execute_sql(
133+
self,
134+
sql: str,
135+
params: Optional[List[Any]] = None,
136+
fetch_results: Optional[FetchResults] = None
137+
) -> Any:
133138
__log__.debug(sql, params)
134139
with peewee.__exception_wrapper__:
135140
async with self.aio_connection() as connection:
@@ -138,7 +143,7 @@ async def aio_execute_sql(self, sql: str, params: Optional[List[Any]] = None, fe
138143
if fetch_results is not None:
139144
return await fetch_results(cursor)
140145

141-
async def aio_execute(self, query, fetch_results=None) -> Any:
146+
async def aio_execute(self, query: Any, fetch_results: Optional[FetchResults] = None) -> Any:
142147
"""Execute *SELECT*, *INSERT*, *UPDATE* or *DELETE* query asyncronously.
143148
144149
:param query: peewee query instance created with ``Model.select()``,

peewee_async/utils.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,19 @@
11
import logging
2-
from typing import Any, Protocol, Optional, Sequence, Set, AsyncContextManager, List
2+
from typing import Any, Protocol, Optional, Sequence, Set, AsyncContextManager, List, Callable, Awaitable
33

44
try:
55
import aiopg
66
import psycopg2
77
except ImportError:
8-
aiopg = None
8+
aiopg = None # type: ignore
99
psycopg2 = None
1010

1111
try:
1212
import aiomysql
1313
import pymysql
1414
except ImportError:
1515
aiomysql = None
16-
pymysql = None
16+
pymysql = None # type: ignore
1717

1818
__log__ = logging.getLogger('peewee.async')
1919
__log__.addHandler(logging.NullHandler())
@@ -70,3 +70,5 @@ def terminate(self) -> None:
7070
async def wait_closed(self) -> None:
7171
...
7272

73+
74+
FetchResults = Callable[[CursorProtocol], Awaitable[Any]]

0 commit comments

Comments
 (0)