Skip to content

Commit cf50739

Browse files
Use text/template for backfill query generation (#632)
Generate the backfill query with a `text/template` instead of the `batchStatementBuilder`. Using a `text/template` instead of the builder with its `Sprintf` statements make the structure of the query easier to see and modify as part of #583. Part of #583
1 parent c4cd645 commit cf50739

File tree

5 files changed

+276
-147
lines changed

5 files changed

+276
-147
lines changed

pkg/backfill/backfill.go

Lines changed: 22 additions & 93 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,9 @@ import (
77
"database/sql"
88
"errors"
99
"fmt"
10-
"strings"
1110
"time"
1211

13-
"github.com/lib/pq"
14-
12+
"github.com/xataio/pgroll/pkg/backfill/templates"
1513
"github.com/xataio/pgroll/pkg/db"
1614
"github.com/xataio/pgroll/pkg/schema"
1715
)
@@ -59,7 +57,13 @@ func (bf *Backfill) Start(ctx context.Context, table *schema.Table) error {
5957
}
6058

6159
// Create a batcher for the table.
62-
b := newBatcher(table, bf.batchSize)
60+
b := batcher{
61+
BatchConfig: templates.BatchConfig{
62+
TableName: table.Name,
63+
PrimaryKey: identityColumns,
64+
BatchSize: bf.batchSize,
65+
},
66+
}
6367

6468
// Update each batch of rows, invoking callbacks for each one.
6569
for batch := 0; ; batch++ {
@@ -158,109 +162,34 @@ func getIdentityColumns(table *schema.Table) []string {
158162
return nil
159163
}
160164

165+
// A batcher is responsible for updating a batch of rows in a table.
166+
// It holds the state necessary to update the next batch of rows.
161167
type batcher struct {
162-
statementBuilder *batchStatementBuilder
163-
lastValues []string
164-
}
165-
166-
func newBatcher(table *schema.Table, batchSize int) *batcher {
167-
return &batcher{
168-
statementBuilder: newBatchStatementBuilder(table.Name, getIdentityColumns(table), batchSize),
169-
lastValues: make([]string, len(getIdentityColumns(table))),
170-
}
168+
templates.BatchConfig
171169
}
172170

173171
func (b *batcher) updateBatch(ctx context.Context, conn db.DB) error {
174172
return conn.WithRetryableTransaction(ctx, func(ctx context.Context, tx *sql.Tx) error {
175173
// Build the query to update the next batch of rows
176-
query := b.statementBuilder.buildQuery(b.lastValues)
174+
sql, err := templates.BuildSQL(b.BatchConfig)
175+
if err != nil {
176+
return err
177+
}
177178

178179
// Execute the query to update the next batch of rows and update the last PK
179180
// value for the next batch
180-
wrapper := make([]any, len(b.lastValues))
181-
for i := range b.lastValues {
182-
wrapper[i] = &b.lastValues[i]
181+
if b.LastValue == nil {
182+
b.LastValue = make([]string, len(b.PrimaryKey))
183+
}
184+
wrapper := make([]any, len(b.LastValue))
185+
for i := range b.LastValue {
186+
wrapper[i] = &b.LastValue[i]
183187
}
184-
err := tx.QueryRowContext(ctx, query).Scan(wrapper...)
188+
err = tx.QueryRowContext(ctx, sql).Scan(wrapper...)
185189
if err != nil {
186190
return err
187191
}
188192

189193
return nil
190194
})
191195
}
192-
193-
type batchStatementBuilder struct {
194-
tableName string
195-
identityColumns []string
196-
batchSize int
197-
}
198-
199-
func newBatchStatementBuilder(tableName string, identityColumnNames []string, batchSize int) *batchStatementBuilder {
200-
quotedCols := make([]string, len(identityColumnNames))
201-
for i, col := range identityColumnNames {
202-
quotedCols[i] = pq.QuoteIdentifier(col)
203-
}
204-
return &batchStatementBuilder{
205-
tableName: pq.QuoteIdentifier(tableName),
206-
identityColumns: quotedCols,
207-
batchSize: batchSize,
208-
}
209-
}
210-
211-
// buildQuery builds the query used to update the next batch of rows.
212-
func (sb *batchStatementBuilder) buildQuery(lastValues []string) string {
213-
return fmt.Sprintf("WITH batch AS (%[1]s), update AS (%[2]s) %[3]s",
214-
sb.buildBatchSubQuery(lastValues),
215-
sb.buildUpdateBatchSubQuery(),
216-
sb.buildLastValueQuery())
217-
}
218-
219-
// fetch the next batch of PK of rows to update
220-
func (sb *batchStatementBuilder) buildBatchSubQuery(lastValues []string) string {
221-
whereClause := ""
222-
if len(lastValues) != 0 && lastValues[0] != "" {
223-
whereClause = fmt.Sprintf("WHERE (%s) > (%s)",
224-
strings.Join(sb.identityColumns, ", "), strings.Join(quoteLiteralList(lastValues), ", "))
225-
}
226-
227-
return fmt.Sprintf("SELECT %[1]s FROM %[2]s %[3]s ORDER BY %[1]s LIMIT %[4]d FOR NO KEY UPDATE",
228-
strings.Join(sb.identityColumns, ", "), sb.tableName, whereClause, sb.batchSize)
229-
}
230-
231-
func quoteLiteralList(l []string) []string {
232-
quoted := make([]string, len(l))
233-
for i, v := range l {
234-
quoted[i] = pq.QuoteLiteral(v)
235-
}
236-
return quoted
237-
}
238-
239-
// update the rows in the batch
240-
func (sb *batchStatementBuilder) buildUpdateBatchSubQuery() string {
241-
conditions := make([]string, len(sb.identityColumns))
242-
for i, col := range sb.identityColumns {
243-
conditions[i] = fmt.Sprintf("%[1]s.%[2]s = batch.%[2]s", sb.tableName, col)
244-
}
245-
updateWhereClause := "WHERE " + strings.Join(conditions, " AND ")
246-
247-
setStmt := fmt.Sprintf("%[1]s = %[2]s.%[1]s", sb.identityColumns[0], sb.tableName)
248-
for i := 1; i < len(sb.identityColumns); i++ {
249-
setStmt += fmt.Sprintf(", %[1]s = %[2]s.%[1]s", sb.identityColumns[i], sb.tableName)
250-
}
251-
updateReturning := sb.tableName + "." + sb.identityColumns[0]
252-
for i := 1; i < len(sb.identityColumns); i++ {
253-
updateReturning += ", " + sb.tableName + "." + sb.identityColumns[i]
254-
}
255-
return fmt.Sprintf("UPDATE %[1]s SET %[2]s FROM batch %[3]s RETURNING %[4]s",
256-
sb.tableName, setStmt, updateWhereClause, updateReturning)
257-
}
258-
259-
// fetch the last values of the PK column
260-
func (sb *batchStatementBuilder) buildLastValueQuery() string {
261-
lastValues := make([]string, len(sb.identityColumns))
262-
for i, col := range sb.identityColumns {
263-
lastValues[i] = "LAST_VALUE(" + col + ") OVER()"
264-
}
265-
return fmt.Sprintf("SELECT %[1]s FROM update", strings.Join(lastValues, ", "))
266-
}

pkg/backfill/backfill_test.go

Lines changed: 0 additions & 54 deletions
This file was deleted.

pkg/backfill/templates/build.go

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
// SPDX-License-Identifier: Apache-2.0
2+
3+
package templates
4+
5+
import (
6+
"bytes"
7+
"strings"
8+
"text/template"
9+
10+
"github.com/lib/pq"
11+
)
12+
13+
type BatchConfig struct {
14+
TableName string
15+
PrimaryKey []string
16+
LastValue []string
17+
BatchSize int
18+
}
19+
20+
func BuildSQL(cfg BatchConfig) (string, error) {
21+
return executeTemplate("sql", SQL, cfg)
22+
}
23+
24+
func executeTemplate(name, content string, cfg BatchConfig) (string, error) {
25+
ql := pq.QuoteLiteral
26+
qi := pq.QuoteIdentifier
27+
28+
tmpl := template.Must(template.New(name).
29+
Funcs(template.FuncMap{
30+
"ql": ql,
31+
"qi": qi,
32+
"commaSeparate": func(slice []string) string {
33+
return strings.Join(slice, ", ")
34+
},
35+
"quoteIdentifiers": func(slice []string) []string {
36+
quoted := make([]string, len(slice))
37+
for i, s := range slice {
38+
quoted[i] = qi(s)
39+
}
40+
return quoted
41+
},
42+
"quoteLiterals": func(slice []string) []string {
43+
quoted := make([]string, len(slice))
44+
for i, s := range slice {
45+
quoted[i] = ql(s)
46+
}
47+
return quoted
48+
},
49+
"updateSetClause": func(tableName string, columns []string) string {
50+
quoted := make([]string, len(columns))
51+
for i, c := range columns {
52+
quoted[i] = qi(c) + " = " + qi(tableName) + "." + qi(c)
53+
}
54+
return strings.Join(quoted, ", ")
55+
},
56+
"updateWhereClause": func(tableName string, columns []string) string {
57+
quoted := make([]string, len(columns))
58+
for i, c := range columns {
59+
quoted[i] = qi(tableName) + "." + qi(c) + " = batch." + qi(c)
60+
}
61+
return strings.Join(quoted, " AND ")
62+
},
63+
"updateReturnClause": func(tableName string, columns []string) string {
64+
quoted := make([]string, len(columns))
65+
for i, c := range columns {
66+
quoted[i] = qi(tableName) + "." + qi(c)
67+
}
68+
return strings.Join(quoted, ", ")
69+
},
70+
"selectLastValue": func(columns []string) string {
71+
quoted := make([]string, len(columns))
72+
for i, c := range columns {
73+
quoted[i] = "LAST_VALUE(" + qi(c) + ") OVER()"
74+
}
75+
return strings.Join(quoted, ", ")
76+
},
77+
}).
78+
Parse(content))
79+
80+
buf := bytes.Buffer{}
81+
if err := tmpl.Execute(&buf, cfg); err != nil {
82+
return "", err
83+
}
84+
85+
return buf.String(), nil
86+
}

0 commit comments

Comments
 (0)