Skip to content

Commit a149a4f

Browse files
authored
New DBAction: CreateCheckConstraintAction (#837)
This PR adds a new `DBAction` that is responsible for adding new check constraints. It is used in adding check constraints to columns and tables. Related #742
1 parent 16173dd commit a149a4f

File tree

3 files changed

+49
-43
lines changed

3 files changed

+49
-43
lines changed

pkg/migrations/dbactions.go

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -394,3 +394,50 @@ func (a *validateConstraintAction) Execute(ctx context.Context) error {
394394
pq.QuoteIdentifier(a.constraint)))
395395
return err
396396
}
397+
398+
// CreateCheckConstraintAction creates a check constraint on a table.
399+
type CreateCheckConstraintAction struct {
400+
conn db.DB
401+
table string
402+
columns []string
403+
constraint string
404+
check string
405+
noInherit bool
406+
skipValidation bool
407+
}
408+
409+
func NewCreateCheckConstraintAction(conn db.DB, table, constraint, check string, columns []string, noInherit, skipValidation bool) *CreateCheckConstraintAction {
410+
return &CreateCheckConstraintAction{
411+
conn: conn,
412+
table: table,
413+
columns: columns,
414+
check: check,
415+
constraint: constraint,
416+
noInherit: noInherit,
417+
skipValidation: skipValidation,
418+
}
419+
}
420+
421+
func (a *CreateCheckConstraintAction) Execute(ctx context.Context) error {
422+
sql := fmt.Sprintf("ALTER TABLE %s ADD ", pq.QuoteIdentifier(a.table))
423+
424+
writer := &ConstraintSQLWriter{
425+
Name: a.constraint,
426+
SkipValidation: a.skipValidation,
427+
}
428+
sql += writer.WriteCheck(rewriteCheckExpression(a.check, a.columns...), a.noInherit)
429+
_, err := a.conn.ExecContext(ctx, sql)
430+
return err
431+
}
432+
433+
// In order for the `check` expression to be easy to write, migration authors specify
434+
// the check expression as though it were being applied to the old column,
435+
// On migration start, however, the check is actually applied to the new (temporary)
436+
// column.
437+
// This function naively rewrites the check expression to apply to the new column.
438+
func rewriteCheckExpression(check string, columns ...string) string {
439+
for _, col := range columns {
440+
check = strings.ReplaceAll(check, col, TemporaryName(col))
441+
}
442+
return check
443+
}

pkg/migrations/op_create_constraint.go

Lines changed: 1 addition & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ func (o *OpCreateConstraint) Start(ctx context.Context, l Logger, conn db.DB, la
9393
case OpCreateConstraintTypeUnique, OpCreateConstraintTypePrimaryKey:
9494
return table, NewCreateUniqueIndexConcurrentlyAction(conn, s.Name, o.Name, table.Name, temporaryNames(o.Columns)...).Execute(ctx)
9595
case OpCreateConstraintTypeCheck:
96-
return table, o.addCheckConstraint(ctx, conn, table.Name)
96+
return table, NewCreateCheckConstraintAction(conn, table.Name, o.Name, *o.Check, o.Columns, o.NoInherit, true).Execute(ctx)
9797
case OpCreateConstraintTypeForeignKey:
9898
return table, o.addForeignKeyConstraint(ctx, conn, table)
9999
}
@@ -284,18 +284,6 @@ func (o *OpCreateConstraint) Validate(ctx context.Context, s *schema.Schema) err
284284
return nil
285285
}
286286

287-
func (o *OpCreateConstraint) addCheckConstraint(ctx context.Context, conn db.DB, tableName string) error {
288-
sql := fmt.Sprintf("ALTER TABLE %s ADD ", pq.QuoteIdentifier(tableName))
289-
290-
writer := &ConstraintSQLWriter{
291-
Name: o.Name,
292-
SkipValidation: true,
293-
}
294-
sql += writer.WriteCheck(rewriteCheckExpression(*o.Check, o.Columns...), o.NoInherit)
295-
_, err := conn.ExecContext(ctx, sql)
296-
return err
297-
}
298-
299287
func (o *OpCreateConstraint) addForeignKeyConstraint(ctx context.Context, conn db.DB, table *schema.Table) error {
300288
sql := fmt.Sprintf("ALTER TABLE %s ADD ", pq.QuoteIdentifier(table.Name))
301289

pkg/migrations/op_set_check.go

Lines changed: 1 addition & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,6 @@ package migrations
55
import (
66
"context"
77
"fmt"
8-
"strings"
9-
10-
"github.com/lib/pq"
118

129
"github.com/xataio/pgroll/pkg/db"
1310
"github.com/xataio/pgroll/pkg/schema"
@@ -32,7 +29,7 @@ func (o *OpSetCheckConstraint) Start(ctx context.Context, l Logger, conn db.DB,
3229
}
3330

3431
// Add the check constraint to the new column as NOT VALID.
35-
if err := o.addCheckConstraint(ctx, conn, s); err != nil {
32+
if err := NewCreateCheckConstraintAction(conn, table.Name, o.Check.Name, o.Check.Constraint, []string{o.Column}, o.Check.NoInherit, true).Execute(ctx); err != nil {
3633
return nil, fmt.Errorf("failed to add check constraint: %w", err)
3734
}
3835

@@ -88,29 +85,3 @@ func (o *OpSetCheckConstraint) Validate(ctx context.Context, s *schema.Schema) e
8885

8986
return nil
9087
}
91-
92-
func (o *OpSetCheckConstraint) addCheckConstraint(ctx context.Context, conn db.DB, s *schema.Schema) error {
93-
table := s.GetTable(o.Table)
94-
sql := fmt.Sprintf("ALTER TABLE %s ADD ", pq.QuoteIdentifier(table.Name))
95-
96-
writer := &ConstraintSQLWriter{
97-
Name: o.Check.Name,
98-
SkipValidation: true,
99-
}
100-
sql += writer.WriteCheck(rewriteCheckExpression(o.Check.Constraint, o.Column), o.Check.NoInherit)
101-
_, err := conn.ExecContext(ctx, sql)
102-
103-
return err
104-
}
105-
106-
// In order for the `check` expression to be easy to write, migration authors specify
107-
// the check expression as though it were being applied to the old column,
108-
// On migration start, however, the check is actually applied to the new (temporary)
109-
// column.
110-
// This function naively rewrites the check expression to apply to the new column.
111-
func rewriteCheckExpression(check string, columns ...string) string {
112-
for _, col := range columns {
113-
check = strings.ReplaceAll(check, col, TemporaryName(col))
114-
}
115-
return check
116-
}

0 commit comments

Comments
 (0)