@@ -7,11 +7,9 @@ import (
7
7
"database/sql"
8
8
"errors"
9
9
"fmt"
10
- "strings"
11
10
"time"
12
11
13
- "github.com/lib/pq"
14
-
12
+ "github.com/xataio/pgroll/pkg/backfill/templates"
15
13
"github.com/xataio/pgroll/pkg/db"
16
14
"github.com/xataio/pgroll/pkg/schema"
17
15
)
@@ -59,7 +57,13 @@ func (bf *Backfill) Start(ctx context.Context, table *schema.Table) error {
59
57
}
60
58
61
59
// Create a batcher for the table.
62
- b := newBatcher (table , bf .batchSize )
60
+ b := batcher {
61
+ BatchConfig : templates.BatchConfig {
62
+ TableName : table .Name ,
63
+ PrimaryKey : identityColumns ,
64
+ BatchSize : bf .batchSize ,
65
+ },
66
+ }
63
67
64
68
// Update each batch of rows, invoking callbacks for each one.
65
69
for batch := 0 ; ; batch ++ {
@@ -158,109 +162,34 @@ func getIdentityColumns(table *schema.Table) []string {
158
162
return nil
159
163
}
160
164
165
+ // A batcher is responsible for updating a batch of rows in a table.
166
+ // It holds the state necessary to update the next batch of rows.
161
167
type batcher struct {
162
- statementBuilder * batchStatementBuilder
163
- lastValues []string
164
- }
165
-
166
- func newBatcher (table * schema.Table , batchSize int ) * batcher {
167
- return & batcher {
168
- statementBuilder : newBatchStatementBuilder (table .Name , getIdentityColumns (table ), batchSize ),
169
- lastValues : make ([]string , len (getIdentityColumns (table ))),
170
- }
168
+ templates.BatchConfig
171
169
}
172
170
173
171
func (b * batcher ) updateBatch (ctx context.Context , conn db.DB ) error {
174
172
return conn .WithRetryableTransaction (ctx , func (ctx context.Context , tx * sql.Tx ) error {
175
173
// Build the query to update the next batch of rows
176
- query := b .statementBuilder .buildQuery (b .lastValues )
174
+ sql , err := templates .BuildSQL (b .BatchConfig )
175
+ if err != nil {
176
+ return err
177
+ }
177
178
178
179
// Execute the query to update the next batch of rows and update the last PK
179
180
// value for the next batch
180
- wrapper := make ([]any , len (b .lastValues ))
181
- for i := range b .lastValues {
182
- wrapper [i ] = & b .lastValues [i ]
181
+ if b .LastValue == nil {
182
+ b .LastValue = make ([]string , len (b .PrimaryKey ))
183
+ }
184
+ wrapper := make ([]any , len (b .LastValue ))
185
+ for i := range b .LastValue {
186
+ wrapper [i ] = & b .LastValue [i ]
183
187
}
184
- err : = tx .QueryRowContext (ctx , query ).Scan (wrapper ... )
188
+ err = tx .QueryRowContext (ctx , sql ).Scan (wrapper ... )
185
189
if err != nil {
186
190
return err
187
191
}
188
192
189
193
return nil
190
194
})
191
195
}
192
-
193
- type batchStatementBuilder struct {
194
- tableName string
195
- identityColumns []string
196
- batchSize int
197
- }
198
-
199
- func newBatchStatementBuilder (tableName string , identityColumnNames []string , batchSize int ) * batchStatementBuilder {
200
- quotedCols := make ([]string , len (identityColumnNames ))
201
- for i , col := range identityColumnNames {
202
- quotedCols [i ] = pq .QuoteIdentifier (col )
203
- }
204
- return & batchStatementBuilder {
205
- tableName : pq .QuoteIdentifier (tableName ),
206
- identityColumns : quotedCols ,
207
- batchSize : batchSize ,
208
- }
209
- }
210
-
211
- // buildQuery builds the query used to update the next batch of rows.
212
- func (sb * batchStatementBuilder ) buildQuery (lastValues []string ) string {
213
- return fmt .Sprintf ("WITH batch AS (%[1]s), update AS (%[2]s) %[3]s" ,
214
- sb .buildBatchSubQuery (lastValues ),
215
- sb .buildUpdateBatchSubQuery (),
216
- sb .buildLastValueQuery ())
217
- }
218
-
219
- // fetch the next batch of PK of rows to update
220
- func (sb * batchStatementBuilder ) buildBatchSubQuery (lastValues []string ) string {
221
- whereClause := ""
222
- if len (lastValues ) != 0 && lastValues [0 ] != "" {
223
- whereClause = fmt .Sprintf ("WHERE (%s) > (%s)" ,
224
- strings .Join (sb .identityColumns , ", " ), strings .Join (quoteLiteralList (lastValues ), ", " ))
225
- }
226
-
227
- return fmt .Sprintf ("SELECT %[1]s FROM %[2]s %[3]s ORDER BY %[1]s LIMIT %[4]d FOR NO KEY UPDATE" ,
228
- strings .Join (sb .identityColumns , ", " ), sb .tableName , whereClause , sb .batchSize )
229
- }
230
-
231
- func quoteLiteralList (l []string ) []string {
232
- quoted := make ([]string , len (l ))
233
- for i , v := range l {
234
- quoted [i ] = pq .QuoteLiteral (v )
235
- }
236
- return quoted
237
- }
238
-
239
- // update the rows in the batch
240
- func (sb * batchStatementBuilder ) buildUpdateBatchSubQuery () string {
241
- conditions := make ([]string , len (sb .identityColumns ))
242
- for i , col := range sb .identityColumns {
243
- conditions [i ] = fmt .Sprintf ("%[1]s.%[2]s = batch.%[2]s" , sb .tableName , col )
244
- }
245
- updateWhereClause := "WHERE " + strings .Join (conditions , " AND " )
246
-
247
- setStmt := fmt .Sprintf ("%[1]s = %[2]s.%[1]s" , sb .identityColumns [0 ], sb .tableName )
248
- for i := 1 ; i < len (sb .identityColumns ); i ++ {
249
- setStmt += fmt .Sprintf (", %[1]s = %[2]s.%[1]s" , sb .identityColumns [i ], sb .tableName )
250
- }
251
- updateReturning := sb .tableName + "." + sb .identityColumns [0 ]
252
- for i := 1 ; i < len (sb .identityColumns ); i ++ {
253
- updateReturning += ", " + sb .tableName + "." + sb .identityColumns [i ]
254
- }
255
- return fmt .Sprintf ("UPDATE %[1]s SET %[2]s FROM batch %[3]s RETURNING %[4]s" ,
256
- sb .tableName , setStmt , updateWhereClause , updateReturning )
257
- }
258
-
259
- // fetch the last values of the PK column
260
- func (sb * batchStatementBuilder ) buildLastValueQuery () string {
261
- lastValues := make ([]string , len (sb .identityColumns ))
262
- for i , col := range sb .identityColumns {
263
- lastValues [i ] = "LAST_VALUE(" + col + ") OVER()"
264
- }
265
- return fmt .Sprintf ("SELECT %[1]s FROM update" , strings .Join (lastValues , ", " ))
266
- }
0 commit comments