Skip to content

Commit c09619a

Browse files
Add support for alter_table set check operations in multi-operation migrations (#622)
Ensure that multi-operation migrations combining `alter_column` operations setting `CHECK` constraints work in combination with other operations. Add testcases for: * rename table, set `CHECK` constraint * rename table, rename column, set `CHECK` constraint Previously these migrations would fail as the `alter_column` operation was unaware of the changes made by the preceding operation. Part of #239
1 parent db21651 commit c09619a

File tree

2 files changed

+221
-3
lines changed

2 files changed

+221
-3
lines changed

pkg/migrations/op_set_check.go

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ func (o *OpSetCheckConstraint) Start(ctx context.Context, conn db.DB, latestSche
2727
table := s.GetTable(o.Table)
2828

2929
// Add the check constraint to the new column as NOT VALID.
30-
if err := o.addCheckConstraint(ctx, conn); err != nil {
30+
if err := o.addCheckConstraint(ctx, conn, s); err != nil {
3131
return nil, fmt.Errorf("failed to add check constraint: %w", err)
3232
}
3333

@@ -82,9 +82,11 @@ func (o *OpSetCheckConstraint) Validate(ctx context.Context, s *schema.Schema) e
8282
return nil
8383
}
8484

85-
func (o *OpSetCheckConstraint) addCheckConstraint(ctx context.Context, conn db.DB) error {
85+
func (o *OpSetCheckConstraint) addCheckConstraint(ctx context.Context, conn db.DB, s *schema.Schema) error {
86+
table := s.GetTable(o.Table)
87+
8688
_, err := conn.ExecContext(ctx, fmt.Sprintf("ALTER TABLE %s ADD CONSTRAINT %s CHECK (%s) NOT VALID",
87-
pq.QuoteIdentifier(o.Table),
89+
pq.QuoteIdentifier(table.Name),
8890
pq.QuoteIdentifier(o.Check.Name),
8991
rewriteCheckExpression(o.Check.Constraint, o.Column),
9092
))

pkg/migrations/op_set_check_test.go

Lines changed: 216 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -594,6 +594,222 @@ func TestSetCheckConstraint(t *testing.T) {
594594
})
595595
}
596596

597+
func TestSetCheckInMultiOperationMigrations(t *testing.T) {
598+
t.Parallel()
599+
600+
ExecuteTests(t, TestCases{
601+
{
602+
name: "rename table, set not null",
603+
migrations: []migrations.Migration{
604+
{
605+
Name: "01_create_table",
606+
Operations: migrations.Operations{
607+
&migrations.OpCreateTable{
608+
Name: "items",
609+
Columns: []migrations.Column{
610+
{
611+
Name: "id",
612+
Type: "int",
613+
Pk: true,
614+
},
615+
{
616+
Name: "name",
617+
Type: "varchar(255)",
618+
Nullable: true,
619+
},
620+
},
621+
},
622+
},
623+
},
624+
{
625+
Name: "02_multi_operation",
626+
Operations: migrations.Operations{
627+
&migrations.OpRenameTable{
628+
From: "items",
629+
To: "products",
630+
},
631+
&migrations.OpAlterColumn{
632+
Table: "products",
633+
Column: "name",
634+
Check: &migrations.CheckConstraint{
635+
Name: "check_name_length",
636+
Constraint: "LENGTH(name) > 2",
637+
},
638+
Up: "SELECT CASE WHEN length(name) > 2 THEN name ELSE name || '---' END",
639+
Down: "name",
640+
},
641+
},
642+
},
643+
},
644+
afterStart: func(t *testing.T, db *sql.DB, schema string) {
645+
// Can insert a row into the new view that meets the check constraint
646+
MustInsert(t, db, schema, "02_multi_operation", "products", map[string]string{
647+
"id": "1",
648+
"name": "abc",
649+
})
650+
651+
// Can't insert a row into the new view that violates the check constraint
652+
MustNotInsert(t, db, schema, "02_multi_operation", "products", map[string]string{
653+
"id": "2",
654+
"name": "x",
655+
}, testutils.CheckViolationErrorCode)
656+
657+
// Can insert a row into the old view that violates the check constraint
658+
MustInsert(t, db, schema, "01_create_table", "items", map[string]string{
659+
"id": "3",
660+
"name": "x",
661+
})
662+
663+
// The new view has the expected rows
664+
rows := MustSelect(t, db, schema, "02_multi_operation", "products")
665+
assert.Equal(t, []map[string]any{
666+
{"id": 1, "name": "abc"},
667+
{"id": 3, "name": "x---"},
668+
}, rows)
669+
670+
// The old view has the expected rows
671+
rows = MustSelect(t, db, schema, "01_create_table", "items")
672+
assert.Equal(t, []map[string]any{
673+
{"id": 1, "name": "abc"},
674+
{"id": 3, "name": "x"},
675+
}, rows)
676+
},
677+
afterRollback: func(t *testing.T, db *sql.DB, schema string) {
678+
// The table has been cleaned up
679+
TableMustBeCleanedUp(t, db, schema, "items", "name")
680+
},
681+
afterComplete: func(t *testing.T, db *sql.DB, schema string) {
682+
// Can insert a row into the new view that meets the check constraint
683+
MustInsert(t, db, schema, "02_multi_operation", "products", map[string]string{
684+
"id": "4",
685+
"name": "def",
686+
})
687+
688+
// Can't insert a row into the new view that violates the check constraint
689+
MustNotInsert(t, db, schema, "02_multi_operation", "products", map[string]string{
690+
"id": "5",
691+
"name": "x",
692+
}, testutils.CheckViolationErrorCode)
693+
694+
// The new view has the expected rows
695+
rows := MustSelect(t, db, schema, "02_multi_operation", "products")
696+
assert.Equal(t, []map[string]any{
697+
{"id": 1, "name": "abc"},
698+
{"id": 3, "name": "x---"},
699+
{"id": 4, "name": "def"},
700+
}, rows)
701+
},
702+
},
703+
{
704+
name: "rename table, rename column set not null",
705+
migrations: []migrations.Migration{
706+
{
707+
Name: "01_create_table",
708+
Operations: migrations.Operations{
709+
&migrations.OpCreateTable{
710+
Name: "items",
711+
Columns: []migrations.Column{
712+
{
713+
Name: "id",
714+
Type: "int",
715+
Pk: true,
716+
},
717+
{
718+
Name: "name",
719+
Type: "varchar(255)",
720+
Nullable: true,
721+
},
722+
},
723+
},
724+
},
725+
},
726+
{
727+
Name: "02_multi_operation",
728+
Operations: migrations.Operations{
729+
&migrations.OpRenameTable{
730+
From: "items",
731+
To: "products",
732+
},
733+
&migrations.OpRenameColumn{
734+
Table: "products",
735+
From: "name",
736+
To: "item_name",
737+
},
738+
&migrations.OpAlterColumn{
739+
Table: "products",
740+
Column: "item_name",
741+
Check: &migrations.CheckConstraint{
742+
Name: "check_name_length",
743+
Constraint: "LENGTH(item_name) > 2",
744+
},
745+
Up: "SELECT CASE WHEN length(item_name) > 2 THEN item_name ELSE item_name || '---' END",
746+
Down: "item_name",
747+
},
748+
},
749+
},
750+
},
751+
afterStart: func(t *testing.T, db *sql.DB, schema string) {
752+
// Can insert a row into the new view that meets the check constraint
753+
MustInsert(t, db, schema, "02_multi_operation", "products", map[string]string{
754+
"id": "1",
755+
"item_name": "abc",
756+
})
757+
758+
// Can't insert a row into the new view that violates the check constraint
759+
MustNotInsert(t, db, schema, "02_multi_operation", "products", map[string]string{
760+
"id": "2",
761+
"item_name": "x",
762+
}, testutils.CheckViolationErrorCode)
763+
764+
// Can insert a row into the old view that violates the check constraint
765+
MustInsert(t, db, schema, "01_create_table", "items", map[string]string{
766+
"id": "3",
767+
"name": "x",
768+
})
769+
770+
// The new view has the expected rows
771+
rows := MustSelect(t, db, schema, "02_multi_operation", "products")
772+
assert.Equal(t, []map[string]any{
773+
{"id": 1, "item_name": "abc"},
774+
{"id": 3, "item_name": "x---"},
775+
}, rows)
776+
777+
// The old view has the expected rows
778+
rows = MustSelect(t, db, schema, "01_create_table", "items")
779+
assert.Equal(t, []map[string]any{
780+
{"id": 1, "name": "abc"},
781+
{"id": 3, "name": "x"},
782+
}, rows)
783+
},
784+
afterRollback: func(t *testing.T, db *sql.DB, schema string) {
785+
// The table has been cleaned up
786+
TableMustBeCleanedUp(t, db, schema, "items", "name")
787+
},
788+
afterComplete: func(t *testing.T, db *sql.DB, schema string) {
789+
// Can insert a row into the new view that meets the check constraint
790+
MustInsert(t, db, schema, "02_multi_operation", "products", map[string]string{
791+
"id": "4",
792+
"item_name": "def",
793+
})
794+
795+
// Can't insert a row into the new view that violates the check constraint
796+
MustNotInsert(t, db, schema, "02_multi_operation", "products", map[string]string{
797+
"id": "5",
798+
"item_name": "x",
799+
}, testutils.CheckViolationErrorCode)
800+
801+
// The new view has the expected rows
802+
rows := MustSelect(t, db, schema, "02_multi_operation", "products")
803+
assert.Equal(t, []map[string]any{
804+
{"id": 1, "item_name": "abc"},
805+
{"id": 3, "item_name": "x---"},
806+
{"id": 4, "item_name": "def"},
807+
}, rows)
808+
},
809+
},
810+
})
811+
}
812+
597813
func TestSetCheckConstraintValidation(t *testing.T) {
598814
t.Parallel()
599815

0 commit comments

Comments
 (0)