Skip to content

Commit 1cc6481

Browse files
authored
feat: add aio_peek aio_frist (#312)
1 parent a0d2204 commit 1cc6481

File tree

6 files changed

+186
-50
lines changed

6 files changed

+186
-50
lines changed

docs/peewee_async/api.rst

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,8 +57,12 @@ AioModelSelect
5757

5858
.. autoclass:: peewee_async.aio_model.AioModelSelect
5959

60+
.. automethod:: peewee_async.aio_model.AioModelSelect.aio_peek
61+
6062
.. automethod:: peewee_async.aio_model.AioModelSelect.aio_scalar
6163

64+
.. automethod:: peewee_async.aio_model.AioModelSelect.aio_first
65+
6266
.. automethod:: peewee_async.aio_model.AioModelSelect.aio_get
6367

6468
.. automethod:: peewee_async.aio_model.AioModelSelect.aio_count

peewee_async/aio_model.py

Lines changed: 53 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -84,50 +84,72 @@ class AioModelRaw(peewee.ModelRaw, AioQueryMixin):
8484
pass
8585

8686

87-
class AioSelectMixin(AioQueryMixin):
87+
class AioSelectMixin(AioQueryMixin, peewee.SelectBase):
88+
8889

8990
@peewee.database_required
90-
async def aio_scalar(self, database: AioDatabase, as_tuple: bool = False) -> Any:
91+
async def aio_peek(self, database: AioDatabase, n: int = 1) -> Any:
9192
"""
92-
Get single value from ``select()`` query, i.e. for aggregation.
93-
94-
:return: result is the same as after sync ``query.scalar()`` call
95-
96-
See also:
97-
http://docs.peewee-orm.com/en/3.15.3/peewee/api.html#SelectBase.scalar
93+
Asynchronous version of
94+
`peewee.SelectBase.peek <https://docs.peewee-orm.com/en/latest/peewee/api.html#SelectBase.peek>`_
9895
"""
96+
9997
async def fetch_results(cursor: CursorProtocol) -> Any:
100-
return await cursor.fetchone()
98+
return await fetch_models(cursor, self, n)
10199

102100
rows = await database.aio_execute(self, fetch_results=fetch_results)
101+
if rows:
102+
return rows[0] if n == 1 else rows
103103

104-
return rows[0] if rows and not as_tuple else rows
104+
@peewee.database_required
105+
async def aio_scalar(
106+
self,
107+
database: AioDatabase,
108+
as_tuple: bool = False,
109+
as_dict: bool = False
110+
) -> Any:
111+
"""
112+
Asynchronous version of `peewee.SelectBase.scalar
113+
<https://docs.peewee-orm.com/en/latest/peewee/api.html#SelectBase.scalar>`_
114+
"""
115+
if as_dict:
116+
return await self.dicts().aio_peek(database)
117+
row = await self.tuples().aio_peek(database)
105118

106-
async def aio_get(self, database: Optional[AioDatabase] = None) -> Any:
119+
return row[0] if row and not as_tuple else row
120+
121+
@peewee.database_required
122+
async def aio_first(self, database: AioDatabase, n: int = 1) -> Any:
123+
"""
124+
Asynchronous version of `peewee.SelectBase.first
125+
<https://docs.peewee-orm.com/en/latest/peewee/api.html#SelectBase.first>`_
107126
"""
108-
Async version of **peewee.SelectBase.get**
109127

110-
See also:
111-
http://docs.peewee-orm.com/en/3.15.3/peewee/api.html#SelectBase.get
128+
if self._limit != n: # type: ignore
129+
self._limit = n
130+
return await self.aio_peek(database, n=n)
131+
132+
async def aio_get(self, database: Optional[AioDatabase] = None) -> Any:
112133
"""
113-
clone = self.paginate(1, 1) # type: ignore
134+
Asynchronous version of `peewee.SelectBase.get
135+
<https://docs.peewee-orm.com/en/latest/peewee/api.html#SelectBase.get>`_
136+
"""
137+
clone = self.paginate(1, 1)
114138
try:
115139
return (await clone.aio_execute(database))[0]
116140
except IndexError:
117141
sql, params = clone.sql()
118-
raise self.model.DoesNotExist('%s instance matching query does ' # type: ignore
142+
raise self.model.DoesNotExist('%s instance matching query does '
119143
'not exist:\nSQL: %s\nParams: %s' %
120144
(clone.model, sql, params))
121145

122146
@peewee.database_required
123147
async def aio_count(self, database: AioDatabase, clear_limit: bool = False) -> int:
124148
"""
125-
Async version of **peewee.SelectBase.count**
126-
127-
See also:
128-
http://docs.peewee-orm.com/en/3.15.3/peewee/api.html#SelectBase.count
149+
Asynchronous version of `peewee.SelectBase.count
150+
<https://docs.peewee-orm.com/en/latest/peewee/api.html#SelectBase.count>`_
129151
"""
130-
clone = self.order_by().alias('_wrapped') # type: ignore
152+
clone = self.order_by().alias('_wrapped')
131153
if clear_limit:
132154
clone._limit = clone._offset = None
133155
try:
@@ -145,38 +167,34 @@ async def aio_count(self, database: AioDatabase, clear_limit: bool = False) -> i
145167
@peewee.database_required
146168
async def aio_exists(self, database: AioDatabase) -> bool:
147169
"""
148-
Async version of **peewee.SelectBase.exists**
149-
150-
See also:
151-
http://docs.peewee-orm.com/en/3.15.3/peewee/api.html#SelectBase.exists
170+
Asynchronous version of `peewee.SelectBase.exists
171+
<https://docs.peewee-orm.com/en/latest/peewee/api.html#SelectBase.exists>`_
152172
"""
153-
clone = self.columns(peewee.SQL('1')) # type: ignore
173+
clone = self.columns(peewee.SQL('1'))
154174
clone._limit = 1
155175
clone._offset = None
156176
return bool(await clone.aio_scalar())
157177

158178
def union_all(self, rhs: Any) -> "AioModelCompoundSelectQuery":
159-
return AioModelCompoundSelectQuery(self.model, self, 'UNION ALL', rhs) # type: ignore
179+
return AioModelCompoundSelectQuery(self.model, self, 'UNION ALL', rhs)
160180
__add__ = union_all
161181

162182
def union(self, rhs: Any) -> "AioModelCompoundSelectQuery":
163-
return AioModelCompoundSelectQuery(self.model, self, 'UNION', rhs) # type: ignore
183+
return AioModelCompoundSelectQuery(self.model, self, 'UNION', rhs)
164184
__or__ = union
165185

166186
def intersect(self, rhs: Any) -> "AioModelCompoundSelectQuery":
167-
return AioModelCompoundSelectQuery(self.model, self, 'INTERSECT', rhs) # type: ignore
187+
return AioModelCompoundSelectQuery(self.model, self, 'INTERSECT', rhs)
168188
__and__ = intersect
169189

170190
def except_(self, rhs: Any) -> "AioModelCompoundSelectQuery":
171-
return AioModelCompoundSelectQuery(self.model, self, 'EXCEPT', rhs) # type: ignore
191+
return AioModelCompoundSelectQuery(self.model, self, 'EXCEPT', rhs)
172192
__sub__ = except_
173193

174194
def aio_prefetch(self, *subqueries: Any, prefetch_type: PREFETCH_TYPE = PREFETCH_TYPE.WHERE) -> Any:
175195
"""
176-
Async version of **peewee.ModelSelect.prefetch**
177-
178-
See also:
179-
http://docs.peewee-orm.com/en/3.15.3/peewee/api.html#ModelSelect.prefetch
196+
Asynchronous version of `peewee.ModelSelect.prefetch
197+
<https://docs.peewee-orm.com/en/latest/peewee/api.html#ModelSelect.prefetch>`_
180198
"""
181199
return aio_prefetch(self, *subqueries, prefetch_type=prefetch_type)
182200

@@ -186,7 +204,7 @@ class AioSelect(AioSelectMixin, peewee.Select):
186204

187205

188206
class AioModelSelect(AioSelectMixin, peewee.ModelSelect):
189-
"""Async version of **peewee.ModelSelect** that provides async versions of ModelSelect methods
207+
"""Asynchronous version of **peewee.ModelSelect** that provides async versions of ModelSelect methods
190208
"""
191209
pass
192210

peewee_async/databases.py

Lines changed: 57 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -18,16 +18,21 @@ class AioDatabase(peewee.Database):
1818
connection and **async connections pool** interface.
1919
2020
:param pool_params: parameters that are passed to the pool
21-
:param min_connections: min connections pool size. Alias for pool_params.minsize
22-
:param max_connections: max connections pool size. Alias for pool_params.maxsize
2321
2422
Example::
2523
2624
database = PooledPostgresqlExtDatabase(
27-
'test',
28-
'min_connections': 1,
29-
'max_connections': 5,
30-
'pool_params': {"timeout": 30, 'pool_recycle': 1.5}
25+
'database': 'postgres',
26+
'host': '127.0.0.1',
27+
'port':5432,
28+
'password': 'postgres',
29+
'user': 'postgres',
30+
'pool_params': {
31+
"minsize": 0,
32+
"maxsize": 5,
33+
"timeout": 30,
34+
'pool_recycle': 1.5
35+
}
3136
)
3237
3338
See also:
@@ -189,8 +194,23 @@ class PsycopgDatabase(AioDatabase, Psycopg3Database):
189194
"""Extension for `peewee.PostgresqlDatabase` providing extra methods
190195
for managing async connection based on psycopg3 pool backend.
191196
197+
Example::
198+
199+
database = PsycopgDatabase(
200+
'database': 'postgres',
201+
'host': '127.0.0.1',
202+
'port': 5432,
203+
'password': 'postgres',
204+
'user': 'postgres',
205+
'pool_params': {
206+
"min_size": 0,
207+
"max_size": 5,
208+
'max_lifetime': 15
209+
}
210+
)
211+
192212
See also:
193-
https://peewee.readthedocs.io/en/latest/peewee/api.html#PostgresqlDatabase
213+
https://www.psycopg.org/psycopg3/docs/advanced/pool.html
194214
"""
195215

196216
pool_backend_cls = PsycopgPoolBackend
@@ -205,6 +225,23 @@ class PooledPostgresqlDatabase(AioDatabase, peewee.PostgresqlDatabase):
205225
"""Extension for `peewee.PostgresqlDatabase` providing extra methods
206226
for managing async connection based on aiopg pool backend.
207227
228+
229+
Example::
230+
231+
database = PooledPostgresqlExtDatabase(
232+
'database': 'postgres',
233+
'host': '127.0.0.1',
234+
'port':5432,
235+
'password': 'postgres',
236+
'user': 'postgres',
237+
'pool_params': {
238+
"minsize": 0,
239+
"maxsize": 5,
240+
"timeout": 30,
241+
'pool_recycle': 1.5
242+
}
243+
)
244+
208245
See also:
209246
https://peewee.readthedocs.io/en/latest/peewee/api.html#PostgresqlDatabase
210247
"""
@@ -230,11 +267,6 @@ class PooledPostgresqlExtDatabase(
230267
JSON fields support is enabled by default, HStore supports is disabled by
231268
default, but can be enabled through pool_params or with ``register_hstore=False`` argument.
232269
233-
Example::
234-
235-
database = PooledPostgresqlExtDatabase('test', register_hstore=False,
236-
max_connections=20)
237-
238270
See also:
239271
https://peewee.readthedocs.io/en/latest/peewee/playhouse.html#PostgresqlExtDatabase
240272
"""
@@ -251,7 +283,19 @@ class PooledMySQLDatabase(AioDatabase, peewee.MySQLDatabase):
251283
252284
Example::
253285
254-
database = PooledMySQLDatabase('test', max_connections=10)
286+
database = PooledMySQLDatabase(
287+
'database': 'mysql',
288+
'host': '127.0.0.1',
289+
'port': 3306,
290+
'user': 'root',
291+
'password': 'mysql',
292+
'connect_timeout': 30,
293+
"pool_params": {
294+
"minsize": 0,
295+
"maxsize": 5,
296+
"pool_recycle": 2
297+
}
298+
)
255299
256300
See also:
257301
http://peewee.readthedocs.io/en/latest/peewee/api.html#MySQLDatabase

peewee_async/result_wrappers.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,11 @@ def close(self) -> None:
2323
pass
2424

2525

26-
async def fetch_models(cursor: CursorProtocol, query: BaseQuery) -> List[Any]:
27-
rows = await cursor.fetchall()
26+
async def fetch_models(cursor: CursorProtocol, query: BaseQuery, count: Optional[int] = None) -> List[Any]:
27+
if count is None:
28+
rows = await cursor.fetchall()
29+
else:
30+
rows = await cursor.fetchmany(count)
2831
sync_cursor = SyncCursorAdapter(rows, cursor.description)
2932
_result_wrapper = query._get_cursor_wrapper(sync_cursor)
3033
return list(_result_wrapper)

peewee_async/utils.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,9 @@ async def fetchone(self) -> Any:
3333
async def fetchall(self) -> List[Any]:
3434
...
3535

36+
async def fetchmany(self, size: int) -> List[Any]:
37+
...
38+
3639
@property
3740
def lastrowid(self) -> int:
3841
...

tests/aio_model/test_shortcuts.py

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from typing import List, Union
12
import uuid
23

34
import peewee
@@ -35,6 +36,64 @@ async def test_aio_get_or_none(db: AioDatabase) -> None:
3536
assert result is None
3637

3738

39+
@dbs_all
40+
@pytest.mark.parametrize(
41+
["peek_num", "expected"],
42+
(
43+
(1, 1),
44+
(2, [1,2]),
45+
(5, [1,2,3]),
46+
)
47+
)
48+
async def test_aio_peek(
49+
db: AioDatabase,
50+
peek_num: int,
51+
expected: Union[int, List[int]]
52+
) -> None:
53+
await IntegerTestModel.aio_create(num=1)
54+
await IntegerTestModel.aio_create(num=2)
55+
await IntegerTestModel.aio_create(num=3)
56+
57+
rows = await IntegerTestModel.select().order_by(
58+
IntegerTestModel.num
59+
).aio_peek(n=peek_num)
60+
61+
if isinstance(rows, list):
62+
result = [r.num for r in rows]
63+
else:
64+
result = rows.num
65+
assert result == expected
66+
67+
68+
@dbs_all
69+
@pytest.mark.parametrize(
70+
["first_num", "expected"],
71+
(
72+
(1, 1),
73+
(2, [1,2]),
74+
(5, [1,2,3]),
75+
)
76+
)
77+
async def test_aio_first(
78+
db: AioDatabase,
79+
first_num: int,
80+
expected: Union[int, List[int]]
81+
) -> None:
82+
await IntegerTestModel.aio_create(num=1)
83+
await IntegerTestModel.aio_create(num=2)
84+
await IntegerTestModel.aio_create(num=3)
85+
86+
rows = await IntegerTestModel.select().order_by(
87+
IntegerTestModel.num
88+
).aio_first(n=first_num)
89+
90+
if isinstance(rows, list):
91+
result = [r.num for r in rows]
92+
else:
93+
result = rows.num
94+
assert result == expected
95+
96+
3897
@dbs_all
3998
async def test_aio_scalar(db: AioDatabase) -> None:
4099
await IntegerTestModel.aio_create(num=1)
@@ -46,6 +105,11 @@ async def test_aio_scalar(db: AioDatabase) -> None:
46105
fn.MAX(IntegerTestModel.num),fn.Min(IntegerTestModel.num)
47106
).aio_scalar(as_tuple=True) == (2, 1)
48107

108+
assert await IntegerTestModel.select(
109+
fn.MAX(IntegerTestModel.num).alias('max'),
110+
fn.Min(IntegerTestModel.num).alias('min')
111+
).aio_scalar(as_dict=True) == {'max': 2, 'min': 1}
112+
49113
assert await TestModel.select().aio_scalar() is None
50114

51115

0 commit comments

Comments
 (0)