Skip to content

Commit b02c324

Browse files
authored
Support multiple PKs during backfilling (#426)
This PR adds support for using multiple primary keys during the backfilling part of the migration. In this PR I extracted the statement builder into a separate struct to make it easier to check for the expected statements.
1 parent b6f76c7 commit b02c324

File tree

5 files changed

+200
-40
lines changed

5 files changed

+200
-40
lines changed
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
{
2+
"name": "39_add_column_with_multiple_pk_in_table",
3+
"operations": [
4+
{
5+
"add_column": {
6+
"table": "sellers",
7+
"column": {
8+
"name": "rating",
9+
"type": "int",
10+
"default": "10"
11+
}
12+
}
13+
}
14+
]
15+
}

pkg/migrations/backfill.go

Lines changed: 113 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ import (
77
"database/sql"
88
"errors"
99
"fmt"
10+
"strings"
1011
"time"
1112

1213
"github.com/lib/pq"
@@ -23,23 +24,18 @@ import (
2324
// 4. Repeat steps 2 and 3 until no more rows are returned.
2425
func Backfill(ctx context.Context, conn db.DB, table *schema.Table, batchSize int, batchDelay time.Duration, cbs ...CallbackFn) error {
2526
// get the backfill column
26-
identityColumn := getIdentityColumn(table)
27-
if identityColumn == nil {
27+
identityColumns := getIdentityColumns(table)
28+
if identityColumns == nil {
2829
return BackfillNotPossibleError{Table: table.Name}
2930
}
3031

3132
// Create a batcher for the table.
32-
b := batcher{
33-
table: table,
34-
identityColumn: identityColumn,
35-
lastValue: nil,
36-
batchSize: batchSize,
37-
}
33+
b := newBatcher(table, batchSize)
3834

3935
// Update each batch of rows, invoking callbacks for each one.
4036
for batch := 0; ; batch++ {
4137
for _, cb := range cbs {
42-
cb(int64(batch * b.batchSize))
38+
cb(int64(batch * batchSize))
4339
}
4440

4541
if err := b.updateBatch(ctx, conn); err != nil {
@@ -61,25 +57,29 @@ func Backfill(ctx context.Context, conn db.DB, table *schema.Table, batchSize in
6157

6258
// checkBackfill will return an error if the backfill operation is not supported.
6359
func checkBackfill(table *schema.Table) error {
64-
col := getIdentityColumn(table)
65-
if col == nil {
60+
cols := getIdentityColumns(table)
61+
if cols == nil {
6662
return BackfillNotPossibleError{Table: table.Name}
6763
}
6864

6965
return nil
7066
}
7167

7268
// getIdentityColumn will return a column suitable for use in a backfill operation.
73-
func getIdentityColumn(table *schema.Table) *schema.Column {
69+
func getIdentityColumns(table *schema.Table) []string {
7470
pks := table.GetPrimaryKey()
75-
if len(pks) == 1 {
76-
return pks[0]
71+
if len(pks) != 0 {
72+
pkNames := make([]string, len(pks))
73+
for i, pk := range pks {
74+
pkNames[i] = pk.Name
75+
}
76+
return pkNames
7777
}
7878

7979
// If there is no primary key, look for a unique not null column
8080
for _, col := range table.Columns {
8181
if col.Unique && !col.Nullable {
82-
return &col
82+
return []string{col.Name}
8383
}
8484
}
8585

@@ -88,20 +88,25 @@ func getIdentityColumn(table *schema.Table) *schema.Column {
8888
}
8989

9090
type batcher struct {
91-
table *schema.Table
92-
identityColumn *schema.Column
93-
lastValue *string
94-
batchSize int
91+
statementBuilder *batchStatementBuilder
92+
lastValues any
93+
}
94+
95+
func newBatcher(table *schema.Table, batchSize int) *batcher {
96+
return &batcher{
97+
statementBuilder: newBatchStatementBuilder(table.Name, getIdentityColumns(table), batchSize),
98+
lastValues: nil,
99+
}
95100
}
96101

97102
func (b *batcher) updateBatch(ctx context.Context, conn db.DB) error {
98103
return conn.WithRetryableTransaction(ctx, func(ctx context.Context, tx *sql.Tx) error {
99104
// Build the query to update the next batch of rows
100-
query := b.buildQuery()
105+
query := b.statementBuilder.buildQuery(b.lastValues)
101106

102107
// Execute the query to update the next batch of rows and update the last PK
103108
// value for the next batch
104-
err := tx.QueryRowContext(ctx, query).Scan(&b.lastValue)
109+
err := tx.QueryRowContext(ctx, query).Scan(&b.lastValues)
105110
if err != nil {
106111
return err
107112
}
@@ -110,23 +115,95 @@ func (b *batcher) updateBatch(ctx context.Context, conn db.DB) error {
110115
})
111116
}
112117

118+
type batchStatementBuilder struct {
119+
tableName string
120+
identityColumns []string
121+
batchSize int
122+
}
123+
124+
func newBatchStatementBuilder(tableName string, identityColumnNames []string, batchSize int) *batchStatementBuilder {
125+
quotedCols := make([]string, len(identityColumnNames))
126+
for i, col := range identityColumnNames {
127+
quotedCols[i] = pq.QuoteIdentifier(col)
128+
}
129+
return &batchStatementBuilder{
130+
tableName: pq.QuoteIdentifier(tableName),
131+
identityColumns: quotedCols,
132+
batchSize: batchSize,
133+
}
134+
}
135+
113136
// buildQuery builds the query used to update the next batch of rows.
114-
func (b *batcher) buildQuery() string {
137+
func (sb *batchStatementBuilder) buildQuery(lastValues any) string {
138+
return fmt.Sprintf("WITH batch AS (%[1]s), update AS (%[2]s) %[3]s",
139+
sb.buildBatchSubQuery(lastValues),
140+
sb.buildUpdateBatchSubQuery(),
141+
sb.buildLastValueQuery())
142+
}
143+
144+
// fetch the next batch of PK of rows to update
145+
func (sb *batchStatementBuilder) buildBatchSubQuery(lastValues any) string {
115146
whereClause := ""
116-
if b.lastValue != nil {
117-
whereClause = fmt.Sprintf("WHERE %s > %v", pq.QuoteIdentifier(b.identityColumn.Name), pq.QuoteLiteral(*b.lastValue))
147+
if lastValues != nil {
148+
conditions := make([]string, len(sb.identityColumns))
149+
switch lastVals := lastValues.(type) {
150+
case []int64:
151+
for i, col := range sb.identityColumns {
152+
conditions[i] = fmt.Sprintf("%s > %d", col, lastVals[i])
153+
}
154+
case []string:
155+
for i, col := range sb.identityColumns {
156+
conditions[i] = fmt.Sprintf("%s > %s", col, pq.QuoteLiteral(lastVals[i]))
157+
}
158+
case []any:
159+
for i, col := range sb.identityColumns {
160+
if v, ok := lastVals[i].(int); ok {
161+
conditions[i] = fmt.Sprintf("%s > %d", col, v)
162+
} else if v, ok := lastVals[i].(string); ok {
163+
conditions[i] = fmt.Sprintf("%s > %s", col, pq.QuoteLiteral(v))
164+
} else {
165+
panic("unsupported type")
166+
}
167+
}
168+
case int64:
169+
conditions[0] = fmt.Sprintf("%s > %d ", sb.identityColumns[0], lastVals)
170+
case string:
171+
conditions[0] = fmt.Sprintf("%s > %s ", sb.identityColumns[0], pq.QuoteLiteral(lastVals))
172+
default:
173+
panic("unsupported type")
174+
}
175+
whereClause = "WHERE " + strings.Join(conditions, " AND ")
118176
}
119177

120-
return fmt.Sprintf(`
121-
WITH batch AS (
122-
SELECT %[1]s FROM %[2]s %[4]s ORDER BY %[1]s LIMIT %[3]d FOR NO KEY UPDATE
123-
), update AS (
124-
UPDATE %[2]s SET %[1]s=%[2]s.%[1]s FROM batch WHERE %[2]s.%[1]s = batch.%[1]s RETURNING %[2]s.%[1]s
125-
)
126-
SELECT LAST_VALUE(%[1]s) OVER() FROM update
127-
`,
128-
pq.QuoteIdentifier(b.identityColumn.Name),
129-
pq.QuoteIdentifier(b.table.Name),
130-
b.batchSize,
131-
whereClause)
178+
return fmt.Sprintf("SELECT %[1]s FROM %[2]s %[3]s ORDER BY %[1]s LIMIT %[4]d FOR NO KEY UPDATE",
179+
strings.Join(sb.identityColumns, ", "), sb.tableName, whereClause, sb.batchSize)
180+
}
181+
182+
// update the rows in the batch
183+
func (sb *batchStatementBuilder) buildUpdateBatchSubQuery() string {
184+
conditions := make([]string, len(sb.identityColumns))
185+
for i, col := range sb.identityColumns {
186+
conditions[i] = fmt.Sprintf("%[1]s.%[2]s = batch.%[2]s", sb.tableName, col)
187+
}
188+
updateWhereClause := "WHERE " + strings.Join(conditions, " AND ")
189+
190+
setStmt := fmt.Sprintf("%[1]s = %[2]s.%[1]s", sb.identityColumns[0], sb.tableName)
191+
for i := 1; i < len(sb.identityColumns); i++ {
192+
setStmt += fmt.Sprintf(", %[1]s = %[2]s.%[1]s", sb.identityColumns[i], sb.tableName)
193+
}
194+
updateReturning := sb.tableName + "." + sb.identityColumns[0]
195+
for i := 1; i < len(sb.identityColumns); i++ {
196+
updateReturning += ", " + sb.tableName + "." + sb.identityColumns[i]
197+
}
198+
return fmt.Sprintf("UPDATE %[1]s SET %[2]s FROM batch %[3]s RETURNING %[4]s",
199+
sb.tableName, setStmt, updateWhereClause, updateReturning)
200+
}
201+
202+
// fetch the last values of the PK column
203+
func (sb *batchStatementBuilder) buildLastValueQuery() string {
204+
lastValues := make([]string, len(sb.identityColumns))
205+
for i, col := range sb.identityColumns {
206+
lastValues[i] = "LAST_VALUE(" + col + ") OVER()"
207+
}
208+
return fmt.Sprintf("SELECT %[1]s FROM update", strings.Join(lastValues, ", "))
132209
}

pkg/migrations/backfill_test.go

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
// SPDX-License-Identifier: Apache-2.0
2+
3+
package migrations
4+
5+
import (
6+
"testing"
7+
8+
"github.com/stretchr/testify/assert"
9+
)
10+
11+
func TestBatchStatementBuilder(t *testing.T) {
12+
tests := map[string]struct {
13+
tableName string
14+
identityColumns []string
15+
batchSize int
16+
lasValues any
17+
expected string
18+
}{
19+
"single identity column no last value": {
20+
tableName: "table_name",
21+
identityColumns: []string{"id"},
22+
batchSize: 10,
23+
expected: `WITH batch AS (SELECT "id" FROM "table_name" ORDER BY "id" LIMIT 10 FOR NO KEY UPDATE), update AS (UPDATE "table_name" SET "id" = "table_name"."id" FROM batch WHERE "table_name"."id" = batch."id" RETURNING "table_name"."id") SELECT LAST_VALUE("id") OVER() FROM update`,
24+
},
25+
"multiple identity columns no last value": {
26+
tableName: "table_name",
27+
identityColumns: []string{"id", "zip"},
28+
batchSize: 10,
29+
expected: `WITH batch AS (SELECT "id", "zip" FROM "table_name" ORDER BY "id", "zip" LIMIT 10 FOR NO KEY UPDATE), update AS (UPDATE "table_name" SET "id" = "table_name"."id", "zip" = "table_name"."zip" FROM batch WHERE "table_name"."id" = batch."id" AND "table_name"."zip" = batch."zip" RETURNING "table_name"."id", "table_name"."zip") SELECT LAST_VALUE("id") OVER(), LAST_VALUE("zip") OVER() FROM update`,
30+
},
31+
"single identity column with last value": {
32+
tableName: "table_name",
33+
identityColumns: []string{"id"},
34+
batchSize: 10,
35+
lasValues: []int64{1},
36+
expected: `WITH batch AS (SELECT "id" FROM "table_name" WHERE "id" > 1 ORDER BY "id" LIMIT 10 FOR NO KEY UPDATE), update AS (UPDATE "table_name" SET "id" = "table_name"."id" FROM batch WHERE "table_name"."id" = batch."id" RETURNING "table_name"."id") SELECT LAST_VALUE("id") OVER() FROM update`,
37+
},
38+
"multiple identity columns with last value": {
39+
tableName: "table_name",
40+
identityColumns: []string{"id", "zip"},
41+
batchSize: 10,
42+
lasValues: []int64{1, 1234},
43+
expected: `WITH batch AS (SELECT "id", "zip" FROM "table_name" WHERE "id" > 1 AND "zip" > 1234 ORDER BY "id", "zip" LIMIT 10 FOR NO KEY UPDATE), update AS (UPDATE "table_name" SET "id" = "table_name"."id", "zip" = "table_name"."zip" FROM batch WHERE "table_name"."id" = batch."id" AND "table_name"."zip" = batch."zip" RETURNING "table_name"."id", "table_name"."zip") SELECT LAST_VALUE("id") OVER(), LAST_VALUE("zip") OVER() FROM update`,
44+
},
45+
"multiple string identity columns with last value": {
46+
tableName: "table_name",
47+
identityColumns: []string{"id", "zip"},
48+
batchSize: 10,
49+
lasValues: []string{"1", "1234"},
50+
expected: `WITH batch AS (SELECT "id", "zip" FROM "table_name" WHERE "id" > '1' AND "zip" > '1234' ORDER BY "id", "zip" LIMIT 10 FOR NO KEY UPDATE), update AS (UPDATE "table_name" SET "id" = "table_name"."id", "zip" = "table_name"."zip" FROM batch WHERE "table_name"."id" = batch."id" AND "table_name"."zip" = batch."zip" RETURNING "table_name"."id", "table_name"."zip") SELECT LAST_VALUE("id") OVER(), LAST_VALUE("zip") OVER() FROM update`,
51+
},
52+
"multiple different identity columns with last value": {
53+
tableName: "table_name",
54+
identityColumns: []string{"id", "zip"},
55+
batchSize: 10,
56+
lasValues: []any{1, "1234"},
57+
expected: `WITH batch AS (SELECT "id", "zip" FROM "table_name" WHERE "id" > 1 AND "zip" > '1234' ORDER BY "id", "zip" LIMIT 10 FOR NO KEY UPDATE), update AS (UPDATE "table_name" SET "id" = "table_name"."id", "zip" = "table_name"."zip" FROM batch WHERE "table_name"."id" = batch."id" AND "table_name"."zip" = batch."zip" RETURNING "table_name"."id", "table_name"."zip") SELECT LAST_VALUE("id") OVER(), LAST_VALUE("zip") OVER() FROM update`,
58+
},
59+
}
60+
61+
for name, test := range tests {
62+
t.Run(name, func(t *testing.T) {
63+
builder := newBatchStatementBuilder(test.tableName, test.identityColumns, test.batchSize)
64+
actual := builder.buildQuery(test.lasValues)
65+
assert.Equal(t, test.expected, actual)
66+
})
67+
}
68+
}

pkg/migrations/op_add_column_test.go

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1164,7 +1164,7 @@ func TestAddColumnValidation(t *testing.T) {
11641164
wantStartErr: migrations.FieldRequiredError{Name: "up"},
11651165
},
11661166
{
1167-
name: "table must have a primary key on exactly one column if up is defined",
1167+
name: "table can have multiple primary keys",
11681168
migrations: []migrations.Migration{
11691169
{
11701170
Name: "01_add_table",
@@ -1189,7 +1189,6 @@ func TestAddColumnValidation(t *testing.T) {
11891189
},
11901190
},
11911191
},
1192-
wantStartErr: migrations.BackfillNotPossibleError{Table: "orders"},
11931192
},
11941193
{
11951194
name: "table has no restrictions on primary keys if up is not defined",

pkg/migrations/op_alter_column_test.go

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -681,7 +681,7 @@ func TestAlterColumnValidation(t *testing.T) {
681681
wantStartErr: migrations.AlterColumnNoChangesError{Table: "posts", Column: "title"},
682682
},
683683
{
684-
name: "if a backfill is required, the table must have a primary key on exactly one column",
684+
name: "backfill with multiple primary keys",
685685
migrations: []migrations.Migration{
686686
{
687687
Name: "01_add_table",
@@ -699,11 +699,12 @@ func TestAlterColumnValidation(t *testing.T) {
699699
Table: "orders",
700700
Column: "quantity",
701701
Nullable: ptr(false),
702+
Up: "1",
702703
},
703704
},
704705
},
705706
},
706-
wantStartErr: migrations.BackfillNotPossibleError{Table: "orders"},
707+
wantStartErr: nil,
707708
},
708709
{
709710
name: "rename-only operations don't have primary key requirements",

0 commit comments

Comments
 (0)