Skip to content

Commit d927f3e

Browse files
Support table and column rename operations preceding drop_multicolumn_constraint operations (#684)
Ensure that `drop_multicolumn_constraint` operations can be preceded by rename table and rename column operations as in the following example: ```json { "name": "24_drop_constraint", "operations": [ { "rename_table": { "from": "items", "to": "products" } }, { "rename_column": { "table": "products", "from": "name", "to": "item_name" } }, { "drop_multicolumn_constraint": { "table": "products", "name": "name_length", "up": { "item_name": "item_name" }, "down": { "item_name": "SELECT CASE WHEN length(item_name) <= 3 THEN LPAD(item_name, 4, '-') ELSE item_name END" } } } ] } ``` Part of #239.
1 parent 97dbca2 commit d927f3e

File tree

4 files changed

+262
-6
lines changed

4 files changed

+262
-6
lines changed

pkg/migrations/op_drop_multicolumn_constraint.go

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,13 @@ func (o *OpDropMultiColumnConstraint) Start(ctx context.Context, conn db.DB, lat
3030
}
3131
}
3232

33-
// Duplicate each of the columns covered by the constraint to be dropped
33+
// Duplicate each of the columns covered by the constraint to be dropped.
34+
// Each column is duplicated assuming its final name after the migration is
35+
// completed.
3436
d := NewColumnDuplicator(conn, table, columns...).WithoutConstraint(o.Name)
37+
for _, colName := range constraintColumns {
38+
d = d.WithName(table.GetColumn(colName).Name, TemporaryName(colName))
39+
}
3540
if err := d.Duplicate(ctx); err != nil {
3641
return nil, fmt.Errorf("failed to duplicate column: %w", err)
3742
}
@@ -45,7 +50,7 @@ func (o *OpDropMultiColumnConstraint) Start(ctx context.Context, conn db.DB, lat
4550
Columns: table.Columns,
4651
SchemaName: s.Name,
4752
LatestSchema: latestSchema,
48-
TableName: o.Table,
53+
TableName: table.Name,
4954
PhysicalColumn: TemporaryName(columnName),
5055
SQL: o.upSQL(columnName),
5156
})
@@ -55,7 +60,9 @@ func (o *OpDropMultiColumnConstraint) Start(ctx context.Context, conn db.DB, lat
5560

5661
// Add the new column to the internal schema representation. This is done
5762
// here, before creation of the down trigger, so that the trigger can declare
58-
// a variable for the new column.
63+
// a variable for the new column. Save the old column name for use as the
64+
// physical column name in the down trigger first.
65+
oldPhysicalColumn := table.GetColumn(columnName).Name
5966
table.AddColumn(columnName, &schema.Column{
6067
Name: TemporaryName(columnName),
6168
})
@@ -67,8 +74,8 @@ func (o *OpDropMultiColumnConstraint) Start(ctx context.Context, conn db.DB, lat
6774
Columns: table.Columns,
6875
SchemaName: s.Name,
6976
LatestSchema: latestSchema,
70-
TableName: o.Table,
71-
PhysicalColumn: columnName,
77+
TableName: table.Name,
78+
PhysicalColumn: oldPhysicalColumn,
7279
SQL: o.Down[columnName],
7380
})
7481
if err != nil {
@@ -125,7 +132,7 @@ func (o *OpDropMultiColumnConstraint) Rollback(ctx context.Context, conn db.DB,
125132
for _, columnName := range table.GetConstraintColumns(o.Name) {
126133
// Drop the new column
127134
_, err := conn.ExecContext(ctx, fmt.Sprintf("ALTER TABLE %s DROP COLUMN IF EXISTS %s",
128-
pq.QuoteIdentifier(o.Table),
135+
pq.QuoteIdentifier(table.Name),
129136
pq.QuoteIdentifier(TemporaryName(columnName)),
130137
))
131138
if err != nil {

pkg/migrations/op_drop_multicolumn_constraint_test.go

Lines changed: 218 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -410,6 +410,224 @@ func TestDropMultiColumnConstraint(t *testing.T) {
410410
})
411411
}
412412

413+
func TestDropMultiColumnConstraintInMultiOperationMigrations(t *testing.T) {
414+
t.Parallel()
415+
416+
ExecuteTests(t, TestCases{
417+
{
418+
name: "rename table, drop constraint",
419+
migrations: []migrations.Migration{
420+
{
421+
Name: "01_create_table",
422+
Operations: migrations.Operations{
423+
&migrations.OpCreateTable{
424+
Name: "items",
425+
Columns: []migrations.Column{
426+
{
427+
Name: "id",
428+
Type: "int",
429+
Pk: true,
430+
},
431+
{
432+
Name: "name",
433+
Type: "varchar(255)",
434+
Nullable: true,
435+
Check: &migrations.CheckConstraint{
436+
Name: "check_name_length",
437+
Constraint: "length(name) > 3",
438+
},
439+
},
440+
},
441+
},
442+
},
443+
},
444+
{
445+
Name: "02_multi_operation",
446+
Operations: migrations.Operations{
447+
&migrations.OpRenameTable{
448+
From: "items",
449+
To: "products",
450+
},
451+
&migrations.OpDropMultiColumnConstraint{
452+
Table: "products",
453+
Name: "check_name_length",
454+
Down: map[string]string{
455+
"name": "SELECT CASE WHEN length(name) <= 3 THEN LPAD(name, 4, '-') ELSE name END",
456+
},
457+
Up: map[string]string{
458+
"name": "name",
459+
},
460+
},
461+
},
462+
},
463+
},
464+
afterStart: func(t *testing.T, db *sql.DB, schema string) {
465+
// Can insert a row into the new schema that violates the constraint
466+
MustInsert(t, db, schema, "02_multi_operation", "products", map[string]string{
467+
"id": "1",
468+
"name": "a",
469+
})
470+
471+
// Can't insert a row into the old schema that violates the constraint
472+
MustNotInsert(t, db, schema, "01_create_table", "items", map[string]string{
473+
"id": "2",
474+
"name": "b",
475+
}, testutils.CheckViolationErrorCode)
476+
477+
// Can insert a row into the old schema that meets the constraint
478+
MustInsert(t, db, schema, "01_create_table", "items", map[string]string{
479+
"id": "2",
480+
"name": "bananas",
481+
})
482+
483+
// The new view has the expected rows
484+
rows := MustSelect(t, db, schema, "02_multi_operation", "products")
485+
assert.Equal(t, []map[string]any{
486+
{"id": 1, "name": "a"},
487+
{"id": 2, "name": "bananas"},
488+
}, rows)
489+
490+
// The old view has the expected rows
491+
rows = MustSelect(t, db, schema, "01_create_table", "items")
492+
assert.Equal(t, []map[string]any{
493+
{"id": 1, "name": "---a"}, // rewritten by the down migration
494+
{"id": 2, "name": "bananas"},
495+
}, rows)
496+
},
497+
afterRollback: func(t *testing.T, db *sql.DB, schema string) {
498+
// The table has been cleaned up
499+
TableMustBeCleanedUp(t, db, schema, "items", "name")
500+
},
501+
afterComplete: func(t *testing.T, db *sql.DB, schema string) {
502+
// Can insert a row into the new schema that violates the constraint
503+
MustInsert(t, db, schema, "02_multi_operation", "products", map[string]string{
504+
"id": "3",
505+
"name": "c",
506+
})
507+
508+
// The new view has the expected rows
509+
rows := MustSelect(t, db, schema, "02_multi_operation", "products")
510+
assert.Equal(t, []map[string]any{
511+
{"id": 1, "name": "---a"},
512+
{"id": 2, "name": "bananas"},
513+
{"id": 3, "name": "c"},
514+
}, rows)
515+
516+
// The table has been cleaned up
517+
TableMustBeCleanedUp(t, db, schema, "products", "name")
518+
},
519+
},
520+
{
521+
name: "rename table, rename column, drop constraint",
522+
migrations: []migrations.Migration{
523+
{
524+
Name: "01_create_table",
525+
Operations: migrations.Operations{
526+
&migrations.OpCreateTable{
527+
Name: "items",
528+
Columns: []migrations.Column{
529+
{
530+
Name: "id",
531+
Type: "int",
532+
Pk: true,
533+
},
534+
{
535+
Name: "name",
536+
Type: "varchar(255)",
537+
Nullable: true,
538+
Check: &migrations.CheckConstraint{
539+
Name: "check_name_length",
540+
Constraint: "length(name) > 3",
541+
},
542+
},
543+
},
544+
},
545+
},
546+
},
547+
{
548+
Name: "02_multi_operation",
549+
Operations: migrations.Operations{
550+
&migrations.OpRenameTable{
551+
From: "items",
552+
To: "products",
553+
},
554+
&migrations.OpRenameColumn{
555+
Table: "products",
556+
From: "name",
557+
To: "item_name",
558+
},
559+
&migrations.OpDropMultiColumnConstraint{
560+
Table: "products",
561+
Name: "check_name_length",
562+
Down: map[string]string{
563+
"item_name": "SELECT CASE WHEN length(item_name) <= 3 THEN LPAD(item_name, 4, '-') ELSE item_name END",
564+
},
565+
Up: map[string]string{
566+
"item_name": "item_name",
567+
},
568+
},
569+
},
570+
},
571+
},
572+
afterStart: func(t *testing.T, db *sql.DB, schema string) {
573+
// Can insert a row into the new schema that violates the constraint
574+
MustInsert(t, db, schema, "02_multi_operation", "products", map[string]string{
575+
"id": "1",
576+
"item_name": "a",
577+
})
578+
579+
// Can't insert a row into the old schema that violates the constraint
580+
MustNotInsert(t, db, schema, "01_create_table", "items", map[string]string{
581+
"id": "2",
582+
"name": "b",
583+
}, testutils.CheckViolationErrorCode)
584+
585+
// Can insert a row into the old schema that meets the constraint
586+
MustInsert(t, db, schema, "01_create_table", "items", map[string]string{
587+
"id": "2",
588+
"name": "bananas",
589+
})
590+
591+
// The new view has the expected rows
592+
rows := MustSelect(t, db, schema, "02_multi_operation", "products")
593+
assert.Equal(t, []map[string]any{
594+
{"id": 1, "item_name": "a"},
595+
{"id": 2, "item_name": "bananas"},
596+
}, rows)
597+
598+
// The old view has the expected rows
599+
rows = MustSelect(t, db, schema, "01_create_table", "items")
600+
assert.Equal(t, []map[string]any{
601+
{"id": 1, "name": "---a"}, // rewritten by the down migration
602+
{"id": 2, "name": "bananas"},
603+
}, rows)
604+
},
605+
afterRollback: func(t *testing.T, db *sql.DB, schema string) {
606+
// The table has been cleaned up
607+
TableMustBeCleanedUp(t, db, schema, "items", "name")
608+
},
609+
afterComplete: func(t *testing.T, db *sql.DB, schema string) {
610+
// Can insert a row into the new schema that violates the constraint
611+
MustInsert(t, db, schema, "02_multi_operation", "products", map[string]string{
612+
"id": "3",
613+
"item_name": "c",
614+
})
615+
616+
// The new view has the expected rows
617+
rows := MustSelect(t, db, schema, "02_multi_operation", "products")
618+
assert.Equal(t, []map[string]any{
619+
{"id": 1, "item_name": "---a"},
620+
{"id": 2, "item_name": "bananas"},
621+
{"id": 3, "item_name": "c"},
622+
}, rows)
623+
624+
// The table has been cleaned up
625+
TableMustBeCleanedUp(t, db, schema, "products", "name")
626+
},
627+
},
628+
})
629+
}
630+
413631
func TestDropMultiColumnConstraintValidation(t *testing.T) {
414632
t.Parallel()
415633

pkg/migrations/op_rename_column.go

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,10 @@ func (o *OpRenameColumn) Start(ctx context.Context, conn db.DB, latestSchema str
2424
}
2525
table.RenameColumn(o.From, o.To)
2626

27+
// Update the name of the column in any constraints that reference the
28+
// renamed column.
29+
table.RenameConstraintColumns(o.From, o.To)
30+
2731
return nil, nil
2832
}
2933

@@ -44,6 +48,10 @@ func (o *OpRenameColumn) Complete(ctx context.Context, conn db.DB, tr SQLTransfo
4448
column := table.GetColumn(o.To)
4549
column.Name = o.To
4650

51+
// Update the name of the column in any constraints that reference the
52+
// renamed column.
53+
table.RenameConstraintColumns(o.From, o.To)
54+
4755
return err
4856
}
4957

@@ -86,6 +94,7 @@ func (o *OpRenameColumn) Validate(ctx context.Context, s *schema.Schema) error {
8694
// Update the in-memory schema to reflect the column rename so that it is
8795
// visible to subsequent operations' validation steps.
8896
table.RenameColumn(o.From, o.To)
97+
table.RenameConstraintColumns(o.From, o.To)
8998

9099
return nil
91100
}

pkg/schema/schema.go

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -277,6 +277,28 @@ func (t *Table) GetConstraintColumns(name string) []string {
277277
return slices.Compact(columns)
278278
}
279279

280+
// RenameConstraintColumns renames all occurrences of a column name in any
281+
// constraint on the table from `from` to `to`.
282+
func (t *Table) RenameConstraintColumns(from, to string) {
283+
updateColumns := func(columns []string) {
284+
for i, c := range columns {
285+
if c == from {
286+
columns[i] = to
287+
}
288+
}
289+
}
290+
291+
for _, cc := range t.CheckConstraints {
292+
updateColumns(cc.Columns)
293+
}
294+
for _, uc := range t.UniqueConstraints {
295+
updateColumns(uc.Columns)
296+
}
297+
for _, fk := range t.ForeignKeys {
298+
updateColumns(fk.Columns)
299+
}
300+
}
301+
280302
// GetPrimaryKey returns the columns that make up the primary key
281303
func (t *Table) GetPrimaryKey() (columns []*Column) {
282304
for _, name := range t.PrimaryKey {

0 commit comments

Comments
 (0)