@@ -49,6 +49,7 @@ import Data.Scientific
49
49
import GHC.Stack
50
50
import qualified Numeric
51
51
import qualified Data.Text.Encoding as Text
52
+ import qualified Data.ByteString.Builder as Builder
52
53
53
54
-- | Provides the db connection and some IHP-specific db configuration
54
55
data ModelContext = ModelContext
@@ -367,7 +368,7 @@ sqlQuery :: (?modelContext :: ModelContext, PG.ToRow q, PG.FromRow r, Show q) =>
367
368
sqlQuery theQuery theParameters = do
368
369
measureTimeIfLogging
369
370
(withDatabaseConnection \ connection -> enhanceSqlError theQuery theParameters do
370
- PG. query connection theQuery theParameters
371
+ withRLSParams ( PG. query connection) theQuery theParameters
371
372
)
372
373
theQuery
373
374
theParameters
@@ -382,28 +383,40 @@ sqlExec :: (?modelContext :: ModelContext, PG.ToRow q, Show q) => Query -> q ->
382
383
sqlExec theQuery theParameters = do
383
384
measureTimeIfLogging
384
385
(withDatabaseConnection \ connection -> enhanceSqlError theQuery theParameters do
385
- PG. execute connection theQuery theParameters
386
+ withRLSParams ( PG. execute connection) theQuery theParameters
386
387
)
387
388
theQuery
388
389
theParameters
389
390
{-# INLINABLE sqlExec #-}
390
391
392
+ -- | Wraps the query with Row level security boilerplate, if a row level security context was provided
393
+ --
394
+ -- __Example:__
395
+ --
396
+ -- If a row level security context is given, this will turn a query like the following
397
+ --
398
+ -- > withRLSParams runQuery "SELECT * FROM projects WHERE id = ?" (Only "..")
399
+ --
400
+ -- Into the following equivalent:
401
+ --
402
+ -- > runQuery "SET LOCAL ROLE ?; SET LOCAL rls.ihp_user_id = ?; SELECT * FROM projects WHERE id = ?" ["ihp_authenticated", "<user id>", .."]
403
+ --
404
+ withRLSParams :: (? modelContext :: ModelContext , PG. ToRow params ) => (PG. Query -> [PG. Action ] -> result ) -> PG. Query -> params -> result
405
+ withRLSParams runQuery query params = do
406
+ case get # rowLevelSecurity ? modelContext of
407
+ Just RowLevelSecurityContext { rlsAuthenticatedRole, rlsUserId } -> do
408
+ let query' = " SET LOCAL ROLE ?; SET LOCAL rls.ihp_user_id = ?; " <> query
409
+ let params' = [PG. toField (PG. Identifier rlsAuthenticatedRole), PG. toField rlsUserId] <> PG. toRow params
410
+ runQuery query' params'
411
+ Nothing -> runQuery query (PG. toRow params)
412
+
391
413
withDatabaseConnection :: (? modelContext :: ModelContext ) => (Connection -> IO a ) -> IO a
392
414
withDatabaseConnection block =
393
415
let
394
416
ModelContext { connectionPool, transactionConnection, rowLevelSecurity } = ? modelContext
395
417
in case transactionConnection of
396
418
Just transactionConnection -> block transactionConnection
397
- Nothing ->
398
- -- When row level security is enabled, we need to implicitly wrap the current query in a
399
- -- transaction, as we need to make sure the @SET LOCAL ROLE ihp_authenticated@ and
400
- -- @SET LOCAL rls.ihp_user_id = ..@ queries are executed on the same connection before
401
- -- the actual query has been executed.
402
- case rowLevelSecurity of
403
- Just rowLevelSecurity -> withTransaction do
404
- let (Just connection) = ? modelContext |> get # transactionConnection
405
- block connection
406
- Nothing -> Pool. withResource connectionPool block
419
+ Nothing -> Pool. withResource connectionPool block
407
420
{-# INLINABLE withDatabaseConnection #-}
408
421
409
422
-- | Runs a raw sql query which results in a single scalar value such as an integer or string
@@ -469,17 +482,7 @@ withTransaction block = withTransactionConnection do
469
482
|> \ case
470
483
Just connection -> connection
471
484
Nothing -> error " withTransaction: transactionConnection not set as expected"
472
- case get # rowLevelSecurity ? modelContext of
473
- -- When starting a new transaction while RLS is enabled, we switch the transaction over
474
- -- to the @ihp_authenticated@ role and also set the @rls.ihp_user_id@ variable.
475
- --
476
- -- This branch is also called from @withDatabaseConnection@, as we
477
- -- automatically wrap all queries in an implicit transaction when RLS is enabled.
478
- Just RowLevelSecurityContext { rlsAuthenticatedRole, rlsUserId } -> PG. withTransaction connection do
479
- sqlExec " SET LOCAL ROLE ?" [PG. Identifier rlsAuthenticatedRole]
480
- sqlExec " SET LOCAL rls.ihp_user_id = ?" (PG. Only rlsUserId)
481
- block
482
- Nothing -> PG. withTransaction connection block
485
+ PG. withTransaction connection block
483
486
{-# INLINABLE withTransaction #-}
484
487
485
488
-- | Executes the given block with the main database role and temporarly sidesteps the row level security policies.
@@ -500,9 +503,7 @@ withTransaction block = withTransactionConnection do
500
503
withRowLevelSecurityDisabled :: (? modelContext :: ModelContext ) => ((? modelContext :: ModelContext ) => IO a ) -> IO a
501
504
withRowLevelSecurityDisabled block = do
502
505
let currentModelContext = ? modelContext
503
- case get # rowLevelSecurity currentModelContext of
504
- Just _ -> let ? modelContext = currentModelContext { rowLevelSecurity = Nothing , transactionConnection = Nothing } in block
505
- Nothing -> block
506
+ let ? modelContext = currentModelContext { rowLevelSecurity = Nothing } in block
506
507
{-# INLINABLE withRowLevelSecurityDisabled #-}
507
508
508
509
-- | Returns the postgres connection when called within a 'withTransaction' block
@@ -589,7 +590,12 @@ logQuery query parameters time = do
589
590
-- NominalTimeDiff is represented as seconds, and doesn't provide a FormatTime option for printing in ms.
590
591
-- To get around that we convert to and from a rational so we can format as desired.
591
592
let queryTimeInMs = (time * 1000 ) |> toRational |> fromRational @ Double
592
- Log. debug (" Query (" <> tshow queryTimeInMs <> " ms): " <> tshow query <> " " <> tshow parameters)
593
+ let formatRLSInfo userId = " { ihp_user_id = " <> userId <> " }"
594
+ let rlsInfo = case get # rowLevelSecurity ? context of
595
+ Just RowLevelSecurityContext { rlsUserId = PG. Plain rlsUserId } -> formatRLSInfo (cs (Builder. toLazyByteString rlsUserId))
596
+ Just RowLevelSecurityContext { rlsUserId = rlsUserId } -> formatRLSInfo (tshow rlsUserId)
597
+ Nothing -> " "
598
+ Log. debug (" Query (" <> tshow queryTimeInMs <> " ms): " <> tshow query <> " " <> tshow parameters <> rlsInfo)
593
599
{-# INLINABLE logQuery #-}
594
600
595
601
-- | Runs a @DELETE@ query for a record.
0 commit comments