5
5
from .result_wrappers import fetch_models
6
6
from .utils import CursorProtocol
7
7
from typing_extensions import Self
8
- from typing import Tuple , List , Any , cast , Optional
8
+ from typing import Tuple , List , Any , cast , Optional , Dict , Union
9
9
10
10
11
- async def aio_prefetch (sq , * subqueries , prefetch_type : PREFETCH_TYPE = PREFETCH_TYPE .WHERE ):
11
+ async def aio_prefetch (sq : Any , * subqueries : Any , prefetch_type : PREFETCH_TYPE = PREFETCH_TYPE .WHERE ) -> Any :
12
12
"""Asynchronous version of `prefetch()`.
13
13
14
14
See also:
@@ -18,8 +18,8 @@ async def aio_prefetch(sq, *subqueries, prefetch_type: PREFETCH_TYPE = PREFETCH_
18
18
return sq
19
19
20
20
fixed_queries = peewee .prefetch_add_subquery (sq , subqueries , prefetch_type )
21
- deps = {}
22
- rel_map = {}
21
+ deps : Dict [ Any , Any ] = {}
22
+ rel_map : Dict [ Any , Any ] = {}
23
23
24
24
for pq in reversed (fixed_queries ):
25
25
query_model = pq .model
@@ -49,27 +49,27 @@ class AioQueryMixin:
49
49
async def aio_execute (self , database : AioDatabase ) -> Any :
50
50
return await database .aio_execute (self )
51
51
52
- async def fetch_results (self , cursor : CursorProtocol ) -> List [ Any ] :
52
+ async def fetch_results (self , cursor : CursorProtocol ) -> Any :
53
53
return await fetch_models (cursor , self )
54
54
55
55
56
56
class AioModelDelete (peewee .ModelDelete , AioQueryMixin ):
57
- async def fetch_results (self , cursor : CursorProtocol ):
57
+ async def fetch_results (self , cursor : CursorProtocol ) -> Union [ List [ Any ], int ] :
58
58
if self ._returning :
59
59
return await fetch_models (cursor , self )
60
60
return cursor .rowcount
61
61
62
62
63
63
class AioModelUpdate (peewee .ModelUpdate , AioQueryMixin ):
64
64
65
- async def fetch_results (self , cursor : CursorProtocol ):
65
+ async def fetch_results (self , cursor : CursorProtocol ) -> Union [ List [ Any ], int ] :
66
66
if self ._returning :
67
67
return await fetch_models (cursor , self )
68
68
return cursor .rowcount
69
69
70
70
71
71
class AioModelInsert (peewee .ModelInsert , AioQueryMixin ):
72
- async def fetch_results (self , cursor : CursorProtocol ):
72
+ async def fetch_results (self , cursor : CursorProtocol ) -> Union [ List [ Any ], Any , int ] :
73
73
if self ._returning is not None and len (self ._returning ) > 1 :
74
74
return await fetch_models (cursor , self )
75
75
@@ -96,26 +96,26 @@ async def aio_scalar(self, database: AioDatabase, as_tuple: bool = False) -> Any
96
96
See also:
97
97
http://docs.peewee-orm.com/en/3.15.3/peewee/api.html#SelectBase.scalar
98
98
"""
99
- async def fetch_results (cursor ) :
99
+ async def fetch_results (cursor : CursorProtocol ) -> Any :
100
100
return await cursor .fetchone ()
101
101
102
102
rows = await database .aio_execute (self , fetch_results = fetch_results )
103
103
104
104
return rows [0 ] if rows and not as_tuple else rows
105
105
106
- async def aio_get (self , database : Optional [AioDatabase ] = None ):
106
+ async def aio_get (self , database : Optional [AioDatabase ] = None ) -> Any :
107
107
"""
108
108
Async version of **peewee.SelectBase.get**
109
109
110
110
See also:
111
111
http://docs.peewee-orm.com/en/3.15.3/peewee/api.html#SelectBase.get
112
112
"""
113
- clone = self .paginate (1 , 1 )
113
+ clone = self .paginate (1 , 1 ) # type: ignore
114
114
try :
115
115
return (await clone .aio_execute (database ))[0 ]
116
116
except IndexError :
117
117
sql , params = clone .sql ()
118
- raise self .model .DoesNotExist ('%s instance matching query does '
118
+ raise self .model .DoesNotExist ('%s instance matching query does ' # type: ignore
119
119
'not exist:\n SQL: %s\n Params: %s' %
120
120
(clone .model , sql , params ))
121
121
@@ -127,7 +127,7 @@ async def aio_count(self, database: AioDatabase, clear_limit: bool = False) -> i
127
127
See also:
128
128
http://docs.peewee-orm.com/en/3.15.3/peewee/api.html#SelectBase.count
129
129
"""
130
- clone = self .order_by ().alias ('_wrapped' )
130
+ clone = self .order_by ().alias ('_wrapped' ) # type: ignore
131
131
if clear_limit :
132
132
clone ._limit = clone ._offset = None
133
133
try :
@@ -150,28 +150,28 @@ async def aio_exists(self, database: AioDatabase) -> bool:
150
150
See also:
151
151
http://docs.peewee-orm.com/en/3.15.3/peewee/api.html#SelectBase.exists
152
152
"""
153
- clone = self .columns (peewee .SQL ('1' ))
153
+ clone = self .columns (peewee .SQL ('1' )) # type: ignore
154
154
clone ._limit = 1
155
155
clone ._offset = None
156
156
return bool (await clone .aio_scalar ())
157
157
158
- def union_all (self , rhs ) :
159
- return AioModelCompoundSelectQuery (self .model , self , 'UNION ALL' , rhs )
158
+ def union_all (self , rhs : Any ) -> "AioModelCompoundSelectQuery" :
159
+ return AioModelCompoundSelectQuery (self .model , self , 'UNION ALL' , rhs ) # type: ignore
160
160
__add__ = union_all
161
161
162
- def union (self , rhs ) :
163
- return AioModelCompoundSelectQuery (self .model , self , 'UNION' , rhs )
162
+ def union (self , rhs : Any ) -> "AioModelCompoundSelectQuery" :
163
+ return AioModelCompoundSelectQuery (self .model , self , 'UNION' , rhs ) # type: ignore
164
164
__or__ = union
165
165
166
- def intersect (self , rhs ) :
167
- return AioModelCompoundSelectQuery (self .model , self , 'INTERSECT' , rhs )
166
+ def intersect (self , rhs : Any ) -> "AioModelCompoundSelectQuery" :
167
+ return AioModelCompoundSelectQuery (self .model , self , 'INTERSECT' , rhs ) # type: ignore
168
168
__and__ = intersect
169
169
170
- def except_ (self , rhs ) :
171
- return AioModelCompoundSelectQuery (self .model , self , 'EXCEPT' , rhs )
170
+ def except_ (self , rhs : Any ) -> "AioModelCompoundSelectQuery" :
171
+ return AioModelCompoundSelectQuery (self .model , self , 'EXCEPT' , rhs ) # type: ignore
172
172
__sub__ = except_
173
173
174
- def aio_prefetch (self , * subqueries , prefetch_type : PREFETCH_TYPE = PREFETCH_TYPE .WHERE ):
174
+ def aio_prefetch (self , * subqueries : Any , prefetch_type : PREFETCH_TYPE = PREFETCH_TYPE .WHERE ) -> Any :
175
175
"""
176
176
Async version of **peewee.ModelSelect.prefetch**
177
177
@@ -214,32 +214,32 @@ class User(peewee_async.AioModel):
214
214
"""
215
215
216
216
@classmethod
217
- def select (cls , * fields ) -> AioModelSelect :
217
+ def select (cls , * fields : Any ) -> AioModelSelect :
218
218
is_default = not fields
219
219
if not fields :
220
220
fields = cls ._meta .sorted_fields
221
221
return AioModelSelect (cls , fields , is_default = is_default )
222
222
223
223
@classmethod
224
- def update (cls , __data = None , ** update ) -> AioModelUpdate :
224
+ def update (cls , __data : Any = None , ** update : Any ) -> AioModelUpdate :
225
225
return AioModelUpdate (cls , cls ._normalize_data (__data , update ))
226
226
227
227
@classmethod
228
- def insert (cls , __data = None , ** insert ) -> AioModelInsert :
228
+ def insert (cls , __data : Any = None , ** insert : Any ) -> AioModelInsert :
229
229
return AioModelInsert (cls , cls ._normalize_data (__data , insert ))
230
230
231
231
@classmethod
232
- def insert_many (cls , rows , fields = None ) -> AioModelInsert :
232
+ def insert_many (cls , rows : Any , fields : Any = None ) -> AioModelInsert :
233
233
return AioModelInsert (cls , insert = rows , columns = fields )
234
234
235
235
@classmethod
236
- def insert_from (cls , query , fields ) -> AioModelInsert :
236
+ def insert_from (cls , query : Any , fields : Any ) -> AioModelInsert :
237
237
columns = [getattr (cls , field ) if isinstance (field , str )
238
238
else field for field in fields ]
239
239
return AioModelInsert (cls , insert = query , columns = columns )
240
240
241
241
@classmethod
242
- def raw (cls , sql , * params ) -> AioModelRaw :
242
+ def raw (cls , sql : Optional [ str ] , * params : Optional [ List [ Any ]] ) -> AioModelRaw :
243
243
return AioModelRaw (cls , sql , params )
244
244
245
245
@classmethod
@@ -263,7 +263,7 @@ async def aio_delete_instance(self, recursive: bool = False, delete_nullable: bo
263
263
await model .delete ().where (query ).aio_execute ()
264
264
return cast (int , await type (self ).delete ().where (self ._pk_expr ()).aio_execute ())
265
265
266
- async def aio_save (self , force_insert : bool = False , only = None ) -> int :
266
+ async def aio_save (self , force_insert : bool = False , only : Any = None ) -> int :
267
267
"""
268
268
Async version of **peewee.Model.save**
269
269
@@ -273,7 +273,7 @@ async def aio_save(self, force_insert: bool = False, only=None) -> int:
273
273
field_dict = self .__data__ .copy ()
274
274
if self ._meta .primary_key is not False :
275
275
pk_field = self ._meta .primary_key
276
- pk_value = self ._pk
276
+ pk_value = self ._pk # type: ignore
277
277
else :
278
278
pk_field = pk_value = None
279
279
if only is not None :
@@ -313,7 +313,7 @@ async def aio_save(self, force_insert: bool = False, only=None) -> int:
313
313
return rows
314
314
315
315
@classmethod
316
- async def aio_get (cls , * query , ** filters ) -> Self :
316
+ async def aio_get (cls , * query : Any , ** filters : Any ) -> Self :
317
317
"""Async version of **peewee.Model.get**
318
318
319
319
See also:
@@ -327,7 +327,7 @@ async def aio_get(cls, *query, **filters) -> Self:
327
327
sq = sq .where (* query )
328
328
if filters :
329
329
sq = sq .filter (** filters )
330
- return await sq .aio_get ()
330
+ return cast ( Self , await sq .aio_get () )
331
331
332
332
@classmethod
333
333
async def aio_get_or_none (cls , * query : Any , ** filters : Any ) -> Optional [Self ]:
0 commit comments