@@ -17,174 +17,267 @@ import (
17
17
// Duplicator duplicates a column in a table, including all constraints and
18
18
// comments.
19
19
type Duplicator struct {
20
+ stmtBuilder * duplicatorStmtBuilder
20
21
conn db.DB
21
- table * schema.Table
22
- column * schema.Column
23
- asName string
24
- withoutNotNull bool
25
- withType string
26
- withoutConstraint string
22
+ columns map [string ]* columnToDuplicate
23
+ withoutConstraint []string
24
+ }
25
+
26
+ type columnToDuplicate struct {
27
+ column * schema.Column
28
+ asName string
29
+ withoutNotNull bool
30
+ withType string
31
+ }
32
+
33
+ // duplicatorStmtBuilder is a helper for building SQL statements to duplicate
34
+ // columns and constraints in a table.
35
+ type duplicatorStmtBuilder struct {
36
+ table * schema.Table
27
37
}
28
38
29
39
const (
30
40
dataTypeMismatchErrorCode pq.ErrorCode = "42804"
31
41
undefinedFunctionErrorCode pq.ErrorCode = "42883"
42
+
43
+ cCreateUniqueIndexSQL = `CREATE UNIQUE INDEX CONCURRENTLY %s ON %s (%s)`
44
+ cSetDefaultSQL = `ALTER TABLE %s ALTER COLUMN %s SET DEFAULT %s`
45
+ cAlterTableAddCheckConstraintSQL = `ALTER TABLE %s ADD CONSTRAINT %s %s NOT VALID`
32
46
)
33
47
34
48
// NewColumnDuplicator creates a new Duplicator for a column.
35
- func NewColumnDuplicator (conn db.DB , table * schema.Table , column * schema.Column ) * Duplicator {
49
+ func NewColumnDuplicator (conn db.DB , table * schema.Table , columns ... * schema.Column ) * Duplicator {
50
+ cols := make (map [string ]* columnToDuplicate , len (columns ))
51
+ for _ , column := range columns {
52
+ cols [column .Name ] = & columnToDuplicate {
53
+ column : column ,
54
+ asName : TemporaryName (column .Name ),
55
+ withType : column .Type ,
56
+ }
57
+ }
36
58
return & Duplicator {
37
- conn : conn ,
38
- table : table ,
39
- column : column ,
40
- asName : TemporaryName (column .Name ),
41
- withType : column .Type ,
59
+ stmtBuilder : & duplicatorStmtBuilder {
60
+ table : table ,
61
+ },
62
+ conn : conn ,
63
+ columns : cols ,
64
+ withoutConstraint : make ([]string , 0 ),
42
65
}
43
66
}
44
67
45
68
// WithType sets the type of the new column.
46
- func (d * Duplicator ) WithType (t string ) * Duplicator {
47
- d .withType = t
69
+ func (d * Duplicator ) WithType (columnName , t string ) * Duplicator {
70
+ d .columns [ columnName ]. withType = t
48
71
return d
49
72
}
50
73
51
74
// WithoutConstraint excludes a constraint from being duplicated.
52
75
func (d * Duplicator ) WithoutConstraint (c string ) * Duplicator {
53
- d .withoutConstraint = c
76
+ d .withoutConstraint = append ( d . withoutConstraint , c )
54
77
return d
55
78
}
56
79
57
80
// WithoutNotNull excludes the NOT NULL constraint from being duplicated.
58
- func (d * Duplicator ) WithoutNotNull () * Duplicator {
59
- d .withoutNotNull = true
81
+ func (d * Duplicator ) WithoutNotNull (columnName string ) * Duplicator {
82
+ d .columns [ columnName ]. withoutNotNull = true
60
83
return d
61
84
}
62
85
63
86
// Duplicate duplicates a column in the table, including all constraints and
64
87
// comments.
65
88
func (d * Duplicator ) Duplicate (ctx context.Context ) error {
89
+ colNames := make ([]string , 0 , len (d .columns ))
90
+ for name , c := range d .columns {
91
+ colNames = append (colNames , name )
92
+
93
+ // Duplicate the column with the new type
94
+ // and check and fk constraints
95
+ if sql := d .stmtBuilder .duplicateColumn (c .column , c .asName , c .withoutNotNull , c .withType , d .withoutConstraint ); sql != "" {
96
+ _ , err := d .conn .ExecContext (ctx , sql )
97
+ if err != nil {
98
+ return err
99
+ }
100
+ }
101
+
102
+ // Duplicate the column's default value
103
+ if sql := d .stmtBuilder .duplicateDefault (c .column , c .asName ); sql != "" {
104
+ _ , err := d .conn .ExecContext (ctx , sql )
105
+ err = errorIgnoringErrorCode (err , dataTypeMismatchErrorCode )
106
+ if err != nil {
107
+ return err
108
+ }
109
+ }
110
+
111
+ if sql := d .stmtBuilder .duplicateComment (c .column , c .asName ); sql != "" {
112
+ _ , err := d .conn .ExecContext (ctx , sql )
113
+ if err != nil {
114
+ return err
115
+ }
116
+ }
117
+ }
118
+
119
+ // Generate SQL to duplicate any check constraints on the columns. This may faile
120
+ // if the check constraint is not valid for the new column type, in which case
121
+ // the error is ignored.
122
+ for _ , sql := range d .stmtBuilder .duplicateCheckConstraints (d .withoutConstraint , colNames ... ) {
123
+ // Update the check constraint expression to use the new column names if any of the columns are duplicated
124
+ _ , err := d .conn .ExecContext (ctx , sql )
125
+ err = errorIgnoringErrorCode (err , undefinedFunctionErrorCode )
126
+ if err != nil {
127
+ return err
128
+ }
129
+ }
130
+
131
+ // Generate SQL to duplicate any unique constraints on the columns
132
+ // The constraint is duplicated by adding a unique index on the column concurrently.
133
+ // The index is converted into a unique constraint on migration completion.
134
+ for _ , sql := range d .stmtBuilder .duplicateUniqueConstraints (d .withoutConstraint , colNames ... ) {
135
+ // Update the unique constraint columns to use the new column names if any of the columns are duplicated
136
+ if _ , err := d .conn .ExecContext (ctx , sql ); err != nil {
137
+ return err
138
+ }
139
+ }
140
+
141
+ return nil
142
+ }
143
+
144
+ func (d * duplicatorStmtBuilder ) duplicateCheckConstraints (withoutConstraint []string , colNames ... string ) []string {
145
+ stmts := make ([]string , 0 , len (d .table .CheckConstraints ))
146
+ for _ , cc := range d .table .CheckConstraints {
147
+ if slices .Contains (withoutConstraint , cc .Name ) {
148
+ continue
149
+ }
150
+ if duplicatedConstraintColumns := d .duplicatedConstraintColumns (cc .Columns , colNames ... ); len (duplicatedConstraintColumns ) > 0 {
151
+ stmts = append (stmts , fmt .Sprintf (cAlterTableAddCheckConstraintSQL ,
152
+ pq .QuoteIdentifier (d .table .Name ),
153
+ pq .QuoteIdentifier (DuplicationName (cc .Name )),
154
+ rewriteCheckExpression (cc .Definition , duplicatedConstraintColumns ... ),
155
+ ))
156
+ }
157
+ }
158
+ return stmts
159
+ }
160
+
161
+ func (d * duplicatorStmtBuilder ) duplicateUniqueConstraints (withoutConstraint []string , colNames ... string ) []string {
162
+ stmts := make ([]string , 0 , len (d .table .UniqueConstraints ))
163
+ for _ , uc := range d .table .UniqueConstraints {
164
+ if slices .Contains (withoutConstraint , uc .Name ) {
165
+ continue
166
+ }
167
+ if duplicatedMember , constraintColumns := d .allConstraintColumns (uc .Columns , colNames ... ); duplicatedMember {
168
+ stmts = append (stmts , fmt .Sprintf (cCreateUniqueIndexSQL ,
169
+ pq .QuoteIdentifier (DuplicationName (uc .Name )),
170
+ pq .QuoteIdentifier (d .table .Name ),
171
+ strings .Join (quoteColumnNames (constraintColumns ), ", " ),
172
+ ))
173
+ }
174
+ }
175
+ return stmts
176
+ }
177
+
178
+ // duplicatedConstraintColumns returns a new slice of constraint columns with
179
+ // the columns that are duplicated replaced with temporary names.
180
+ func (d * duplicatorStmtBuilder ) duplicatedConstraintColumns (constraintColumns []string , duplicatedColumns ... string ) []string {
181
+ newConstraintColumns := make ([]string , 0 )
182
+ for _ , column := range constraintColumns {
183
+ if slices .Contains (duplicatedColumns , column ) {
184
+ newConstraintColumns = append (newConstraintColumns , column )
185
+ }
186
+ }
187
+ return newConstraintColumns
188
+ }
189
+
190
+ // allConstraintColumns returns a new slice of constraint columns with the columns
191
+ // that are duplicated replaced with temporary names and a boolean indicating if
192
+ // any of the columns are duplicated.
193
+ func (d * duplicatorStmtBuilder ) allConstraintColumns (constraintColumns []string , duplicatedColumns ... string ) (bool , []string ) {
194
+ duplicatedMember := false
195
+ newConstraintColumns := make ([]string , len (constraintColumns ))
196
+ for i , column := range constraintColumns {
197
+ if slices .Contains (duplicatedColumns , column ) {
198
+ newConstraintColumns [i ] = TemporaryName (column )
199
+ duplicatedMember = true
200
+ } else {
201
+ newConstraintColumns [i ] = column
202
+ }
203
+ }
204
+ return duplicatedMember , newConstraintColumns
205
+ }
206
+
207
+ func (d * duplicatorStmtBuilder ) duplicateColumn (
208
+ column * schema.Column ,
209
+ asName string ,
210
+ withoutNotNull bool ,
211
+ withType string ,
212
+ withoutConstraint []string ,
213
+ ) string {
66
214
const (
67
- cAlterTableSQL = `ALTER TABLE %s ADD COLUMN %s %s`
68
- cAddForeignKeySQL = `ADD CONSTRAINT %s FOREIGN KEY (%s) REFERENCES %s (%s) ON DELETE %s`
69
- cAddCheckConstraintSQL = `ADD CONSTRAINT %s %s NOT VALID`
70
- cCreateUniqueIndexSQL = `CREATE UNIQUE INDEX CONCURRENTLY %s ON %s (%s)`
71
- cSetDefaultSQL = `ALTER TABLE %s ALTER COLUMN %s SET DEFAULT %s`
72
- cAlterTableAddCheckConstraintSQL = `ALTER TABLE %s ADD CONSTRAINT %s %s NOT VALID`
73
- cCommentOnColumnSQL = `COMMENT ON COLUMN %s.%s IS %s`
215
+ cAlterTableSQL = `ALTER TABLE %s ADD COLUMN %s %s`
216
+ cAddForeignKeySQL = `ADD CONSTRAINT %s FOREIGN KEY (%s) REFERENCES %s (%s) ON DELETE %s`
217
+ cAddCheckConstraintSQL = `ADD CONSTRAINT %s %s NOT VALID`
74
218
)
75
219
76
220
// Generate SQL to duplicate the column's name and type
77
221
sql := fmt .Sprintf (cAlterTableSQL ,
78
222
pq .QuoteIdentifier (d .table .Name ),
79
- pq .QuoteIdentifier (d . asName ),
80
- d . withType )
223
+ pq .QuoteIdentifier (asName ),
224
+ withType )
81
225
82
226
// Generate SQL to add an unchecked NOT NULL constraint if the original column
83
227
// is NOT NULL. The constraint will be validated on migration completion.
84
- if ! d . column .Nullable && ! d . withoutNotNull {
228
+ if ! column .Nullable && ! withoutNotNull {
85
229
sql += fmt .Sprintf (", " + cAddCheckConstraintSQL ,
86
- pq .QuoteIdentifier (DuplicationName (NotNullConstraintName (d . column .Name ))),
87
- fmt .Sprintf ("CHECK (%s IS NOT NULL)" , pq .QuoteIdentifier (d . asName )),
230
+ pq .QuoteIdentifier (DuplicationName (NotNullConstraintName (column .Name ))),
231
+ fmt .Sprintf ("CHECK (%s IS NOT NULL)" , pq .QuoteIdentifier (asName )),
88
232
)
89
233
}
90
234
91
235
// Generate SQL to duplicate any foreign key constraints on the column
92
236
for _ , fk := range d .table .ForeignKeys {
93
- if fk . Name == d . withoutConstraint {
237
+ if slices . Contains ( withoutConstraint , fk . Name ) {
94
238
continue
95
239
}
96
240
97
- if slices .Contains (fk .Columns , d . column .Name ) {
241
+ if slices .Contains (fk .Columns , column .Name ) {
98
242
sql += fmt .Sprintf (", " + cAddForeignKeySQL ,
99
243
pq .QuoteIdentifier (DuplicationName (fk .Name )),
100
- strings .Join (quoteColumnNames (copyAndReplace (fk .Columns , d . column .Name , d . asName )), ", " ),
244
+ strings .Join (quoteColumnNames (copyAndReplace (fk .Columns , column .Name , asName )), ", " ),
101
245
pq .QuoteIdentifier (fk .ReferencedTable ),
102
246
strings .Join (quoteColumnNames (fk .ReferencedColumns ), ", " ),
103
247
fk .OnDelete ,
104
248
)
105
249
}
106
250
}
107
251
108
- _ , err := d .conn .ExecContext (ctx , sql )
109
- if err != nil {
110
- return err
252
+ return sql
253
+ }
254
+
255
+ func (d * duplicatorStmtBuilder ) duplicateDefault (column * schema.Column , asName string ) string {
256
+ if column .Default == nil {
257
+ return ""
111
258
}
112
259
260
+ const cSetDefaultSQL = `ALTER TABLE %s ALTER COLUMN %s SET DEFAULT %s`
261
+
113
262
// Generate SQL to duplicate any default value on the column. This may fail
114
263
// if the default value is not valid for the new column type, in which case
115
264
// the error is ignored.
116
- if d .column .Default != nil {
117
- sql := fmt .Sprintf (cSetDefaultSQL , pq .QuoteIdentifier (d .table .Name ), d .asName , * d .column .Default )
118
-
119
- _ , err := d .conn .ExecContext (ctx , sql )
265
+ return fmt .Sprintf (cSetDefaultSQL , pq .QuoteIdentifier (d .table .Name ), asName , * column .Default )
266
+ }
120
267
121
- err = errorIgnoringErrorCode (err , dataTypeMismatchErrorCode )
122
- if err != nil {
123
- return err
124
- }
268
+ func (d * duplicatorStmtBuilder ) duplicateComment (column * schema.Column , asName string ) string {
269
+ if column .Comment == "" {
270
+ return ""
125
271
}
126
272
127
- // Generate SQL to duplicate any check constraints on the column. This may faile
128
- // if the check constraint is not valid for the new column type, in which case
129
- // the error is ignored.
130
- for _ , cc := range d .table .CheckConstraints {
131
- if cc .Name == d .withoutConstraint {
132
- continue
133
- }
134
-
135
- if slices .Contains (cc .Columns , d .column .Name ) {
136
- sql := fmt .Sprintf (cAlterTableAddCheckConstraintSQL ,
137
- pq .QuoteIdentifier (d .table .Name ),
138
- pq .QuoteIdentifier (DuplicationName (cc .Name )),
139
- rewriteCheckExpression (cc .Definition , d .column .Name , d .asName ),
140
- )
141
-
142
- _ , err := d .conn .ExecContext (ctx , sql )
143
-
144
- err = errorIgnoringErrorCode (err , undefinedFunctionErrorCode )
145
- if err != nil {
146
- return err
147
- }
148
- }
149
- }
273
+ const cCommentOnColumnSQL = `COMMENT ON COLUMN %s.%s IS %s`
150
274
151
275
// Generate SQL to duplicate the column's comment
152
- if d .column .Comment != "" {
153
- sql = fmt .Sprintf (cCommentOnColumnSQL ,
154
- pq .QuoteIdentifier (d .table .Name ),
155
- pq .QuoteIdentifier (d .asName ),
156
- pq .QuoteLiteral (d .column .Comment ),
157
- )
158
-
159
- _ , err = d .conn .ExecContext (ctx , sql )
160
- if err != nil {
161
- return err
162
- }
163
- }
164
-
165
- // Generate SQL to duplicate any unique constraints on the column
166
- // The constraint is duplicated by adding a unique index on the column concurrently.
167
- // The index is converted into a unique constraint on migration completion.
168
- for _ , uc := range d .table .UniqueConstraints {
169
- if uc .Name == d .withoutConstraint {
170
- continue
171
- }
172
-
173
- if slices .Contains (uc .Columns , d .column .Name ) {
174
- sql = fmt .Sprintf (cCreateUniqueIndexSQL ,
175
- pq .QuoteIdentifier (DuplicationName (uc .Name )),
176
- pq .QuoteIdentifier (d .table .Name ),
177
- strings .Join (quoteColumnNames (copyAndReplace (uc .Columns , d .column .Name , d .asName )), ", " ),
178
- )
179
-
180
- _ , err = d .conn .ExecContext (ctx , sql )
181
- if err != nil {
182
- return err
183
- }
184
- }
185
- }
186
-
187
- return nil
276
+ return fmt .Sprintf (cCommentOnColumnSQL ,
277
+ pq .QuoteIdentifier (d .table .Name ),
278
+ pq .QuoteIdentifier (asName ),
279
+ pq .QuoteLiteral (column .Comment ),
280
+ )
188
281
}
189
282
190
283
// DiplicationName returns the name of a duplicated column.
0 commit comments