Skip to content

Commit dfc09f9

Browse files
committed
Added transactions to DataSync
This adds a new high-level `withTransaction` function and a lower-level `Transaction` object to deal with postgres transactions from DataSync. The following operations can be used from within a transaction: - createRecord, createRecords - updateRecord, updateRecords - deleteRecord, deleteRecords - query
1 parent 87dbe2c commit dfc09f9

File tree

6 files changed

+285
-41
lines changed

6 files changed

+285
-41
lines changed

IHP/DataSync/Controller.hs

Lines changed: 130 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ import qualified IHP.PGListener as PGListener
2525
import IHP.ApplicationContext
2626
import Data.Set (Set)
2727
import qualified Data.Set as Set
28+
import qualified Data.Pool as Pool
2829

2930
instance (
3031
PG.ToField (PrimaryKey (GetTableName CurrentUserRecord))
@@ -36,7 +37,7 @@ instance (
3637
initialState = DataSyncController
3738

3839
run = do
39-
setState DataSyncReady { subscriptions = HashMap.empty }
40+
setState DataSyncReady { subscriptions = HashMap.empty, transactions = HashMap.empty }
4041

4142
ensureRLSEnabled <- makeCachedEnsureRLSEnabled
4243
installTableChangeTriggers <- ChangeNotifications.makeCachedInstallTableChangeTriggers
@@ -45,12 +46,12 @@ instance (
4546

4647
let
4748
handleMessage :: DataSyncMessage -> IO ()
48-
handleMessage DataSyncQuery { query, requestId } = do
49+
handleMessage DataSyncQuery { query, requestId, transactionId } = do
4950
ensureRLSEnabled (get #table query)
5051

5152
let (theQuery, theParams) = compileQuery query
5253

53-
result :: [[Field]] <- sqlQueryWithRLS theQuery theParams
54+
result :: [[Field]] <- sqlQueryWithRLSAndTransactionId transactionId theQuery theParams
5455

5556
sendJSON DataSyncResult { result, requestId }
5657

@@ -131,7 +132,7 @@ instance (
131132
132133
sendJSON DidDeleteDataSubscription { subscriptionId, requestId }
133134
134-
handleMessage CreateRecordMessage { table, record, requestId } = do
135+
handleMessage CreateRecordMessage { table, record, requestId, transactionId } = do
135136
ensureRLSEnabled table
136137
137138
let query = "INSERT INTO ? ? VALUES ? RETURNING *"
@@ -145,15 +146,15 @@ instance (
145146
146147
let params = (PG.Identifier table, PG.In (map PG.Identifier columns), PG.In values)
147148
148-
result :: [[Field]] <- sqlQueryWithRLS query params
149+
result :: [[Field]] <- sqlQueryWithRLSAndTransactionId transactionId query params
149150
150151
case result of
151152
[record] -> sendJSON DidCreateRecord { requestId, record }
152153
otherwise -> error "Unexpected result in CreateRecordMessage handler"
153154
154155
pure ()
155156
156-
handleMessage CreateRecordsMessage { table, records, requestId } = do
157+
handleMessage CreateRecordsMessage { table, records, requestId, transactionId } = do
157158
ensureRLSEnabled table
158159
159160
let query = "INSERT INTO ? ? ? RETURNING *"
@@ -175,13 +176,13 @@ instance (
175176
176177
let params = (PG.Identifier table, PG.In (map PG.Identifier columns), PG.Values [] values)
177178
178-
records :: [[Field]] <- sqlQueryWithRLS query params
179+
records :: [[Field]] <- sqlQueryWithRLSAndTransactionId transactionId query params
179180
180181
sendJSON DidCreateRecords { requestId, records }
181182
182183
pure ()
183184
184-
handleMessage UpdateRecordMessage { table, id, patch, requestId } = do
185+
handleMessage UpdateRecordMessage { table, id, patch, requestId, transactionId } = do
185186
ensureRLSEnabled table
186187
187188
let columns = patch
@@ -204,15 +205,15 @@ instance (
204205
<> (join (map (\(key, value) -> [PG.toField key, value]) keyValues))
205206
<> [PG.toField id]
206207
207-
result :: [[Field]] <- sqlQueryWithRLS (PG.Query query) params
208+
result :: [[Field]] <- sqlQueryWithRLSAndTransactionId transactionId (PG.Query query) params
208209
209210
case result of
210211
[record] -> sendJSON DidUpdateRecord { requestId, record }
211212
otherwise -> error "Unexpected result in UpdateRecordMessage handler"
212213
213214
pure ()
214215
215-
handleMessage UpdateRecordsMessage { table, ids, patch, requestId } = do
216+
handleMessage UpdateRecordsMessage { table, ids, patch, requestId, transactionId } = do
216217
ensureRLSEnabled table
217218
218219
let columns = patch
@@ -235,26 +236,63 @@ instance (
235236
<> (join (map (\(key, value) -> [PG.toField key, value]) keyValues))
236237
<> [PG.toField (PG.In ids)]
237238
238-
records <- sqlQueryWithRLS (PG.Query query) params
239+
records <- sqlQueryWithRLSAndTransactionId transactionId (PG.Query query) params
239240
240241
sendJSON DidUpdateRecords { requestId, records }
241242
242243
pure ()
243244
244-
handleMessage DeleteRecordMessage { table, id, requestId } = do
245+
handleMessage DeleteRecordMessage { table, id, requestId, transactionId } = do
245246
ensureRLSEnabled table
246247
247-
sqlExecWithRLS "DELETE FROM ? WHERE id = ?" (PG.Identifier table, id)
248+
sqlExecWithRLSAndTransactionId transactionId "DELETE FROM ? WHERE id = ?" (PG.Identifier table, id)
248249
249250
sendJSON DidDeleteRecord { requestId }
250251
251-
handleMessage DeleteRecordsMessage { table, ids, requestId } = do
252+
handleMessage DeleteRecordsMessage { table, ids, requestId, transactionId } = do
252253
ensureRLSEnabled table
253254
254-
sqlExecWithRLS "DELETE FROM ? WHERE id IN ?" (PG.Identifier table, PG.In ids)
255+
sqlExecWithRLSAndTransactionId transactionId "DELETE FROM ? WHERE id IN ?" (PG.Identifier table, PG.In ids)
255256
256257
sendJSON DidDeleteRecords { requestId }
257258
259+
handleMessage StartTransaction { requestId } = do
260+
ensureBelowTransactionLimit
261+
262+
transactionId <- UUID.nextRandom
263+
264+
(connection, localPool) <- ?modelContext
265+
|> get #connectionPool
266+
|> Pool.takeResource
267+
268+
let transaction = DataSyncTransaction
269+
{ id = transactionId
270+
, connection
271+
, releaseConnection = Pool.putResource localPool connection
272+
}
273+
274+
let globalModelContext = ?modelContext
275+
let ?modelContext = globalModelContext { transactionConnection = Just connection } in sqlExecWithRLS "BEGIN" ()
276+
277+
modifyIORef' ?state (\state -> state |> modify #transactions (HashMap.insert transactionId transaction))
278+
279+
sendJSON DidStartTransaction { requestId, transactionId }
280+
281+
handleMessage RollbackTransaction { requestId, id } = do
282+
sqlExecWithRLSAndTransactionId (Just id) "ROLLBACK" ()
283+
284+
closeTransaction id
285+
286+
sendJSON DidRollbackTransaction { requestId, transactionId = id }
287+
288+
handleMessage CommitTransaction { requestId, id } = do
289+
sqlExecWithRLSAndTransactionId (Just id) "COMMIT" ()
290+
291+
closeTransaction id
292+
293+
sendJSON DidCommitTransaction { requestId, transactionId = id }
294+
295+
258296
259297
forever do
260298
message <- Aeson.eitherDecodeStrict' <$> receiveData @ByteString
@@ -289,13 +327,15 @@ cleanupAllSubscriptions = do
289327
let pgListener = ?applicationContext |> get #pgListener
290328
291329
case state of
292-
DataSyncReady { subscriptions } -> do
330+
DataSyncReady { subscriptions, transactions } -> do
293331
let channelSubscriptions = subscriptions
294332
|> HashMap.elems
295333
|> map (get #channelSubscription)
296334
forEach channelSubscriptions \channelSubscription -> do
297335
pgListener |> PGListener.unsubscribe channelSubscription
298336
337+
forEach (HashMap.elems transactions) (get #releaseConnection)
338+
299339
pure ()
300340
_ -> pure ()
301341
@@ -310,8 +350,81 @@ queryFieldNamesToColumnNames sqlQuery = sqlQuery
310350
where
311351
convertOrderByClause OrderByClause { orderByColumn, orderByDirection } = OrderByClause { orderByColumn = cs (fieldNameToColumnName (cs orderByColumn)), orderByDirection }
312352
353+
354+
runInModelContextWithTransaction :: (?state :: IORef DataSyncController, _) => ((?modelContext :: ModelContext) => IO result) -> Maybe UUID -> IO result
355+
runInModelContextWithTransaction function (Just transactionId) = do
356+
let globalModelContext = ?modelContext
357+
358+
DataSyncTransaction { connection } <- findTransactionById transactionId
359+
let
360+
?modelContext = globalModelContext { transactionConnection = Just connection }
361+
in
362+
function
363+
runInModelContextWithTransaction function Nothing = function
364+
365+
findTransactionById :: (?state :: IORef DataSyncController) => UUID -> IO DataSyncTransaction
366+
findTransactionById transactionId = do
367+
transactions <- get #transactions <$> readIORef ?state
368+
case HashMap.lookup transactionId transactions of
369+
Just transaction -> pure transaction
370+
Nothing -> error "No transaction with that id"
371+
372+
closeTransaction transactionId = do
373+
DataSyncTransaction { releaseConnection } <- findTransactionById transactionId
374+
modifyIORef' ?state (\state -> state |> modify #transactions (HashMap.delete transactionId))
375+
releaseConnection
376+
377+
-- | Allow max 10 concurrent transactions per connection to avoid running out of database connections
378+
--
379+
-- Each transaction removes a database connection from the connection pool. If we don't limit the transactions,
380+
-- a single user could take down the application by starting more than 'IHP.FrameworkConfig.DBPoolMaxConnections'
381+
-- concurrent transactions. Then all database connections are removed from the connection pool and further database
382+
-- queries for other users will fail.
383+
--
384+
ensureBelowTransactionLimit :: (?state :: IORef DataSyncController) => IO ()
385+
ensureBelowTransactionLimit = do
386+
transactions <- get #transactions <$> readIORef ?state
387+
let transactionCount = HashMap.size transactions
388+
let maxTransactionsPerConnection = 10
389+
when (transactionCount >= maxTransactionsPerConnection) do
390+
error ("You've reached the transaction limit of " <> tshow maxTransactionsPerConnection <> " transactions")
391+
392+
sqlQueryWithRLSAndTransactionId ::
393+
( ?modelContext :: ModelContext
394+
, PG.ToRow parameters
395+
, ?context :: ControllerContext
396+
, userId ~ Id CurrentUserRecord
397+
, Show (PrimaryKey (GetTableName CurrentUserRecord))
398+
, HasNewSessionUrl CurrentUserRecord
399+
, Typeable CurrentUserRecord
400+
, ?context :: ControllerContext
401+
, HasField "id" CurrentUserRecord (Id' (GetTableName CurrentUserRecord))
402+
, PG.ToField userId
403+
, FromRow result
404+
, ?state :: IORef DataSyncController
405+
) => Maybe UUID -> PG.Query -> parameters -> IO [result]
406+
sqlQueryWithRLSAndTransactionId transactionId theQuery theParams = runInModelContextWithTransaction (sqlQueryWithRLS theQuery theParams) transactionId
407+
408+
sqlExecWithRLSAndTransactionId ::
409+
( ?modelContext :: ModelContext
410+
, PG.ToRow parameters
411+
, ?context :: ControllerContext
412+
, userId ~ Id CurrentUserRecord
413+
, Show (PrimaryKey (GetTableName CurrentUserRecord))
414+
, HasNewSessionUrl CurrentUserRecord
415+
, Typeable CurrentUserRecord
416+
, ?context :: ControllerContext
417+
, HasField "id" CurrentUserRecord (Id' (GetTableName CurrentUserRecord))
418+
, PG.ToField userId
419+
, ?state :: IORef DataSyncController
420+
) => Maybe UUID -> PG.Query -> parameters -> IO Int64
421+
sqlExecWithRLSAndTransactionId transactionId theQuery theParams = runInModelContextWithTransaction (sqlExecWithRLS theQuery theParams) transactionId
422+
313423
$(deriveFromJSON defaultOptions 'DataSyncQuery)
314424
$(deriveToJSON defaultOptions 'DataSyncResult)
315425
316426
instance SetField "subscriptions" DataSyncController (HashMap UUID Subscription) where
317-
setField subscriptions record = record { subscriptions }
427+
setField subscriptions record = record { subscriptions }
428+
429+
instance SetField "transactions" DataSyncController (HashMap UUID DataSyncTransaction) where
430+
setField transactions record = record { transactions }

IHP/DataSync/Types.hs

Lines changed: 24 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,17 +6,21 @@ import IHP.QueryBuilder
66
import IHP.DataSync.DynamicQuery
77
import Data.HashMap.Strict (HashMap)
88
import qualified IHP.PGListener as PGListener
9+
import qualified Database.PostgreSQL.Simple as PG
910

1011
data DataSyncMessage
11-
= DataSyncQuery { query :: !DynamicSQLQuery, requestId :: !Int }
12+
= DataSyncQuery { query :: !DynamicSQLQuery, requestId :: !Int, transactionId :: !(Maybe UUID) }
1213
| CreateDataSubscription { query :: !DynamicSQLQuery, requestId :: !Int }
1314
| DeleteDataSubscription { subscriptionId :: !UUID, requestId :: !Int }
14-
| CreateRecordMessage { table :: !Text, record :: !(HashMap Text Value), requestId :: !Int }
15-
| CreateRecordsMessage { table :: !Text, records :: ![HashMap Text Value], requestId :: !Int }
16-
| UpdateRecordMessage { table :: !Text, id :: !UUID, patch :: !(HashMap Text Value), requestId :: !Int }
17-
| UpdateRecordsMessage { table :: !Text, ids :: ![UUID], patch :: !(HashMap Text Value), requestId :: !Int }
18-
| DeleteRecordMessage { table :: !Text, id :: !UUID, requestId :: !Int }
19-
| DeleteRecordsMessage { table :: !Text, ids :: ![UUID], requestId :: !Int }
15+
| CreateRecordMessage { table :: !Text, record :: !(HashMap Text Value), requestId :: !Int, transactionId :: !(Maybe UUID) }
16+
| CreateRecordsMessage { table :: !Text, records :: ![HashMap Text Value], requestId :: !Int, transactionId :: !(Maybe UUID) }
17+
| UpdateRecordMessage { table :: !Text, id :: !UUID, patch :: !(HashMap Text Value), requestId :: !Int, transactionId :: !(Maybe UUID) }
18+
| UpdateRecordsMessage { table :: !Text, ids :: ![UUID], patch :: !(HashMap Text Value), requestId :: !Int, transactionId :: !(Maybe UUID) }
19+
| DeleteRecordMessage { table :: !Text, id :: !UUID, requestId :: !Int, transactionId :: !(Maybe UUID) }
20+
| DeleteRecordsMessage { table :: !Text, ids :: ![UUID], requestId :: !Int, transactionId :: !(Maybe UUID) }
21+
| StartTransaction { requestId :: !Int }
22+
| RollbackTransaction { requestId :: !Int, id :: !UUID }
23+
| CommitTransaction { requestId :: !Int, id :: !UUID }
2024
deriving (Eq, Show)
2125

2226
data DataSyncResponse
@@ -34,9 +38,21 @@ data DataSyncResponse
3438
| DidUpdateRecords { requestId :: !Int, records :: ![[Field]] } -- ^ Response to 'UpdateRecordsMessage'
3539
| DidDeleteRecord { requestId :: !Int }
3640
| DidDeleteRecords { requestId :: !Int }
41+
| DidStartTransaction { requestId :: !Int, transactionId :: !UUID }
42+
| DidRollbackTransaction { requestId :: !Int, transactionId :: !UUID }
43+
| DidCommitTransaction { requestId :: !Int, transactionId :: !UUID }
3744

3845
data Subscription = Subscription { id :: !UUID, channelSubscription :: !PGListener.Subscription }
46+
data DataSyncTransaction
47+
= DataSyncTransaction
48+
{ id :: !UUID
49+
, connection :: !PG.Connection
50+
, releaseConnection :: IO ()
51+
}
3952

4053
data DataSyncController
4154
= DataSyncController
42-
| DataSyncReady { subscriptions :: !(HashMap UUID Subscription) }
55+
| DataSyncReady
56+
{ subscriptions :: !(HashMap UUID Subscription)
57+
, transactions :: !(HashMap UUID DataSyncTransaction)
58+
}

0 commit comments

Comments
 (0)