Skip to content

Commit c93271e

Browse files
authored
chore: Refactor Async(Postgresql/MySQL)Connection (#213)
1 parent 9e3dda8 commit c93271e

File tree

1 file changed

+50
-69
lines changed

1 file changed

+50
-69
lines changed

peewee_async.py

Lines changed: 50 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
Copyright (c) 2014, Alexey Kinëv <[email protected]>
1414
1515
"""
16+
import abc
1617
import asyncio
1718
import contextlib
1819
import functools
@@ -712,7 +713,7 @@ async def connect_async(self, loop=None, timeout=None):
712713
timeout=timeout,
713714
**self.connect_params_async
714715
)
715-
await conn.connect()
716+
await conn.create()
716717
self._async_conn = conn
717718

718719
async def cursor_async(self):
@@ -734,7 +735,7 @@ async def close_async(self):
734735
if self._async_conn:
735736
conn = self._async_conn
736737
self._async_conn = None
737-
await conn.close()
738+
await conn.terminate()
738739

739740
async def push_transaction_async(self):
740741
"""Increment async transaction depth.
@@ -851,19 +852,14 @@ async def aio_execute(self, query):
851852
return (await coroutine(query))
852853

853854

854-
##############
855-
# PostgreSQL #
856-
##############
857-
858-
859-
class AsyncPostgresqlConnection:
855+
class AioPool(metaclass=abc.ABCMeta):
860856
"""Asynchronous database connection pool.
861857
"""
862858
def __init__(self, *, database=None, loop=None, timeout=None, **kwargs):
863859
self.pool = None
864860
self.loop = loop
865861
self.database = database
866-
self.timeout = timeout or aiopg.DEFAULT_TIMEOUT
862+
self.timeout = timeout
867863
self.connect_params = kwargs
868864

869865
async def acquire(self):
@@ -876,24 +872,20 @@ def release(self, conn):
876872
"""
877873
self.pool.release(conn)
878874

879-
async def connect(self):
875+
@abc.abstractmethod
876+
async def create(self):
880877
"""Create connection pool asynchronously.
881878
"""
882-
self.pool = await aiopg.create_pool(
883-
loop=self.loop,
884-
timeout=self.timeout,
885-
database=self.database,
886-
**self.connect_params)
879+
raise NotImplementedError
887880

888-
async def close(self):
881+
async def terminate(self):
889882
"""Terminate all pool connections.
890883
"""
891884
self.pool.terminate()
892885
await self.pool.wait_closed()
893886

894887
async def cursor(self, conn=None, *args, **kwargs):
895-
"""Get a cursor for the specified transaction connection
896-
or acquire from the pool.
888+
"""Get cursor for connection from pool.
897889
"""
898890
in_transaction = conn is not None
899891
if not conn:
@@ -914,10 +906,44 @@ async def release_cursor(self, cursor, in_transaction=False):
914906
the connection is also released back to the pool.
915907
"""
916908
conn = cursor.connection
917-
cursor.close()
909+
await self.close_cursor(cursor)
918910
if not in_transaction:
919911
self.release(conn)
920912

913+
@abc.abstractmethod
914+
async def close_cursor(self, cursor):
915+
raise NotImplementedError
916+
917+
918+
919+
##############
920+
# PostgreSQL #
921+
##############
922+
923+
924+
class AioPostgresqlPool(AioPool):
925+
"""Asynchronous database connection pool.
926+
"""
927+
def __init__(self, *, database=None, loop=None, timeout=None, **kwargs):
928+
super().__init__(
929+
database=database,
930+
loop=loop,
931+
timeout=timeout or aiopg.DEFAULT_TIMEOUT,
932+
**kwargs,
933+
)
934+
935+
async def create(self):
936+
"""Create connection pool asynchronously.
937+
"""
938+
self.pool = await aiopg.create_pool(
939+
loop=self.loop,
940+
timeout=self.timeout,
941+
database=self.database,
942+
**self.connect_params)
943+
944+
async def close_cursor(self, cursor):
945+
cursor.close()
946+
921947

922948
class AsyncPostgresqlMixin(AsyncDatabase):
923949
"""Mixin for `peewee.PostgresqlDatabase` providing extra methods
@@ -926,7 +952,7 @@ class AsyncPostgresqlMixin(AsyncDatabase):
926952
if psycopg2:
927953
Error = psycopg2.Error
928954

929-
def init_async(self, conn_cls=AsyncPostgresqlConnection,
955+
def init_async(self, conn_cls=AioPostgresqlPool,
930956
enable_json=False, enable_hstore=False):
931957
if not aiopg:
932958
raise Exception("Error, aiopg is not installed!")
@@ -1027,27 +1053,11 @@ def use_speedups(self, value):
10271053
#########
10281054

10291055

1030-
class AsyncMySQLConnection:
1056+
class AioMysqlPool(AioPool):
10311057
"""Asynchronous database connection pool.
10321058
"""
1033-
def __init__(self, *, database=None, loop=None, timeout=None, **kwargs):
1034-
self.pool = None
1035-
self.loop = loop
1036-
self.database = database
1037-
self.timeout = timeout
1038-
self.connect_params = kwargs
1039-
1040-
async def acquire(self):
1041-
"""Acquire connection from pool.
1042-
"""
1043-
return (await self.pool.acquire())
1044-
1045-
def release(self, conn):
1046-
"""Release connection to pool.
1047-
"""
1048-
self.pool.release(conn)
10491059

1050-
async def connect(self):
1060+
async def create(self):
10511061
"""Create connection pool asynchronously.
10521062
"""
10531063
self.pool = await aiomysql.create_pool(
@@ -1056,37 +1066,8 @@ async def connect(self):
10561066
connect_timeout=self.timeout,
10571067
**self.connect_params)
10581068

1059-
async def close(self):
1060-
"""Terminate all pool connections.
1061-
"""
1062-
self.pool.terminate()
1063-
await self.pool.wait_closed()
1064-
1065-
async def cursor(self, conn=None, *args, **kwargs):
1066-
"""Get cursor for connection from pool.
1067-
"""
1068-
in_transaction = conn is not None
1069-
if not conn:
1070-
conn = await self.acquire()
1071-
try:
1072-
cursor = await conn.cursor(*args, **kwargs)
1073-
except:
1074-
if not in_transaction:
1075-
self.release(conn)
1076-
raise
1077-
cursor.release = functools.partial(
1078-
self.release_cursor, cursor,
1079-
in_transaction=in_transaction)
1080-
return cursor
1081-
1082-
async def release_cursor(self, cursor, in_transaction=False):
1083-
"""Release cursor coroutine. Unless in transaction,
1084-
the connection is also released back to the pool.
1085-
"""
1086-
conn = cursor.connection
1069+
async def close_cursor(self, cursor):
10871070
await cursor.close()
1088-
if not in_transaction:
1089-
self.release(conn)
10901071

10911072

10921073
class MySQLDatabase(AsyncDatabase, peewee.MySQLDatabase):
@@ -1108,7 +1089,7 @@ def init(self, database, **kwargs):
11081089
raise Exception("Error, aiomysql is not installed!")
11091090
self.min_connections = 1
11101091
self.max_connections = 1
1111-
self._async_conn_cls = kwargs.pop('async_conn', AsyncMySQLConnection)
1092+
self._async_conn_cls = kwargs.pop('async_conn', AioMysqlPool)
11121093
super().init(database, **kwargs)
11131094

11141095
@property

0 commit comments

Comments
 (0)