Skip to content

Commit d231ee0

Browse files
kvchandrew-farries
andauthored
Add support for creating CHECK constraints with create_constraint (#464)
This PR introduces a new constraint `type` to `create_constraint` operation called `check`. Now it is possible to create check constraints on multiple columns. ### Example ```json { "name": "45_add_table_check_constraint", "operations": [ { "create_constraint": { "type": "check", "table": "tickets", "name": "check_zip_name", "columns": [ "sellers_name", "sellers_zip" ], "check": "sellers_name ~ 'Alice' AND sellers_zip IS NOT NULL", "up": { "sellers_name": "Alice", "sellers_zip": "(SELECT CASE WHEN sellers_zip IS NOT NULL THEN sellers_zip ELSE '00000' END)" }, "down": { "sellers_name": "sellers_name", "sellers_zip": "sellers_zip" } } } ] } ``` --------- Co-authored-by: Andrew Farries <[email protected]>
1 parent 5d63a7f commit d231ee0

File tree

7 files changed

+287
-5
lines changed

7 files changed

+287
-5
lines changed

docs/README.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1101,7 +1101,7 @@ Example **create table** migrations:
11011101

11021102
A create constraint operation adds a new constraint to an existing table.
11031103

1104-
Only `UNIQUE` constraints are supported.
1104+
Only `UNIQUE` and `CHECK` constraints are supported.
11051105

11061106
Required fields: `name`, `table`, `type`, `up`, `down`.
11071107

@@ -1129,6 +1129,7 @@ Required fields: `name`, `table`, `type`, `up`, `down`.
11291129
Example **create constraint** migrations:
11301130

11311131
* [44_add_table_unique_constraint.json](../examples/44_add_table_unique_constraint.json)
1132+
* [45_add_table_check_constraint.json](../examples/45_add_table_check_constraint.json)
11321133

11331134

11341135
### Drop column

examples/.ledger

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,3 +42,4 @@
4242
42_create_unique_index.json
4343
43_create_tickets_table.json
4444
44_add_table_unique_constraint.json
45+
45_add_table_check_constraint.json
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
{
2+
"name": "45_add_table_check_constraint",
3+
"operations": [
4+
{
5+
"create_constraint": {
6+
"type": "check",
7+
"table": "tickets",
8+
"name": "check_zip_name",
9+
"columns": [
10+
"sellers_name",
11+
"sellers_zip"
12+
],
13+
"check": "sellers_name = 'alice' OR sellers_zip > 0",
14+
"up": {
15+
"sellers_name": "sellers_name",
16+
"sellers_zip": "(SELECT CASE WHEN sellers_name != 'alice' AND sellers_zip <= 0 THEN 123 WHEN sellers_name != 'alice' THEN sellers_zip ELSE sellers_zip END)"
17+
},
18+
"down": {
19+
"sellers_name": "sellers_name",
20+
"sellers_zip": "sellers_zip"
21+
}
22+
}
23+
}
24+
]
25+
}

pkg/migrations/op_create_constraint.go

Lines changed: 30 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -65,16 +65,18 @@ func (o *OpCreateConstraint) Start(ctx context.Context, conn db.DB, latestSchema
6565
}
6666
}
6767

68-
switch o.Type { //nolint:gocritic // more cases will be added
68+
switch o.Type {
6969
case OpCreateConstraintTypeUnique:
7070
return table, o.addUniqueIndex(ctx, conn)
71+
case OpCreateConstraintTypeCheck:
72+
return table, o.addCheckConstraint(ctx, conn)
7173
}
7274

7375
return table, nil
7476
}
7577

7678
func (o *OpCreateConstraint) Complete(ctx context.Context, conn db.DB, tr SQLTransformer, s *schema.Schema) error {
77-
switch o.Type { //nolint:gocritic // more cases will be added
79+
switch o.Type {
7880
case OpCreateConstraintTypeUnique:
7981
uniqueOp := &OpSetUnique{
8082
Table: o.Table,
@@ -84,6 +86,17 @@ func (o *OpCreateConstraint) Complete(ctx context.Context, conn db.DB, tr SQLTra
8486
if err != nil {
8587
return err
8688
}
89+
case OpCreateConstraintTypeCheck:
90+
checkOp := &OpSetCheckConstraint{
91+
Table: o.Table,
92+
Check: CheckConstraint{
93+
Name: o.Name,
94+
},
95+
}
96+
err := checkOp.Complete(ctx, conn, tr, s)
97+
if err != nil {
98+
return err
99+
}
87100
}
88101

89102
// remove old columns
@@ -176,11 +189,15 @@ func (o *OpCreateConstraint) Validate(ctx context.Context, s *schema.Schema) err
176189
}
177190
}
178191

179-
switch o.Type { //nolint:gocritic // more cases will be added
192+
switch o.Type {
180193
case OpCreateConstraintTypeUnique:
181194
if len(o.Columns) == 0 {
182195
return FieldRequiredError{Name: "columns"}
183196
}
197+
case OpCreateConstraintTypeCheck:
198+
if o.Check == nil || *o.Check == "" {
199+
return FieldRequiredError{Name: "check"}
200+
}
184201
}
185202

186203
return nil
@@ -196,6 +213,16 @@ func (o *OpCreateConstraint) addUniqueIndex(ctx context.Context, conn db.DB) err
196213
return err
197214
}
198215

216+
func (o *OpCreateConstraint) addCheckConstraint(ctx context.Context, conn db.DB) error {
217+
_, err := conn.ExecContext(ctx, fmt.Sprintf("ALTER TABLE %s ADD CONSTRAINT %s CHECK (%s) NOT VALID",
218+
pq.QuoteIdentifier(o.Table),
219+
pq.QuoteIdentifier(o.Name),
220+
rewriteCheckExpression(*o.Check, o.Columns...),
221+
))
222+
223+
return err
224+
}
225+
199226
func quotedTemporaryNames(columns []string) []string {
200227
names := make([]string, len(columns))
201228
for i, col := range columns {

pkg/migrations/op_create_constraint_test.go

Lines changed: 220 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@ import (
77
"strings"
88
"testing"
99

10+
"github.com/stretchr/testify/assert"
11+
1012
"github.com/xataio/pgroll/internal/testutils"
1113
"github.com/xataio/pgroll/pkg/migrations"
1214
)
@@ -97,6 +99,80 @@ func TestCreateConstraint(t *testing.T) {
9799
}, testutils.UniqueViolationErrorCode)
98100
},
99101
},
102+
{
103+
name: "create check constraint on single column",
104+
migrations: []migrations.Migration{
105+
{
106+
Name: "01_add_table",
107+
Operations: migrations.Operations{
108+
&migrations.OpCreateTable{
109+
Name: "users",
110+
Columns: []migrations.Column{
111+
{
112+
Name: "id",
113+
Type: "serial",
114+
Pk: ptr(true),
115+
},
116+
{
117+
Name: "name",
118+
Type: "varchar(255)",
119+
Nullable: ptr(false),
120+
},
121+
},
122+
},
123+
},
124+
},
125+
{
126+
Name: "02_create_constraint",
127+
Operations: migrations.Operations{
128+
&migrations.OpCreateConstraint{
129+
Name: "name_letters",
130+
Table: "users",
131+
Type: "check",
132+
Check: ptr("name ~ '^[a-zA-Z]+$'"),
133+
Columns: []string{"name"},
134+
Up: migrations.OpCreateConstraintUp(map[string]string{
135+
"name": "regexp_replace(name, '\\d+', '', 'g')",
136+
}),
137+
Down: migrations.OpCreateConstraintDown(map[string]string{
138+
"name": "name",
139+
}),
140+
},
141+
},
142+
},
143+
},
144+
afterStart: func(t *testing.T, db *sql.DB, schema string) {
145+
// The new (temporary) column should exist on the underlying table.
146+
ColumnMustExist(t, db, schema, "users", migrations.TemporaryName("name"))
147+
// The check constraint exists on the new table.
148+
CheckConstraintMustExist(t, db, schema, "users", "name_letters")
149+
// Inserting values into the old schema that violate the check constraint must succeed.
150+
MustInsert(t, db, schema, "01_add_table", "users", map[string]string{
151+
"name": "alice11",
152+
})
153+
154+
// Inserting values into the new schema that violate the check constraint should fail.
155+
MustInsert(t, db, schema, "02_create_constraint", "users", map[string]string{
156+
"name": "bob",
157+
})
158+
MustNotInsert(t, db, schema, "02_create_constraint", "users", map[string]string{
159+
"name": "bob2",
160+
}, testutils.CheckViolationErrorCode)
161+
},
162+
afterRollback: func(t *testing.T, db *sql.DB, schema string) {
163+
// Functions, triggers and temporary columns are dropped.
164+
tableCleanedUp(t, db, schema, "users", "name")
165+
},
166+
afterComplete: func(t *testing.T, db *sql.DB, schema string) {
167+
// Functions, triggers and temporary columns are dropped.
168+
tableCleanedUp(t, db, schema, "users", "name")
169+
170+
// Inserting values into the new schema that violate the check constraint should fail.
171+
MustNotInsert(t, db, schema, "02_create_constraint", "users", map[string]string{
172+
"name": "carol0",
173+
}, testutils.CheckViolationErrorCode)
174+
},
175+
},
100176
{
101177
name: "create unique constraint on multiple columns",
102178
migrations: []migrations.Migration{
@@ -181,6 +257,104 @@ func TestCreateConstraint(t *testing.T) {
181257
// Complete is a no-op.
182258
},
183259
},
260+
{
261+
name: "create check constraint on multiple columns",
262+
migrations: []migrations.Migration{
263+
{
264+
Name: "01_add_table",
265+
Operations: migrations.Operations{
266+
&migrations.OpCreateTable{
267+
Name: "users",
268+
Columns: []migrations.Column{
269+
{
270+
Name: "id",
271+
Type: "serial",
272+
Pk: ptr(true),
273+
},
274+
{
275+
Name: "name",
276+
Type: "varchar(255)",
277+
Nullable: ptr(false),
278+
},
279+
{
280+
Name: "email",
281+
Type: "varchar(255)",
282+
Nullable: ptr(false),
283+
},
284+
},
285+
},
286+
},
287+
},
288+
{
289+
Name: "02_create_constraint",
290+
Operations: migrations.Operations{
291+
&migrations.OpCreateConstraint{
292+
Name: "check_name_email",
293+
Table: "users",
294+
Type: "check",
295+
Check: ptr("name != email"),
296+
Columns: []string{"name", "email"},
297+
Up: migrations.OpCreateConstraintUp(map[string]string{
298+
"name": "name",
299+
"email": "(SELECT CASE WHEN email ~ '@' THEN email ELSE email || '@example.com' END)",
300+
}),
301+
Down: migrations.OpCreateConstraintDown(map[string]string{
302+
"name": "name",
303+
"email": "email",
304+
}),
305+
},
306+
},
307+
},
308+
},
309+
afterStart: func(t *testing.T, db *sql.DB, schema string) {
310+
// The new (temporary) column should exist on the underlying table.
311+
ColumnMustExist(t, db, schema, "users", migrations.TemporaryName("name"))
312+
// The new (temporary) column should exist on the underlying table.
313+
ColumnMustExist(t, db, schema, "users", migrations.TemporaryName("email"))
314+
// The check constraint exists on the new table.
315+
CheckConstraintMustExist(t, db, schema, "users", "check_name_email")
316+
317+
// Inserting values into the old schema that the violate the check constraint must succeed.
318+
MustInsert(t, db, schema, "01_add_table", "users", map[string]string{
319+
"name": "alice",
320+
"email": "alice",
321+
})
322+
323+
// Inserting values into the new schema that meet the check constraint should succeed.
324+
MustInsert(t, db, schema, "02_create_constraint", "users", map[string]string{
325+
"name": "bob",
326+
"email": "[email protected]",
327+
})
328+
MustNotInsert(t, db, schema, "02_create_constraint", "users", map[string]string{
329+
"name": "bob",
330+
"email": "bob",
331+
}, testutils.CheckViolationErrorCode)
332+
},
333+
afterRollback: func(t *testing.T, db *sql.DB, schema string) {
334+
// The check constraint must not exists on the table.
335+
CheckConstraintMustNotExist(t, db, schema, "users", "check_name_email")
336+
// Functions, triggers and temporary columns are dropped.
337+
tableCleanedUp(t, db, schema, "users", "name")
338+
tableCleanedUp(t, db, schema, "users", "email")
339+
},
340+
afterComplete: func(t *testing.T, db *sql.DB, schema string) {
341+
// Functions, triggers and temporary columns are dropped.
342+
tableCleanedUp(t, db, schema, "users", "name")
343+
tableCleanedUp(t, db, schema, "users", "email")
344+
345+
// Inserting values into the new schema that the violate the check constraint must fail.
346+
MustNotInsert(t, db, schema, "02_create_constraint", "users", map[string]string{
347+
"name": "carol",
348+
"email": "carol",
349+
}, testutils.CheckViolationErrorCode)
350+
351+
rows := MustSelect(t, db, schema, "02_create_constraint", "users")
352+
assert.Equal(t, []map[string]any{
353+
{"id": 1, "name": "alice", "email": "[email protected]"},
354+
{"id": 2, "name": "bob", "email": "[email protected]"},
355+
}, rows)
356+
},
357+
},
184358
{
185359
name: "invalid constraint name",
186360
migrations: []migrations.Migration{
@@ -270,6 +444,52 @@ func TestCreateConstraint(t *testing.T) {
270444
afterRollback: func(t *testing.T, db *sql.DB, schema string) {},
271445
afterComplete: func(t *testing.T, db *sql.DB, schema string) {},
272446
},
447+
{
448+
name: "expression of check constraint is missing",
449+
migrations: []migrations.Migration{
450+
{
451+
Name: "01_add_table",
452+
Operations: migrations.Operations{
453+
&migrations.OpCreateTable{
454+
Name: "users",
455+
Columns: []migrations.Column{
456+
{
457+
Name: "id",
458+
Type: "serial",
459+
Pk: ptr(true),
460+
},
461+
{
462+
Name: "name",
463+
Type: "varchar(255)",
464+
Nullable: ptr(false),
465+
},
466+
},
467+
},
468+
},
469+
},
470+
{
471+
Name: "02_create_constraint_with_missing_migration",
472+
Operations: migrations.Operations{
473+
&migrations.OpCreateConstraint{
474+
Name: "check_name",
475+
Table: "users",
476+
Columns: []string{"name"},
477+
Type: "check",
478+
Up: migrations.OpCreateConstraintUp(map[string]string{
479+
"name": "name",
480+
}),
481+
Down: migrations.OpCreateConstraintDown(map[string]string{
482+
"name": "name",
483+
}),
484+
},
485+
},
486+
},
487+
},
488+
wantStartErr: migrations.FieldRequiredError{Name: "check"},
489+
afterStart: func(t *testing.T, db *sql.DB, schema string) {},
490+
afterRollback: func(t *testing.T, db *sql.DB, schema string) {},
491+
afterComplete: func(t *testing.T, db *sql.DB, schema string) {},
492+
},
273493
})
274494
}
275495

pkg/migrations/types.go

Lines changed: 4 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

schema.json

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -454,7 +454,11 @@
454454
"type": {
455455
"description": "Type of the constraint",
456456
"type": "string",
457-
"enum": ["unique"]
457+
"enum": ["unique", "check"]
458+
},
459+
"check": {
460+
"description": "Check constraint expression",
461+
"type": "string"
458462
},
459463
"up": {
460464
"type": "object",

0 commit comments

Comments
 (0)