Skip to content

Commit 5d63a7f

Browse files
kvchandrew-farries
andauthored
Duplicate unique and check constraints correctly (#466)
Previously, unique and check constraints with multiple columns were not duplicated correctly. We had two issues: 1. Pgroll tried to create the same duplicated index for each renamed column. From now on pgroll creates the index once with the new column names. 2. Pgroll tried to convert the unique index to a constraint multiple times. From now on, the index is only converted once. Required by #464 --------- Co-authored-by: Andrew Farries <[email protected]>
1 parent 85e917c commit 5d63a7f

File tree

7 files changed

+375
-165
lines changed

7 files changed

+375
-165
lines changed

pkg/migrations/duplicate.go

Lines changed: 195 additions & 102 deletions
Original file line numberDiff line numberDiff line change
@@ -17,174 +17,267 @@ import (
1717
// Duplicator duplicates a column in a table, including all constraints and
1818
// comments.
1919
type Duplicator struct {
20+
stmtBuilder *duplicatorStmtBuilder
2021
conn db.DB
21-
table *schema.Table
22-
column *schema.Column
23-
asName string
24-
withoutNotNull bool
25-
withType string
26-
withoutConstraint string
22+
columns map[string]*columnToDuplicate
23+
withoutConstraint []string
24+
}
25+
26+
type columnToDuplicate struct {
27+
column *schema.Column
28+
asName string
29+
withoutNotNull bool
30+
withType string
31+
}
32+
33+
// duplicatorStmtBuilder is a helper for building SQL statements to duplicate
34+
// columns and constraints in a table.
35+
type duplicatorStmtBuilder struct {
36+
table *schema.Table
2737
}
2838

2939
const (
3040
dataTypeMismatchErrorCode pq.ErrorCode = "42804"
3141
undefinedFunctionErrorCode pq.ErrorCode = "42883"
42+
43+
cCreateUniqueIndexSQL = `CREATE UNIQUE INDEX CONCURRENTLY %s ON %s (%s)`
44+
cSetDefaultSQL = `ALTER TABLE %s ALTER COLUMN %s SET DEFAULT %s`
45+
cAlterTableAddCheckConstraintSQL = `ALTER TABLE %s ADD CONSTRAINT %s %s NOT VALID`
3246
)
3347

3448
// NewColumnDuplicator creates a new Duplicator for a column.
35-
func NewColumnDuplicator(conn db.DB, table *schema.Table, column *schema.Column) *Duplicator {
49+
func NewColumnDuplicator(conn db.DB, table *schema.Table, columns ...*schema.Column) *Duplicator {
50+
cols := make(map[string]*columnToDuplicate, len(columns))
51+
for _, column := range columns {
52+
cols[column.Name] = &columnToDuplicate{
53+
column: column,
54+
asName: TemporaryName(column.Name),
55+
withType: column.Type,
56+
}
57+
}
3658
return &Duplicator{
37-
conn: conn,
38-
table: table,
39-
column: column,
40-
asName: TemporaryName(column.Name),
41-
withType: column.Type,
59+
stmtBuilder: &duplicatorStmtBuilder{
60+
table: table,
61+
},
62+
conn: conn,
63+
columns: cols,
64+
withoutConstraint: make([]string, 0),
4265
}
4366
}
4467

4568
// WithType sets the type of the new column.
46-
func (d *Duplicator) WithType(t string) *Duplicator {
47-
d.withType = t
69+
func (d *Duplicator) WithType(columnName, t string) *Duplicator {
70+
d.columns[columnName].withType = t
4871
return d
4972
}
5073

5174
// WithoutConstraint excludes a constraint from being duplicated.
5275
func (d *Duplicator) WithoutConstraint(c string) *Duplicator {
53-
d.withoutConstraint = c
76+
d.withoutConstraint = append(d.withoutConstraint, c)
5477
return d
5578
}
5679

5780
// WithoutNotNull excludes the NOT NULL constraint from being duplicated.
58-
func (d *Duplicator) WithoutNotNull() *Duplicator {
59-
d.withoutNotNull = true
81+
func (d *Duplicator) WithoutNotNull(columnName string) *Duplicator {
82+
d.columns[columnName].withoutNotNull = true
6083
return d
6184
}
6285

6386
// Duplicate duplicates a column in the table, including all constraints and
6487
// comments.
6588
func (d *Duplicator) Duplicate(ctx context.Context) error {
89+
colNames := make([]string, 0, len(d.columns))
90+
for name, c := range d.columns {
91+
colNames = append(colNames, name)
92+
93+
// Duplicate the column with the new type
94+
// and check and fk constraints
95+
if sql := d.stmtBuilder.duplicateColumn(c.column, c.asName, c.withoutNotNull, c.withType, d.withoutConstraint); sql != "" {
96+
_, err := d.conn.ExecContext(ctx, sql)
97+
if err != nil {
98+
return err
99+
}
100+
}
101+
102+
// Duplicate the column's default value
103+
if sql := d.stmtBuilder.duplicateDefault(c.column, c.asName); sql != "" {
104+
_, err := d.conn.ExecContext(ctx, sql)
105+
err = errorIgnoringErrorCode(err, dataTypeMismatchErrorCode)
106+
if err != nil {
107+
return err
108+
}
109+
}
110+
111+
if sql := d.stmtBuilder.duplicateComment(c.column, c.asName); sql != "" {
112+
_, err := d.conn.ExecContext(ctx, sql)
113+
if err != nil {
114+
return err
115+
}
116+
}
117+
}
118+
119+
// Generate SQL to duplicate any check constraints on the columns. This may faile
120+
// if the check constraint is not valid for the new column type, in which case
121+
// the error is ignored.
122+
for _, sql := range d.stmtBuilder.duplicateCheckConstraints(d.withoutConstraint, colNames...) {
123+
// Update the check constraint expression to use the new column names if any of the columns are duplicated
124+
_, err := d.conn.ExecContext(ctx, sql)
125+
err = errorIgnoringErrorCode(err, undefinedFunctionErrorCode)
126+
if err != nil {
127+
return err
128+
}
129+
}
130+
131+
// Generate SQL to duplicate any unique constraints on the columns
132+
// The constraint is duplicated by adding a unique index on the column concurrently.
133+
// The index is converted into a unique constraint on migration completion.
134+
for _, sql := range d.stmtBuilder.duplicateUniqueConstraints(d.withoutConstraint, colNames...) {
135+
// Update the unique constraint columns to use the new column names if any of the columns are duplicated
136+
if _, err := d.conn.ExecContext(ctx, sql); err != nil {
137+
return err
138+
}
139+
}
140+
141+
return nil
142+
}
143+
144+
func (d *duplicatorStmtBuilder) duplicateCheckConstraints(withoutConstraint []string, colNames ...string) []string {
145+
stmts := make([]string, 0, len(d.table.CheckConstraints))
146+
for _, cc := range d.table.CheckConstraints {
147+
if slices.Contains(withoutConstraint, cc.Name) {
148+
continue
149+
}
150+
if duplicatedConstraintColumns := d.duplicatedConstraintColumns(cc.Columns, colNames...); len(duplicatedConstraintColumns) > 0 {
151+
stmts = append(stmts, fmt.Sprintf(cAlterTableAddCheckConstraintSQL,
152+
pq.QuoteIdentifier(d.table.Name),
153+
pq.QuoteIdentifier(DuplicationName(cc.Name)),
154+
rewriteCheckExpression(cc.Definition, duplicatedConstraintColumns...),
155+
))
156+
}
157+
}
158+
return stmts
159+
}
160+
161+
func (d *duplicatorStmtBuilder) duplicateUniqueConstraints(withoutConstraint []string, colNames ...string) []string {
162+
stmts := make([]string, 0, len(d.table.UniqueConstraints))
163+
for _, uc := range d.table.UniqueConstraints {
164+
if slices.Contains(withoutConstraint, uc.Name) {
165+
continue
166+
}
167+
if duplicatedMember, constraintColumns := d.allConstraintColumns(uc.Columns, colNames...); duplicatedMember {
168+
stmts = append(stmts, fmt.Sprintf(cCreateUniqueIndexSQL,
169+
pq.QuoteIdentifier(DuplicationName(uc.Name)),
170+
pq.QuoteIdentifier(d.table.Name),
171+
strings.Join(quoteColumnNames(constraintColumns), ", "),
172+
))
173+
}
174+
}
175+
return stmts
176+
}
177+
178+
// duplicatedConstraintColumns returns a new slice of constraint columns with
179+
// the columns that are duplicated replaced with temporary names.
180+
func (d *duplicatorStmtBuilder) duplicatedConstraintColumns(constraintColumns []string, duplicatedColumns ...string) []string {
181+
newConstraintColumns := make([]string, 0)
182+
for _, column := range constraintColumns {
183+
if slices.Contains(duplicatedColumns, column) {
184+
newConstraintColumns = append(newConstraintColumns, column)
185+
}
186+
}
187+
return newConstraintColumns
188+
}
189+
190+
// allConstraintColumns returns a new slice of constraint columns with the columns
191+
// that are duplicated replaced with temporary names and a boolean indicating if
192+
// any of the columns are duplicated.
193+
func (d *duplicatorStmtBuilder) allConstraintColumns(constraintColumns []string, duplicatedColumns ...string) (bool, []string) {
194+
duplicatedMember := false
195+
newConstraintColumns := make([]string, len(constraintColumns))
196+
for i, column := range constraintColumns {
197+
if slices.Contains(duplicatedColumns, column) {
198+
newConstraintColumns[i] = TemporaryName(column)
199+
duplicatedMember = true
200+
} else {
201+
newConstraintColumns[i] = column
202+
}
203+
}
204+
return duplicatedMember, newConstraintColumns
205+
}
206+
207+
func (d *duplicatorStmtBuilder) duplicateColumn(
208+
column *schema.Column,
209+
asName string,
210+
withoutNotNull bool,
211+
withType string,
212+
withoutConstraint []string,
213+
) string {
66214
const (
67-
cAlterTableSQL = `ALTER TABLE %s ADD COLUMN %s %s`
68-
cAddForeignKeySQL = `ADD CONSTRAINT %s FOREIGN KEY (%s) REFERENCES %s (%s) ON DELETE %s`
69-
cAddCheckConstraintSQL = `ADD CONSTRAINT %s %s NOT VALID`
70-
cCreateUniqueIndexSQL = `CREATE UNIQUE INDEX CONCURRENTLY %s ON %s (%s)`
71-
cSetDefaultSQL = `ALTER TABLE %s ALTER COLUMN %s SET DEFAULT %s`
72-
cAlterTableAddCheckConstraintSQL = `ALTER TABLE %s ADD CONSTRAINT %s %s NOT VALID`
73-
cCommentOnColumnSQL = `COMMENT ON COLUMN %s.%s IS %s`
215+
cAlterTableSQL = `ALTER TABLE %s ADD COLUMN %s %s`
216+
cAddForeignKeySQL = `ADD CONSTRAINT %s FOREIGN KEY (%s) REFERENCES %s (%s) ON DELETE %s`
217+
cAddCheckConstraintSQL = `ADD CONSTRAINT %s %s NOT VALID`
74218
)
75219

76220
// Generate SQL to duplicate the column's name and type
77221
sql := fmt.Sprintf(cAlterTableSQL,
78222
pq.QuoteIdentifier(d.table.Name),
79-
pq.QuoteIdentifier(d.asName),
80-
d.withType)
223+
pq.QuoteIdentifier(asName),
224+
withType)
81225

82226
// Generate SQL to add an unchecked NOT NULL constraint if the original column
83227
// is NOT NULL. The constraint will be validated on migration completion.
84-
if !d.column.Nullable && !d.withoutNotNull {
228+
if !column.Nullable && !withoutNotNull {
85229
sql += fmt.Sprintf(", "+cAddCheckConstraintSQL,
86-
pq.QuoteIdentifier(DuplicationName(NotNullConstraintName(d.column.Name))),
87-
fmt.Sprintf("CHECK (%s IS NOT NULL)", pq.QuoteIdentifier(d.asName)),
230+
pq.QuoteIdentifier(DuplicationName(NotNullConstraintName(column.Name))),
231+
fmt.Sprintf("CHECK (%s IS NOT NULL)", pq.QuoteIdentifier(asName)),
88232
)
89233
}
90234

91235
// Generate SQL to duplicate any foreign key constraints on the column
92236
for _, fk := range d.table.ForeignKeys {
93-
if fk.Name == d.withoutConstraint {
237+
if slices.Contains(withoutConstraint, fk.Name) {
94238
continue
95239
}
96240

97-
if slices.Contains(fk.Columns, d.column.Name) {
241+
if slices.Contains(fk.Columns, column.Name) {
98242
sql += fmt.Sprintf(", "+cAddForeignKeySQL,
99243
pq.QuoteIdentifier(DuplicationName(fk.Name)),
100-
strings.Join(quoteColumnNames(copyAndReplace(fk.Columns, d.column.Name, d.asName)), ", "),
244+
strings.Join(quoteColumnNames(copyAndReplace(fk.Columns, column.Name, asName)), ", "),
101245
pq.QuoteIdentifier(fk.ReferencedTable),
102246
strings.Join(quoteColumnNames(fk.ReferencedColumns), ", "),
103247
fk.OnDelete,
104248
)
105249
}
106250
}
107251

108-
_, err := d.conn.ExecContext(ctx, sql)
109-
if err != nil {
110-
return err
252+
return sql
253+
}
254+
255+
func (d *duplicatorStmtBuilder) duplicateDefault(column *schema.Column, asName string) string {
256+
if column.Default == nil {
257+
return ""
111258
}
112259

260+
const cSetDefaultSQL = `ALTER TABLE %s ALTER COLUMN %s SET DEFAULT %s`
261+
113262
// Generate SQL to duplicate any default value on the column. This may fail
114263
// if the default value is not valid for the new column type, in which case
115264
// the error is ignored.
116-
if d.column.Default != nil {
117-
sql := fmt.Sprintf(cSetDefaultSQL, pq.QuoteIdentifier(d.table.Name), d.asName, *d.column.Default)
118-
119-
_, err := d.conn.ExecContext(ctx, sql)
265+
return fmt.Sprintf(cSetDefaultSQL, pq.QuoteIdentifier(d.table.Name), asName, *column.Default)
266+
}
120267

121-
err = errorIgnoringErrorCode(err, dataTypeMismatchErrorCode)
122-
if err != nil {
123-
return err
124-
}
268+
func (d *duplicatorStmtBuilder) duplicateComment(column *schema.Column, asName string) string {
269+
if column.Comment == "" {
270+
return ""
125271
}
126272

127-
// Generate SQL to duplicate any check constraints on the column. This may faile
128-
// if the check constraint is not valid for the new column type, in which case
129-
// the error is ignored.
130-
for _, cc := range d.table.CheckConstraints {
131-
if cc.Name == d.withoutConstraint {
132-
continue
133-
}
134-
135-
if slices.Contains(cc.Columns, d.column.Name) {
136-
sql := fmt.Sprintf(cAlterTableAddCheckConstraintSQL,
137-
pq.QuoteIdentifier(d.table.Name),
138-
pq.QuoteIdentifier(DuplicationName(cc.Name)),
139-
rewriteCheckExpression(cc.Definition, d.column.Name, d.asName),
140-
)
141-
142-
_, err := d.conn.ExecContext(ctx, sql)
143-
144-
err = errorIgnoringErrorCode(err, undefinedFunctionErrorCode)
145-
if err != nil {
146-
return err
147-
}
148-
}
149-
}
273+
const cCommentOnColumnSQL = `COMMENT ON COLUMN %s.%s IS %s`
150274

151275
// Generate SQL to duplicate the column's comment
152-
if d.column.Comment != "" {
153-
sql = fmt.Sprintf(cCommentOnColumnSQL,
154-
pq.QuoteIdentifier(d.table.Name),
155-
pq.QuoteIdentifier(d.asName),
156-
pq.QuoteLiteral(d.column.Comment),
157-
)
158-
159-
_, err = d.conn.ExecContext(ctx, sql)
160-
if err != nil {
161-
return err
162-
}
163-
}
164-
165-
// Generate SQL to duplicate any unique constraints on the column
166-
// The constraint is duplicated by adding a unique index on the column concurrently.
167-
// The index is converted into a unique constraint on migration completion.
168-
for _, uc := range d.table.UniqueConstraints {
169-
if uc.Name == d.withoutConstraint {
170-
continue
171-
}
172-
173-
if slices.Contains(uc.Columns, d.column.Name) {
174-
sql = fmt.Sprintf(cCreateUniqueIndexSQL,
175-
pq.QuoteIdentifier(DuplicationName(uc.Name)),
176-
pq.QuoteIdentifier(d.table.Name),
177-
strings.Join(quoteColumnNames(copyAndReplace(uc.Columns, d.column.Name, d.asName)), ", "),
178-
)
179-
180-
_, err = d.conn.ExecContext(ctx, sql)
181-
if err != nil {
182-
return err
183-
}
184-
}
185-
}
186-
187-
return nil
276+
return fmt.Sprintf(cCommentOnColumnSQL,
277+
pq.QuoteIdentifier(d.table.Name),
278+
pq.QuoteIdentifier(asName),
279+
pq.QuoteLiteral(column.Comment),
280+
)
188281
}
189282

190283
// DiplicationName returns the name of a duplicated column.

0 commit comments

Comments
 (0)