1
- import peewee
2
- from fastapi import FastAPI
3
1
import logging
4
- import uvicorn
5
2
import random
3
+ from contextlib import asynccontextmanager
4
+
5
+ import peewee
6
+ import uvicorn
7
+ from aiopg .connection import Connection
8
+ from aiopg .pool import Pool
9
+ from fastapi import FastAPI
10
+
11
+ acquire = Pool .acquire
12
+ cursor = Connection .cursor
13
+
14
+
15
+ def new_acquire (self ):
16
+ choice = random .randint (1 , 5 )
17
+ if choice == 5 :
18
+ raise Exception ("some network error" ) # network error imitation
19
+ return acquire (self )
20
+
21
+
22
+ def new_cursor (self ):
23
+ choice = random .randint (1 , 5 )
24
+ if choice == 5 :
25
+ raise Exception ("some network error" ) # network error imitation
26
+ return cursor (self )
27
+
28
+ Connection .cursor = new_cursor
29
+ Pool .acquire = new_acquire
6
30
7
31
import peewee_async
8
- from contextlib import asynccontextmanager
9
- import functools
32
+
10
33
11
34
logging .basicConfig ()
12
- pg_db = peewee_async .PooledPostgresqlDatabase (None )
35
+ pg_db = peewee_async .PooledPostgresqlDatabase (
36
+ database = 'postgres' ,
37
+ user = 'postgres' ,
38
+ password = 'postgres' ,
39
+ host = 'localhost' ,
40
+ port = 5432 ,
41
+ max_connections = 3
42
+ )
13
43
14
44
15
45
class Manager (peewee_async .Manager ):
16
46
"""Async models manager."""
17
47
18
- database = peewee_async .PooledPostgresqlDatabase (
19
- database = 'postgres' ,
20
- user = 'postgres' ,
21
- password = 'postgres' ,
22
- host = 'localhost' ,
23
- port = 5432 ,
24
- )
48
+ database = pg_db
25
49
26
50
27
51
manager = Manager ()
28
52
29
53
30
- def patch_manager (manager ):
31
- async def cursor (self , conn = None , * args , ** kwargs ):
32
-
33
- choice = random .randint (1 , 5 )
34
- if choice == 5 :
35
- raise Exception ("some network error" ) # network error imitation
36
-
37
- # actual code
38
- in_transaction = conn is not None
39
- if not conn :
40
- conn = await self .acquire ()
41
- cursor = await conn .cursor (* args , ** kwargs )
42
- cursor .release = functools .partial (
43
- self .release_cursor , cursor ,
44
- in_transaction = in_transaction )
45
- return cursor
46
-
47
- manager .database ._async_conn_cls .cursor = cursor
48
-
49
54
def setup_logging ():
50
55
logger = logging .getLogger ("uvicorn.error" )
51
56
handler = logging .FileHandler (filename = "app.log" , mode = "w" )
@@ -70,7 +75,6 @@ async def lifespan(app: FastAPI):
70
75
operation = 'TRUNCATE TABLE MySimplestModel;' ,
71
76
)
72
77
setup_logging ()
73
- patch_manager (manager )
74
78
yield
75
79
# Clean up the ML models and release the resources
76
80
await manager .close ()
@@ -90,6 +94,46 @@ async def test():
90
94
return errors
91
95
92
96
97
+ async def nested_transaction ():
98
+ async with manager .transaction ():
99
+ await manager .execute (MySimplestModel .update (id = 1 ))
100
+
101
+
102
+ async def nested_atomic ():
103
+ async with manager .atomic ():
104
+ await manager .execute (MySimplestModel .update (id = 1 ))
105
+
106
+
107
+ @app .get ("/transaction" )
108
+ async def test ():
109
+ try :
110
+ async with manager .transaction ():
111
+ await manager .execute (MySimplestModel .update (id = 1 ))
112
+ await nested_transaction ()
113
+ except Exception as e :
114
+ errors .add (str (e ))
115
+ raise
116
+ return errors
117
+
118
+
119
+ @app .get ("/atomic" )
120
+ async def test ():
121
+ try :
122
+ async with manager .atomic ():
123
+ await manager .execute (MySimplestModel .update (id = 1 ))
124
+ await nested_atomic ()
125
+ except Exception as e :
126
+ errors .add (str (e ))
127
+ raise
128
+ return errors
129
+
130
+
131
+ @app .get ("/recreate_pool" )
132
+ async def test ():
133
+ await manager .database .close_async ()
134
+ await manager .database .connect_async ()
135
+
136
+
93
137
if __name__ == "__main__" :
94
138
uvicorn .run (
95
139
app ,
0 commit comments