Skip to content

Commit e904b91

Browse files
committed
Added support for 'CREATE POLICY .. FOR ..' sql statements
1 parent 5961386 commit e904b91

File tree

7 files changed

+53
-9
lines changed

7 files changed

+53
-9
lines changed

IHP/IDE/CodeGen/MigrationGenerator.hs

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -340,11 +340,14 @@ normalizeStatement StatementCreateTable { unsafeGetCreateTable = table } = State
340340
(normalizedTable, normalizeTableRest) = normalizeTable table
341341
normalizeStatement AddConstraint { tableName, constraint } = [ AddConstraint { tableName, constraint = normalizeConstraint constraint } ]
342342
normalizeStatement CreateEnumType { name, values } = [ CreateEnumType { name = Text.toLower name, values = map Text.toLower values } ]
343-
normalizeStatement CreatePolicy { name, tableName, using, check } = [ CreatePolicy { name, tableName, using = normalizeExpression <$> using, check = normalizeExpression <$> check } ]
343+
normalizeStatement CreatePolicy { name, action, tableName, using, check } = [ CreatePolicy { name, tableName, using = normalizeExpression <$> using, check = normalizeExpression <$> check, action = normalizePolicyAction action } ]
344344
normalizeStatement CreateIndex { expressions, .. } = [ CreateIndex { expressions = map normalizeExpression expressions, .. } ]
345345
normalizeStatement CreateFunction { .. } = [ CreateFunction { orReplace = False, .. } ]
346346
normalizeStatement otherwise = [otherwise]
347347

348+
normalizePolicyAction (Just PolicyForAll) = Nothing
349+
normalizePolicyAction otherwise = otherwise
350+
348351
normalizeTable :: CreateTable -> (CreateTable, [Statement])
349352
normalizeTable table@(CreateTable { .. }) = ( CreateTable { columns = fst normalizedColumns, constraints = normalizedTableConstraints, .. }, (concat $ (snd normalizedColumns)) <> normalizedConstraintsStatements )
350353
where

IHP/IDE/SchemaDesigner/Compiler.hs

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ compileStatement Comment { content } = "--" <> content
3535
compileStatement CreateIndex { indexName, unique, tableName, expressions, whereClause } = "CREATE" <> (if unique then " UNIQUE " else " ") <> "INDEX " <> indexName <> " ON " <> tableName <> " (" <> (intercalate ", " (map compileExpression expressions)) <> ")" <> (case whereClause of Just expression -> " WHERE " <> compileExpression expression; Nothing -> "") <> ";"
3636
compileStatement CreateFunction { functionName, functionBody, orReplace, returns, language } = "CREATE " <> (if orReplace then "OR REPLACE " else "") <> "FUNCTION " <> functionName <> "() RETURNS " <> compilePostgresType returns <> " AS $$" <> functionBody <> "$$ language " <> language <> ";"
3737
compileStatement EnableRowLevelSecurity { tableName } = "ALTER TABLE " <> tableName <> " ENABLE ROW LEVEL SECURITY;"
38-
compileStatement CreatePolicy { name, tableName, using, check } = "CREATE POLICY " <> compileIdentifier name <> " ON " <> compileIdentifier tableName <> maybe "" (\expr -> " USING (" <> compileExpression expr <> ")") using <> maybe "" (\expr -> " WITH CHECK (" <> compileExpression expr <> ")") check <> ";"
38+
compileStatement CreatePolicy { name, action, tableName, using, check } = "CREATE POLICY " <> compileIdentifier name <> maybe "" (\action -> " FOR " <> compilePolicyAction action) action <> " ON " <> compileIdentifier tableName <> maybe "" (\expr -> " USING (" <> compileExpression expr <> ")") using <> maybe "" (\expr -> " WITH CHECK (" <> compileExpression expr <> ")") check <> ";"
3939
compileStatement CreateSequence { name } = "CREATE SEQUENCE " <> compileIdentifier name <> ";"
4040
compileStatement DropConstraint { tableName, constraintName } = "ALTER TABLE " <> compileIdentifier tableName <> " DROP CONSTRAINT " <> compileIdentifier constraintName <> ";"
4141
compileStatement DropEnumType { name } = "DROP TYPE " <> compileIdentifier name <> ";"
@@ -432,4 +432,11 @@ compileTriggerEvent TriggerOnTruncate = "TRUNCATE"
432432

433433
compileTriggerFor :: TriggerFor -> Text
434434
compileTriggerFor ForEachRow = "FOR EACH ROW"
435-
compileTriggerFor ForEachStatement = "FOR EACH STATEMENT"
435+
compileTriggerFor ForEachStatement = "FOR EACH STATEMENT"
436+
437+
compilePolicyAction :: PolicyAction -> Text
438+
compilePolicyAction PolicyForAll = "ALL"
439+
compilePolicyAction PolicyForSelect = "SELECT"
440+
compilePolicyAction PolicyForInsert = "INSERT"
441+
compilePolicyAction PolicyForUpdate = "UPDATE"
442+
compilePolicyAction PolicyForDelete = "DELETE"

IHP/IDE/SchemaDesigner/Parser.hs

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -652,6 +652,7 @@ createPolicy = do
652652
lexeme "CREATE"
653653
lexeme "POLICY"
654654
name <- identifier
655+
action <- optional (lexeme "FOR" >> policyAction)
655656
lexeme "ON"
656657
tableName <- qualifiedIdentifier
657658

@@ -666,8 +667,14 @@ createPolicy = do
666667

667668
char ';'
668669

669-
pure CreatePolicy { name, tableName, using, check }
670+
pure CreatePolicy { name, action, tableName, using, check }
670671

672+
policyAction =
673+
(lexeme "ALL" >> pure PolicyForAll)
674+
<|> (lexeme "SELECT" >> pure PolicyForSelect)
675+
<|> (lexeme "INSERT" >> pure PolicyForInsert)
676+
<|> (lexeme "UPDATE" >> pure PolicyForUpdate)
677+
<|> (lexeme "DELETE" >> pure PolicyForDelete)
671678

672679
setStatement = do
673680
lexeme "SET"

IHP/IDE/SchemaDesigner/SchemaOperations.hs

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -151,7 +151,7 @@ updatePolicy UpdatePolicyOptions { .. } statements =
151151
statements
152152
|> map updatePolicy'
153153
where
154-
updatePolicy' policy@CreatePolicy { name = pName, tableName = pTable } | pName == currentName && pTable == tableName = CreatePolicy { tableName, name, using, check }
154+
updatePolicy' policy@CreatePolicy { name = pName, action, tableName = pTable } | pName == currentName && pTable == tableName = CreatePolicy { tableName, action, name, using, check }
155155
updatePolicy' otherwise = otherwise
156156

157157
data AddPolicyOptions = AddPolicyOptions
@@ -164,7 +164,7 @@ data AddPolicyOptions = AddPolicyOptions
164164
addPolicy :: AddPolicyOptions -> Schema -> Schema
165165
addPolicy AddPolicyOptions { .. } statements = statements <> createPolicyStatement
166166
where
167-
createPolicyStatement = [ CreatePolicy { tableName, name, using, check } ]
167+
createPolicyStatement = [ CreatePolicy { tableName, action = Nothing, name, using, check } ]
168168

169169
data DeletePolicyOptions = DeletePolicyOptions
170170
{ tableName :: !Text
@@ -235,6 +235,7 @@ suggestPolicy :: Schema -> Statement -> Statement
235235
suggestPolicy schema (StatementCreateTable CreateTable { name = tableName, columns })
236236
| isJust (find isUserIdColumn columns) = CreatePolicy
237237
{ name = "Users can manage their " <> tableName
238+
, action = Nothing
238239
, tableName
239240
, using = Just compareUserId
240241
, check = Just compareUserId
@@ -252,6 +253,7 @@ suggestPolicy schema (StatementCreateTable CreateTable { name = tableName, colum
252253
columnWithFKAndRefTableToPolicy :: (Column, Constraint, CreateTable) -> Maybe Statement
253254
columnWithFKAndRefTableToPolicy (column, ForeignKeyConstraint { referenceColumn }, CreateTable { name = refTableName, columns = refTableColumns }) | isJust (find isUserIdColumn refTableColumns) = Just CreatePolicy
254255
{ name = "Users can manage the " <> tableName <> " if they can see the " <> tableNameToModelName refTableName
256+
, action = Nothing
255257
, tableName
256258
, using = Just delegateCheck
257259
, check = Just delegateCheck
@@ -304,7 +306,7 @@ suggestPolicy schema (StatementCreateTable CreateTable { name = tableName, colum
304306
|> fmap \case
305307
StatementCreateTable table -> table
306308

307-
emptyPolicy = CreatePolicy { name = "", tableName, using = Nothing, check = Nothing }
309+
emptyPolicy = CreatePolicy { name = "", action = Nothing, tableName, using = Nothing, check = Nothing }
308310

309311
isUserIdColumn :: Column -> Bool
310312
isUserIdColumn Column { name = "user_id" } = True

IHP/IDE/SchemaDesigner/Types.hs

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ data Statement
3939
-- | ALTER TABLE tableName ENABLE ROW LEVEL SECURITY;
4040
| EnableRowLevelSecurity { tableName :: Text }
4141
-- CREATE POLICY name ON tableName USING using WITH CHECK check;
42-
| CreatePolicy { name :: Text, tableName :: Text, using :: Maybe Expression, check :: Maybe Expression }
42+
| CreatePolicy { name :: Text, tableName :: Text, action :: Maybe PolicyAction, using :: Maybe Expression, check :: Maybe Expression }
4343
-- SET name = value;
4444
| Set { name :: Text, value :: Expression }
4545
-- SELECT query;
@@ -214,4 +214,12 @@ data TriggerEvent
214214
data TriggerFor
215215
= ForEachRow
216216
| ForEachStatement
217+
deriving (Eq, Show)
218+
219+
data PolicyAction
220+
= PolicyForAll
221+
| PolicyForSelect
222+
| PolicyForInsert
223+
| PolicyForUpdate
224+
| PolicyForDelete
217225
deriving (Eq, Show)

Test/IDE/SchemaDesigner/ParserSpec.hs

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -528,6 +528,7 @@ $$;
528528
it "should parse 'CREATE POLICY' statements" do
529529
parseSql "CREATE POLICY \"Users can manage their tasks\" ON tasks USING (user_id = ihp_user_id()) WITH CHECK (user_id = ihp_user_id());" `shouldBe` CreatePolicy
530530
{ name = "Users can manage their tasks"
531+
, action = Nothing
531532
, tableName = "tasks"
532533
, using = Just (
533534
EqExpression
@@ -718,10 +719,21 @@ COMMENT ON EXTENSION "uuid-ossp" IS 'generate universally unique identifiers (UU
718719
it "should parse 'DROP POLICY .. ON ..' statements" do
719720
parseSql "DROP POLICY \"Users can manage their todos\" ON todos;" `shouldBe` DropPolicy { tableName = "todos", policyName = "Users can manage their todos" }
720721

722+
it "should parse 'CREATE POLICY .. FOR SELECT' statements" do
723+
let sql = cs [plain|CREATE POLICY "Messages are public" FOR SELECT ON messages USING (true);|]
724+
parseSql sql `shouldBe` CreatePolicy
725+
{ name = "Messages are public"
726+
, action = Just PolicyForSelect
727+
, tableName = "messages"
728+
, using = Just (VarExpression "true")
729+
, check = Nothing
730+
}
731+
721732
it "should parse policies with an EXISTS condition" do
722733
let sql = cs [plain|CREATE POLICY "Users can manage their project's migrations" ON migrations USING (EXISTS (SELECT 1 FROM projects WHERE id = project_id)) WITH CHECK (EXISTS (SELECT 1 FROM projects WHERE id = project_id));|]
723734
parseSql sql `shouldBe` CreatePolicy
724735
{ name = "Users can manage their project's migrations"
736+
, action = Nothing
725737
, tableName = "migrations"
726738
, using = Just (ExistsExpression (SelectExpression (Select {columns = [IntExpression 1], from = VarExpression "projects", alias = Nothing, whereClause = EqExpression (VarExpression "id") (VarExpression "project_id")})))
727739
, check = Just (ExistsExpression (SelectExpression (Select {columns = [IntExpression 1], from = VarExpression "projects", alias = Nothing, whereClause = EqExpression (VarExpression "id") (VarExpression "project_id")})))
@@ -731,6 +743,7 @@ COMMENT ON EXTENSION "uuid-ossp" IS 'generate universally unique identifiers (UU
731743
let sql = cs [plain|CREATE POLICY "Users can manage their project's migrations" ON migrations USING (EXISTS (SELECT 1 FROM public.projects WHERE projects.id = migrations.project_id)) WITH CHECK (EXISTS (SELECT 1 FROM public.projects WHERE projects.id = migrations.project_id));|]
732744
parseSql sql `shouldBe` CreatePolicy
733745
{ name = "Users can manage their project's migrations"
746+
, action = Nothing
734747
, tableName = "migrations"
735748
, using = Just (ExistsExpression (SelectExpression (Select {columns = [IntExpression 1], from = DotExpression (VarExpression "public") "projects", alias = Nothing, whereClause = EqExpression (DotExpression (VarExpression "projects") "id") (DotExpression (VarExpression "migrations") "project_id")})))
736749
, check = Just (ExistsExpression (SelectExpression (Select {columns = [IntExpression 1], from = DotExpression (VarExpression "public") "projects", alias = Nothing, whereClause = EqExpression (DotExpression (VarExpression "projects") "id") (DotExpression (VarExpression "migrations") "project_id")})))
@@ -785,6 +798,7 @@ COMMENT ON EXTENSION "uuid-ossp" IS 'generate universally unique identifiers (UU
785798
|]
786799
parseSql sql `shouldBe` CreatePolicy
787800
{ name = "Users can see other users in their company"
801+
, action = Nothing
788802
, tableName = "users"
789803
, using = Just (EqExpression (VarExpression "company_id") (SelectExpression (Select {columns = [DotExpression (VarExpression "users_1") "company_id"], from = DotExpression (VarExpression "public") "users", alias = Just "users_1", whereClause = EqExpression (DotExpression (VarExpression "users_1") "id") (CallExpression "ihp_user_id" [])})))
790804
, check = Nothing

Test/IDE/SchemaDesigner/SchemaOperationsSpec.hs

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ tests = do
7979
(SchemaOperations.disableRowLevelSecurityIfNoPolicies "a" inputSchema) `shouldBe` inputSchema
8080

8181
it "should not do anything if there's a policy" do
82-
let policy = CreatePolicy { tableName = "a", name = "p", check = Nothing, using = Nothing }
82+
let policy = CreatePolicy { tableName = "a", action = Nothing, name = "p", check = Nothing, using = Nothing }
8383
let inputSchema = [tableA, EnableRowLevelSecurity { tableName = "a"}, policy]
8484

8585
(SchemaOperations.disableRowLevelSecurityIfNoPolicies "a" inputSchema) `shouldBe` inputSchema
@@ -115,6 +115,7 @@ tests = do
115115
let schema = [table]
116116
let expectedPolicy = CreatePolicy
117117
{ name = "Users can manage their posts"
118+
, action = Nothing
118119
, tableName = "posts"
119120
, using = Just (EqExpression (VarExpression "user_id") (CallExpression "ihp_user_id" []))
120121
, check = Just (EqExpression (VarExpression "user_id") (CallExpression "ihp_user_id" []))
@@ -135,6 +136,7 @@ tests = do
135136
let schema = [table]
136137
let expectedPolicy = CreatePolicy
137138
{ name = ""
139+
, action = Nothing
138140
, tableName = "posts"
139141
, using = Nothing
140142
, check = Nothing
@@ -166,6 +168,7 @@ tests = do
166168
]
167169
let expectedPolicy = CreatePolicy
168170
{ name = "Users can manage the tasks if they can see the TaskList"
171+
, action = Nothing
169172
, tableName = "tasks"
170173
, using = Just (ExistsExpression (SelectExpression (Select {columns = [IntExpression 1], from = DotExpression (VarExpression "public") "task_lists", alias = Nothing, whereClause = EqExpression (DotExpression (VarExpression "task_lists") "id") (DotExpression (VarExpression "tasks") "task_list_id")})))
171174
, check = Just (ExistsExpression (SelectExpression (Select {columns = [IntExpression 1], from = DotExpression (VarExpression "public") "task_lists", alias = Nothing, whereClause = EqExpression (DotExpression (VarExpression "task_lists") "id") (DotExpression (VarExpression "tasks") "task_list_id")})))

0 commit comments

Comments
 (0)