Skip to content

Commit 501bbef

Browse files
committed
Fixed data subscriptions not always closed in specific cases
We switched from manual resource handling to Exception.bracket with MVars to make sure that all data subscriptions are always closed when the underlying websocket connection is closed
1 parent 2c5afc0 commit 501bbef

File tree

2 files changed

+79
-62
lines changed

2 files changed

+79
-62
lines changed

IHP/DataSync/Controller.hs

Lines changed: 75 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ instance (
3737
initialState = DataSyncController
3838

3939
run = do
40-
setState DataSyncReady { subscriptions = HashMap.empty, transactions = HashMap.empty }
40+
setState DataSyncReady { subscriptions = HashMap.empty, transactions = HashMap.empty, asyncs = [] }
4141

4242
ensureRLSEnabled <- makeCachedEnsureRLSEnabled
4343
installTableChangeTriggers <- ChangeNotifications.makeCachedInstallTableChangeTriggers
@@ -56,6 +56,8 @@ instance (
5656
sendJSON DataSyncResult { result, requestId }
5757

5858
handleMessage CreateDataSubscription { query, requestId } = do
59+
ensureBelowSubscriptionsLimit
60+
5961
tableNameRLS <- ensureRLSEnabled (get #table query)
6062

6163
subscriptionId <- UUID.nextRandom
@@ -111,22 +113,23 @@ instance (
111113
when isWatchingRecord do
112114
sendJSON DidDelete { subscriptionId, id }
113115
116+
let subscribe = PGListener.subscribeJSON (ChangeNotifications.channelName tableNameRLS) callback pgListener
117+
let unsubscribe subscription = PGListener.unsubscribe subscription pgListener
114118
115-
channelSubscription <- pgListener
116-
|> PGListener.subscribeJSON (ChangeNotifications.channelName tableNameRLS) callback
119+
Exception.bracket subscribe unsubscribe \channelSubscription -> do
120+
close <- MVar.newEmptyMVar
121+
modifyIORef' ?state (\state -> state |> modify #subscriptions (HashMap.insert subscriptionId close))
117122
118-
modifyIORef' ?state (\state -> state |> modify #subscriptions (HashMap.insert subscriptionId Subscription { id = subscriptionId, channelSubscription }))
123+
sendJSON DidCreateDataSubscription { subscriptionId, requestId, result }
119124
120-
sendJSON DidCreateDataSubscription { subscriptionId, requestId, result }
125+
MVar.takeMVar close
121126
122127
handleMessage DeleteDataSubscription { requestId, subscriptionId } = do
123128
DataSyncReady { subscriptions } <- getState
124-
let maybeSubscription :: Maybe Subscription = HashMap.lookup subscriptionId subscriptions
129+
let (Just closeSignalMVar) = HashMap.lookup subscriptionId subscriptions
125130
126131
-- Cancel table watcher
127-
case maybeSubscription of
128-
Just subscription -> pgListener |> PGListener.unsubscribe (get #channelSubscription subscription)
129-
Nothing -> pure ()
132+
MVar.putMVar closeSignalMVar ()
130133
131134
modifyIORef' ?state (\state -> state |> modify #subscriptions (HashMap.delete subscriptionId))
132135
@@ -260,35 +263,49 @@ instance (
260263
ensureBelowTransactionLimit
261264
262265
transactionId <- UUID.nextRandom
263-
264-
(connection, localPool) <- ?modelContext
265-
|> get #connectionPool
266-
|> Pool.takeResource
267266
268-
let transaction = DataSyncTransaction
269-
{ id = transactionId
270-
, connection
271-
, releaseConnection = Pool.putResource localPool connection
272-
}
273267
274-
let globalModelContext = ?modelContext
275-
let ?modelContext = globalModelContext { transactionConnection = Just connection } in sqlExecWithRLS "BEGIN" ()
268+
let takeConnection = ?modelContext
269+
|> get #connectionPool
270+
|> Pool.takeResource
271+
272+
let releaseConnection (connection, localPool) = do
273+
PG.execute connection "ROLLBACK" () -- Make sure there's no pending transaction in case something went wrong
274+
Pool.putResource localPool connection
275+
276+
Exception.bracket takeConnection releaseConnection \(connection, localPool) -> do
277+
transactionSignal <- MVar.newEmptyMVar
278+
279+
let globalModelContext = ?modelContext
280+
let ?modelContext = globalModelContext { transactionConnection = Just connection } in sqlExecWithRLS "BEGIN" ()
276281
277-
modifyIORef' ?state (\state -> state |> modify #transactions (HashMap.insert transactionId transaction))
282+
let transaction = DataSyncTransaction
283+
{ id = transactionId
284+
, connection
285+
, close = transactionSignal
286+
}
278287
279-
sendJSON DidStartTransaction { requestId, transactionId }
288+
modifyIORef' ?state (\state -> state |> modify #transactions (HashMap.insert transactionId transaction))
289+
290+
sendJSON DidStartTransaction { requestId, transactionId }
291+
292+
MVar.takeMVar transactionSignal
293+
294+
modifyIORef' ?state (\state -> state |> modify #transactions (HashMap.delete transactionId))
280295
281296
handleMessage RollbackTransaction { requestId, id } = do
282-
sqlExecWithRLSAndTransactionId (Just id) "ROLLBACK" ()
297+
DataSyncTransaction { id, close } <- findTransactionById id
283298
284-
closeTransaction id
299+
sqlExecWithRLSAndTransactionId (Just id) "ROLLBACK" ()
300+
MVar.putMVar close ()
285301
286302
sendJSON DidRollbackTransaction { requestId, transactionId = id }
287303
288304
handleMessage CommitTransaction { requestId, id } = do
289-
sqlExecWithRLSAndTransactionId (Just id) "COMMIT" ()
305+
DataSyncTransaction { id, close } <- findTransactionById id
290306
291-
closeTransaction id
307+
sqlExecWithRLSAndTransactionId (Just id) "COMMIT" ()
308+
MVar.putMVar close ()
292309
293310
sendJSON DidCommitTransaction { requestId, transactionId = id }
294311
@@ -301,22 +318,24 @@ instance (
301318
Right decodedMessage -> do
302319
let requestId = get #requestId decodedMessage
303320
304-
-- Handle the messages in an async way
305-
-- This increases throughput as multiple queries can be fetched
306-
-- in parallel
307-
async do
308-
result <- Exception.try (handleMessage decodedMessage)
309-
310-
case result of
311-
Left (e :: Exception.SomeException) -> do
312-
let errorMessage = case fromException e of
313-
Just (enhancedSqlError :: EnhancedSqlError) -> cs (get #sqlErrorMsg (get #sqlError enhancedSqlError))
314-
Nothing -> cs (displayException e)
315-
Log.error (tshow e)
316-
sendJSON DataSyncError { requestId, errorMessage }
317-
Right result -> pure ()
318-
319-
pure ()
321+
Exception.mask \restore -> do
322+
-- Handle the messages in an async way
323+
-- This increases throughput as multiple queries can be fetched
324+
-- in parallel
325+
handlerProcess <- async $ restore do
326+
result <- Exception.try (handleMessage decodedMessage)
327+
328+
case result of
329+
Left (e :: Exception.SomeException) -> do
330+
let errorMessage = case fromException e of
331+
Just (enhancedSqlError :: EnhancedSqlError) -> cs (get #sqlErrorMsg (get #sqlError enhancedSqlError))
332+
Nothing -> cs (displayException e)
333+
Log.error (tshow e)
334+
sendJSON DataSyncError { requestId, errorMessage }
335+
Right result -> pure ()
336+
337+
modifyIORef' ?state (\state -> state |> modify #asyncs (handlerProcess:))
338+
pure ()
320339
Left errorMessage -> sendJSON FailedToDecodeMessageError { errorMessage = cs errorMessage }
321340
322341
onClose = cleanupAllSubscriptions
@@ -327,16 +346,7 @@ cleanupAllSubscriptions = do
327346
let pgListener = ?applicationContext |> get #pgListener
328347
329348
case state of
330-
DataSyncReady { subscriptions, transactions } -> do
331-
let channelSubscriptions = subscriptions
332-
|> HashMap.elems
333-
|> map (get #channelSubscription)
334-
forEach channelSubscriptions \channelSubscription -> do
335-
pgListener |> PGListener.unsubscribe channelSubscription
336-
337-
forEach (HashMap.elems transactions) (get #releaseConnection)
338-
339-
pure ()
349+
DataSyncReady { asyncs } -> forEach asyncs uninterruptibleCancel
340350
_ -> pure ()
341351
342352
changesToValue :: [ChangeNotifications.Change] -> Value
@@ -369,11 +379,6 @@ findTransactionById transactionId = do
369379
Just transaction -> pure transaction
370380
Nothing -> error "No transaction with that id"
371381
372-
closeTransaction transactionId = do
373-
DataSyncTransaction { releaseConnection } <- findTransactionById transactionId
374-
modifyIORef' ?state (\state -> state |> modify #transactions (HashMap.delete transactionId))
375-
releaseConnection
376-
377382
-- | Allow max 10 concurrent transactions per connection to avoid running out of database connections
378383
--
379384
-- Each transaction removes a database connection from the connection pool. If we don't limit the transactions,
@@ -389,6 +394,14 @@ ensureBelowTransactionLimit = do
389394
when (transactionCount >= maxTransactionsPerConnection) do
390395
error ("You've reached the transaction limit of " <> tshow maxTransactionsPerConnection <> " transactions")
391396
397+
ensureBelowSubscriptionsLimit :: (?state :: IORef DataSyncController) => IO ()
398+
ensureBelowSubscriptionsLimit = do
399+
subscriptions <- get #subscriptions <$> readIORef ?state
400+
let subscriptionsCount = HashMap.size subscriptions
401+
let maxSubscriptionsPerConnection = 128
402+
when (subscriptionsCount >= maxSubscriptionsPerConnection) do
403+
error ("You've reached the subscriptions limit of " <> tshow maxSubscriptionsPerConnection <> " subscriptions")
404+
392405
sqlQueryWithRLSAndTransactionId ::
393406
( ?modelContext :: ModelContext
394407
, PG.ToRow parameters
@@ -423,8 +436,11 @@ sqlExecWithRLSAndTransactionId transactionId theQuery theParams = runInModelCont
423436
$(deriveFromJSON defaultOptions 'DataSyncQuery)
424437
$(deriveToJSON defaultOptions 'DataSyncResult)
425438
426-
instance SetField "subscriptions" DataSyncController (HashMap UUID Subscription) where
439+
instance SetField "subscriptions" DataSyncController (HashMap UUID (MVar.MVar ())) where
427440
setField subscriptions record = record { subscriptions }
428441
429442
instance SetField "transactions" DataSyncController (HashMap UUID DataSyncTransaction) where
430-
setField transactions record = record { transactions }
443+
setField transactions record = record { transactions }
444+
445+
instance SetField "asyncs" DataSyncController [Async ()] where
446+
setField asyncs record = record { asyncs }

IHP/DataSync/Types.hs

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ import IHP.DataSync.DynamicQuery
77
import Data.HashMap.Strict (HashMap)
88
import qualified IHP.PGListener as PGListener
99
import qualified Database.PostgreSQL.Simple as PG
10+
import Control.Concurrent.MVar as MVar
1011

1112
data DataSyncMessage
1213
= DataSyncQuery { query :: !DynamicSQLQuery, requestId :: !Int, transactionId :: !(Maybe UUID) }
@@ -42,17 +43,17 @@ data DataSyncResponse
4243
| DidRollbackTransaction { requestId :: !Int, transactionId :: !UUID }
4344
| DidCommitTransaction { requestId :: !Int, transactionId :: !UUID }
4445

45-
data Subscription = Subscription { id :: !UUID, channelSubscription :: !PGListener.Subscription }
4646
data DataSyncTransaction
4747
= DataSyncTransaction
4848
{ id :: !UUID
4949
, connection :: !PG.Connection
50-
, releaseConnection :: IO ()
50+
, close :: MVar ()
5151
}
5252

5353
data DataSyncController
5454
= DataSyncController
5555
| DataSyncReady
56-
{ subscriptions :: !(HashMap UUID Subscription)
56+
{ subscriptions :: !(HashMap UUID (MVar.MVar ()))
5757
, transactions :: !(HashMap UUID DataSyncTransaction)
58+
, asyncs :: ![Async ()]
5859
}

0 commit comments

Comments
 (0)