Skip to content

Commit 3089e5f

Browse files
committed
feat: aio_count method added
1 parent bc6bc52 commit 3089e5f

File tree

5 files changed

+122
-86
lines changed

5 files changed

+122
-86
lines changed

peewee_async.py

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -397,7 +397,8 @@ async def aio_execute(self, query, fetch_results=None):
397397
# To make `Database.aio_execute` compatible with peewee's sync queries we
398398
# apply optional patching, it will do nothing for Aio-counterparts:
399399
_patch_query_with_compat_methods(query, None)
400-
sql, params = query.sql()
400+
ctx = self.get_sql_context()
401+
sql, params = ctx.sql(query).query()
401402
fetch_results = fetch_results or getattr(query, 'fetch_results', None)
402403
return await self.aio_execute_sql(sql, params, fetch_results=fetch_results)
403404

@@ -694,7 +695,7 @@ async def fetch_results(self, cursor):
694695
return await self.make_async_query_wrapper(cursor)
695696

696697

697-
class AioModelSelect(peewee.ModelSelect, AioQueryMixin):
698+
class AioSelectMixin(AioQueryMixin):
698699

699700
async def fetch_results(self, cursor):
700701
return await self.make_async_query_wrapper(cursor)
@@ -723,6 +724,28 @@ async def aio_get(self, database=None):
723724
'not exist:\nSQL: %s\nParams: %s' %
724725
(clone.model, sql, params))
725726

727+
@peewee.database_required
728+
async def aio_count(self, database, clear_limit=False):
729+
clone = self.order_by().alias('_wrapped')
730+
if clear_limit:
731+
clone._limit = clone._offset = None
732+
try:
733+
if clone._having is None and clone._group_by is None and \
734+
clone._windows is None and clone._distinct is None and \
735+
clone._simple_distinct is not True:
736+
clone = clone.select(peewee.SQL('1'))
737+
except AttributeError:
738+
pass
739+
return await AioSelect([clone], [peewee.fn.COUNT(peewee.SQL('1'))]).aio_scalar(database)
740+
741+
742+
class AioSelect(peewee.Select, AioSelectMixin):
743+
pass
744+
745+
746+
class AioModelSelect(peewee.ModelSelect, AioSelectMixin):
747+
pass
748+
726749

727750
class AioModel(peewee.Model):
728751
"""Async version of **peewee.Model** that allows to execute queries asynchronously

peewee_async_compat.py

Lines changed: 8 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,7 @@ def _patch_query_with_compat_methods(query, async_query_cls):
7878
if async_query_cls is AioModelSelect:
7979
query.aio_get = partial(async_query_cls.aio_get, query)
8080
query.aio_scalar = partial(async_query_cls.aio_scalar, query)
81+
query.aio_count = partial(async_query_cls.aio_count, query)
8182

8283

8384
def _query_db(query):
@@ -94,25 +95,13 @@ async def count(query, clear_limit=False):
9495
9596
:return: number of objects in `select()` query
9697
"""
97-
database = _query_db(query)
98-
clone = query.clone()
99-
if query._distinct or query._group_by or query._limit or query._offset:
100-
if clear_limit:
101-
clone._limit = clone._offset = None
102-
sql, params = clone.sql()
103-
wrapped = 'SELECT COUNT(1) FROM (%s) AS wrapped_select' % sql
104-
async def fetch_results(cursor):
105-
row = await cursor.fetchone()
106-
if row:
107-
return row[0]
108-
else:
109-
return row
110-
result = await database.aio_execute_sql(wrapped, params, fetch_results)
111-
return result or 0
112-
else:
113-
clone._returning = [peewee.fn.Count(peewee.SQL('*'))]
114-
clone._order_by = None
115-
return (await scalar(clone)) or 0
98+
from peewee_async import AioModelSelect # noqa
99+
warnings.warn(
100+
"`count` is deprecated, use `query.aio_count` method.",
101+
DeprecationWarning
102+
)
103+
_patch_query_with_compat_methods(query, AioModelSelect)
104+
return await query.aio_count(clear_limit=clear_limit)
116105

117106

118107
async def prefetch(sq, *subqueries, prefetch_type):

tests/aio_model/test_shortcuts.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,3 +43,21 @@ async def test_aio_scalar(db):
4343
).aio_scalar(as_tuple=True) == (2, 1)
4444

4545
assert await TestModel.select().aio_scalar() is None
46+
47+
48+
@dbs_all
49+
async def test_count_query(db):
50+
51+
for num in range(5):
52+
await IntegerTestModel.aio_create(num=num)
53+
count = await IntegerTestModel.select().limit(3).aio_count()
54+
assert count == 3
55+
56+
57+
@dbs_all
58+
async def test_count_query_clear_limit(db):
59+
60+
for num in range(5):
61+
await IntegerTestModel.aio_create(num=num)
62+
count = await IntegerTestModel.select().limit(3).aio_count(clear_limit=True)
63+
assert count == 5

tests/compat/test_shortcuts.py

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
import uuid
2+
3+
import peewee
4+
5+
from tests.conftest import manager_for_all_dbs
6+
from tests.models import CompatTestModel
7+
8+
9+
@manager_for_all_dbs
10+
async def test_get_or_none(manager):
11+
"""Test get_or_none manager function."""
12+
text1 = "Test %s" % uuid.uuid4()
13+
text2 = "Test %s" % uuid.uuid4()
14+
15+
obj1 = await manager.create(CompatTestModel, text=text1)
16+
obj2 = await manager.get_or_none(CompatTestModel, text=text1)
17+
obj3 = await manager.get_or_none(CompatTestModel, text=text2)
18+
19+
assert obj1 == obj2
20+
assert obj1 is not None
21+
assert obj2 is not None
22+
assert obj3 is None
23+
24+
25+
@manager_for_all_dbs
26+
async def test_count_query_with_limit(manager):
27+
text = "Test %s" % uuid.uuid4()
28+
await manager.create(CompatTestModel, text=text)
29+
text = "Test %s" % uuid.uuid4()
30+
await manager.create(CompatTestModel, text=text)
31+
text = "Test %s" % uuid.uuid4()
32+
await manager.create(CompatTestModel, text=text)
33+
34+
count = await manager.count(CompatTestModel.select().limit(1))
35+
assert count == 1
36+
37+
38+
@manager_for_all_dbs
39+
async def test_count_query(manager):
40+
text = "Test %s" % uuid.uuid4()
41+
await manager.create(CompatTestModel, text=text)
42+
text = "Test %s" % uuid.uuid4()
43+
await manager.create(CompatTestModel, text=text)
44+
text = "Test %s" % uuid.uuid4()
45+
await manager.create(CompatTestModel, text=text)
46+
47+
count = await manager.count(CompatTestModel.select())
48+
assert count == 3
49+
50+
51+
@manager_for_all_dbs
52+
async def test_scalar_query(manager):
53+
54+
text = "Test %s" % uuid.uuid4()
55+
await manager.create(CompatTestModel, text=text)
56+
text = "Test %s" % uuid.uuid4()
57+
await manager.create(CompatTestModel, text=text)
58+
59+
fn = peewee.fn.Count(CompatTestModel.id)
60+
count = await manager.scalar(CompatTestModel.select(fn))
61+
62+
assert count == 2
63+
64+
65+
@manager_for_all_dbs
66+
async def test_create_obj(manager):
67+
68+
text = "Test %s" % uuid.uuid4()
69+
obj = await manager.create(CompatTestModel, text=text)
70+
assert obj is not None
71+
assert obj.text == text

tests/test_shortcuts.py

Lines changed: 0 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -43,62 +43,6 @@ async def test_prefetch(manager, prefetch_type):
4343
assert tuple(result[0].betas[0].gammas) == (gamma_111, gamma_112)
4444

4545

46-
@manager_for_all_dbs
47-
async def test_get_or_none(manager):
48-
"""Test get_or_none manager function."""
49-
text1 = "Test %s" % uuid.uuid4()
50-
text2 = "Test %s" % uuid.uuid4()
51-
52-
obj1 = await manager.create(TestModel, text=text1)
53-
obj2 = await manager.get_or_none(TestModel, text=text1)
54-
obj3 = await manager.get_or_none(TestModel, text=text2)
55-
56-
assert obj1 == obj2
57-
assert obj1 is not None
58-
assert obj2 is not None
59-
assert obj3 is None
60-
61-
62-
@manager_for_all_dbs
63-
async def test_count_query_with_limit(manager):
64-
text = "Test %s" % uuid.uuid4()
65-
await manager.create(TestModel, text=text)
66-
text = "Test %s" % uuid.uuid4()
67-
await manager.create(TestModel, text=text)
68-
text = "Test %s" % uuid.uuid4()
69-
await manager.create(TestModel, text=text)
70-
71-
count = await manager.count(TestModel.select().limit(1))
72-
assert count == 1
73-
74-
75-
@manager_for_all_dbs
76-
async def test_count_query(manager):
77-
text = "Test %s" % uuid.uuid4()
78-
await manager.create(TestModel, text=text)
79-
text = "Test %s" % uuid.uuid4()
80-
await manager.create(TestModel, text=text)
81-
text = "Test %s" % uuid.uuid4()
82-
await manager.create(TestModel, text=text)
83-
84-
count = await manager.count(TestModel.select())
85-
assert count == 3
86-
87-
88-
@manager_for_all_dbs
89-
async def test_scalar_query(manager):
90-
91-
text = "Test %s" % uuid.uuid4()
92-
await manager.create(TestModel, text=text)
93-
text = "Test %s" % uuid.uuid4()
94-
await manager.create(TestModel, text=text)
95-
96-
fn = peewee.fn.Count(TestModel.id)
97-
count = await manager.scalar(TestModel.select(fn))
98-
99-
assert count == 2
100-
101-
10246
@manager_for_all_dbs
10347
async def test_delete_obj(manager):
10448
text = "Test %s" % uuid.uuid4()
@@ -124,15 +68,6 @@ async def test_update_obj(manager):
12468
assert obj2.text == "Test update object"
12569

12670

127-
@manager_for_all_dbs
128-
async def test_create_obj(manager):
129-
130-
text = "Test %s" % uuid.uuid4()
131-
obj = await manager.create(TestModel, text=text)
132-
assert obj is not None
133-
assert obj.text == text
134-
135-
13671
@manager_for_all_dbs
13772
async def test_create_or_get(manager):
13873
text = "Test %s" % uuid.uuid4()

0 commit comments

Comments
 (0)