@@ -25,6 +25,7 @@ import qualified IHP.PGListener as PGListener
25
25
import IHP.ApplicationContext
26
26
import Data.Set (Set )
27
27
import qualified Data.Set as Set
28
+ import qualified Data.Pool as Pool
28
29
29
30
instance (
30
31
PG. ToField (PrimaryKey (GetTableName CurrentUserRecord ))
@@ -36,7 +37,7 @@ instance (
36
37
initialState = DataSyncController
37
38
38
39
run = do
39
- setState DataSyncReady { subscriptions = HashMap. empty }
40
+ setState DataSyncReady { subscriptions = HashMap. empty, transactions = HashMap. empty }
40
41
41
42
ensureRLSEnabled <- makeCachedEnsureRLSEnabled
42
43
installTableChangeTriggers <- ChangeNotifications. makeCachedInstallTableChangeTriggers
@@ -45,12 +46,12 @@ instance (
45
46
46
47
let
47
48
handleMessage :: DataSyncMessage -> IO ()
48
- handleMessage DataSyncQuery { query, requestId } = do
49
+ handleMessage DataSyncQuery { query, requestId, transactionId } = do
49
50
ensureRLSEnabled (get # table query)
50
51
51
52
let (theQuery, theParams) = compileQuery query
52
53
53
- result :: [[Field ]] <- sqlQueryWithRLS theQuery theParams
54
+ result :: [[Field ]] <- sqlQueryWithRLSAndTransactionId transactionId theQuery theParams
54
55
55
56
sendJSON DataSyncResult { result, requestId }
56
57
@@ -131,7 +132,7 @@ instance (
131
132
132
133
sendJSON DidDeleteDataSubscription { subscriptionId, requestId }
133
134
134
- handleMessage CreateRecordMessage { table, record, requestId } = do
135
+ handleMessage CreateRecordMessage { table, record, requestId, transactionId } = do
135
136
ensureRLSEnabled table
136
137
137
138
let query = " INSERT INTO ? ? VALUES ? RETURNING * "
@@ -145,15 +146,15 @@ instance (
145
146
146
147
let params = (PG.Identifier table, PG.In (map PG.Identifier columns), PG.In values)
147
148
148
- result :: [[Field]] <- sqlQueryWithRLS query params
149
+ result :: [[Field]] <- sqlQueryWithRLSAndTransactionId transactionId query params
149
150
150
151
case result of
151
152
[record] -> sendJSON DidCreateRecord { requestId, record }
152
153
otherwise -> error " Unexpected result in CreateRecordMessage handler"
153
154
154
155
pure ()
155
156
156
- handleMessage CreateRecordsMessage { table, records, requestId } = do
157
+ handleMessage CreateRecordsMessage { table, records, requestId, transactionId } = do
157
158
ensureRLSEnabled table
158
159
159
160
let query = " INSERT INTO ? ? ? RETURNING * "
@@ -175,13 +176,13 @@ instance (
175
176
176
177
let params = (PG.Identifier table, PG.In (map PG.Identifier columns), PG.Values [] values)
177
178
178
- records :: [[Field]] <- sqlQueryWithRLS query params
179
+ records :: [[Field]] <- sqlQueryWithRLSAndTransactionId transactionId query params
179
180
180
181
sendJSON DidCreateRecords { requestId, records }
181
182
182
183
pure ()
183
184
184
- handleMessage UpdateRecordMessage { table, id, patch, requestId } = do
185
+ handleMessage UpdateRecordMessage { table, id, patch, requestId, transactionId } = do
185
186
ensureRLSEnabled table
186
187
187
188
let columns = patch
@@ -204,15 +205,15 @@ instance (
204
205
<> (join (map (\(key, value) -> [PG.toField key, value]) keyValues))
205
206
<> [PG.toField id]
206
207
207
- result :: [[Field]] <- sqlQueryWithRLS (PG.Query query) params
208
+ result :: [[Field]] <- sqlQueryWithRLSAndTransactionId transactionId (PG.Query query) params
208
209
209
210
case result of
210
211
[record] -> sendJSON DidUpdateRecord { requestId, record }
211
212
otherwise -> error " Unexpected result in UpdateRecordMessage handler"
212
213
213
214
pure ()
214
215
215
- handleMessage UpdateRecordsMessage { table, ids, patch, requestId } = do
216
+ handleMessage UpdateRecordsMessage { table, ids, patch, requestId, transactionId } = do
216
217
ensureRLSEnabled table
217
218
218
219
let columns = patch
@@ -235,26 +236,63 @@ instance (
235
236
<> (join (map (\(key, value) -> [PG.toField key, value]) keyValues))
236
237
<> [PG.toField (PG.In ids)]
237
238
238
- records <- sqlQueryWithRLS (PG.Query query) params
239
+ records <- sqlQueryWithRLSAndTransactionId transactionId (PG.Query query) params
239
240
240
241
sendJSON DidUpdateRecords { requestId, records }
241
242
242
243
pure ()
243
244
244
- handleMessage DeleteRecordMessage { table, id, requestId } = do
245
+ handleMessage DeleteRecordMessage { table, id, requestId, transactionId } = do
245
246
ensureRLSEnabled table
246
247
247
- sqlExecWithRLS " DELETE FROM ? WHERE id = ? " (PG.Identifier table, id)
248
+ sqlExecWithRLSAndTransactionId transactionId " DELETE FROM ? WHERE id = ? " (PG.Identifier table, id)
248
249
249
250
sendJSON DidDeleteRecord { requestId }
250
251
251
- handleMessage DeleteRecordsMessage { table, ids, requestId } = do
252
+ handleMessage DeleteRecordsMessage { table, ids, requestId, transactionId } = do
252
253
ensureRLSEnabled table
253
254
254
- sqlExecWithRLS " DELETE FROM ? WHERE id IN ? " (PG.Identifier table, PG.In ids)
255
+ sqlExecWithRLSAndTransactionId transactionId " DELETE FROM ? WHERE id IN ? " (PG.Identifier table, PG.In ids)
255
256
256
257
sendJSON DidDeleteRecords { requestId }
257
258
259
+ handleMessage StartTransaction { requestId } = do
260
+ ensureBelowTransactionLimit
261
+
262
+ transactionId <- UUID.nextRandom
263
+
264
+ (connection, localPool) <- ?modelContext
265
+ |> get #connectionPool
266
+ |> Pool.takeResource
267
+
268
+ let transaction = DataSyncTransaction
269
+ { id = transactionId
270
+ , connection
271
+ , releaseConnection = Pool.putResource localPool connection
272
+ }
273
+
274
+ let globalModelContext = ?modelContext
275
+ let ?modelContext = globalModelContext { transactionConnection = Just connection } in sqlExecWithRLS " BEGIN " ()
276
+
277
+ modifyIORef' ?state (\state -> state |> modify #transactions (HashMap.insert transactionId transaction))
278
+
279
+ sendJSON DidStartTransaction { requestId, transactionId }
280
+
281
+ handleMessage RollbackTransaction { requestId, id } = do
282
+ sqlExecWithRLSAndTransactionId (Just id) " ROLLBACK " ()
283
+
284
+ closeTransaction id
285
+
286
+ sendJSON DidRollbackTransaction { requestId, transactionId = id }
287
+
288
+ handleMessage CommitTransaction { requestId, id } = do
289
+ sqlExecWithRLSAndTransactionId (Just id) " COMMIT " ()
290
+
291
+ closeTransaction id
292
+
293
+ sendJSON DidCommitTransaction { requestId, transactionId = id }
294
+
295
+
258
296
259
297
forever do
260
298
message <- Aeson.eitherDecodeStrict' <$> receiveData @ByteString
@@ -289,13 +327,15 @@ cleanupAllSubscriptions = do
289
327
let pgListener = ?applicationContext |> get #pgListener
290
328
291
329
case state of
292
- DataSyncReady { subscriptions } -> do
330
+ DataSyncReady { subscriptions, transactions } -> do
293
331
let channelSubscriptions = subscriptions
294
332
|> HashMap.elems
295
333
|> map (get #channelSubscription)
296
334
forEach channelSubscriptions \channelSubscription -> do
297
335
pgListener |> PGListener.unsubscribe channelSubscription
298
336
337
+ forEach (HashMap.elems transactions) (get #releaseConnection)
338
+
299
339
pure ()
300
340
_ -> pure ()
301
341
@@ -310,8 +350,81 @@ queryFieldNamesToColumnNames sqlQuery = sqlQuery
310
350
where
311
351
convertOrderByClause OrderByClause { orderByColumn, orderByDirection } = OrderByClause { orderByColumn = cs (fieldNameToColumnName (cs orderByColumn)), orderByDirection }
312
352
353
+
354
+ runInModelContextWithTransaction :: (?state :: IORef DataSyncController, _) => ((?modelContext :: ModelContext) => IO result) -> Maybe UUID -> IO result
355
+ runInModelContextWithTransaction function (Just transactionId) = do
356
+ let globalModelContext = ?modelContext
357
+
358
+ DataSyncTransaction { connection } <- findTransactionById transactionId
359
+ let
360
+ ?modelContext = globalModelContext { transactionConnection = Just connection }
361
+ in
362
+ function
363
+ runInModelContextWithTransaction function Nothing = function
364
+
365
+ findTransactionById :: (?state :: IORef DataSyncController) => UUID -> IO DataSyncTransaction
366
+ findTransactionById transactionId = do
367
+ transactions <- get #transactions <$> readIORef ?state
368
+ case HashMap.lookup transactionId transactions of
369
+ Just transaction -> pure transaction
370
+ Nothing -> error " No transaction with that id "
371
+
372
+ closeTransaction transactionId = do
373
+ DataSyncTransaction { releaseConnection } <- findTransactionById transactionId
374
+ modifyIORef' ?state (\state -> state |> modify #transactions (HashMap.delete transactionId))
375
+ releaseConnection
376
+
377
+ -- | Allow max 10 concurrent transactions per connection to avoid running out of database connections
378
+ --
379
+ -- Each transaction removes a database connection from the connection pool. If we don't limit the transactions,
380
+ -- a single user could take down the application by starting more than 'IHP.FrameworkConfig.DBPoolMaxConnections'
381
+ -- concurrent transactions. Then all database connections are removed from the connection pool and further database
382
+ -- queries for other users will fail.
383
+ --
384
+ ensureBelowTransactionLimit :: (?state :: IORef DataSyncController) => IO ()
385
+ ensureBelowTransactionLimit = do
386
+ transactions <- get #transactions <$> readIORef ?state
387
+ let transactionCount = HashMap.size transactions
388
+ let maxTransactionsPerConnection = 10
389
+ when (transactionCount >= maxTransactionsPerConnection) do
390
+ error (" You've reached the transaction limit of " <> tshow maxTransactionsPerConnection <> " transactions" )
391
+
392
+ sqlQueryWithRLSAndTransactionId ::
393
+ ( ?modelContext :: ModelContext
394
+ , PG.ToRow parameters
395
+ , ?context :: ControllerContext
396
+ , userId ~ Id CurrentUserRecord
397
+ , Show (PrimaryKey (GetTableName CurrentUserRecord))
398
+ , HasNewSessionUrl CurrentUserRecord
399
+ , Typeable CurrentUserRecord
400
+ , ?context :: ControllerContext
401
+ , HasField " id " CurrentUserRecord (Id' (GetTableName CurrentUserRecord))
402
+ , PG.ToField userId
403
+ , FromRow result
404
+ , ?state :: IORef DataSyncController
405
+ ) => Maybe UUID -> PG.Query -> parameters -> IO [result]
406
+ sqlQueryWithRLSAndTransactionId transactionId theQuery theParams = runInModelContextWithTransaction (sqlQueryWithRLS theQuery theParams) transactionId
407
+
408
+ sqlExecWithRLSAndTransactionId ::
409
+ ( ?modelContext :: ModelContext
410
+ , PG.ToRow parameters
411
+ , ?context :: ControllerContext
412
+ , userId ~ Id CurrentUserRecord
413
+ , Show (PrimaryKey (GetTableName CurrentUserRecord))
414
+ , HasNewSessionUrl CurrentUserRecord
415
+ , Typeable CurrentUserRecord
416
+ , ?context :: ControllerContext
417
+ , HasField " id " CurrentUserRecord (Id' (GetTableName CurrentUserRecord))
418
+ , PG.ToField userId
419
+ , ?state :: IORef DataSyncController
420
+ ) => Maybe UUID -> PG.Query -> parameters -> IO Int64
421
+ sqlExecWithRLSAndTransactionId transactionId theQuery theParams = runInModelContextWithTransaction (sqlExecWithRLS theQuery theParams) transactionId
422
+
313
423
$(deriveFromJSON defaultOptions 'DataSyncQuery)
314
424
$(deriveToJSON defaultOptions 'DataSyncResult)
315
425
316
426
instance SetField " subscriptions" DataSyncController (HashMap UUID Subscription) where
317
- setField subscriptions record = record { subscriptions }
427
+ setField subscriptions record = record { subscriptions }
428
+
429
+ instance SetField " transactions" DataSyncController (HashMap UUID DataSyncTransaction) where
430
+ setField transactions record = record { transactions }
0 commit comments