7
7
"bytes"
8
8
"context"
9
9
"database/sql"
10
+ stderrors "errors"
10
11
"fmt"
11
12
"io"
12
13
"sync"
@@ -32,6 +33,8 @@ type driver interface {
32
33
StartRun (context.Context ) error
33
34
// CommitRun commits a transaction, if there is an error it should rollback the transaction.
34
35
CommitRun (context.Context ) error
36
+ // RollbackRun rolls back a transaction.
37
+ RollbackRun (context.Context ) error
35
38
// CheckHook is a hook that runs prior to a migration's statements.
36
39
// It should run in the same transaction a corresponding Run call.
37
40
CheckHook (context.Context , migration.CheckFunc ) (migration.Problems , error )
@@ -244,57 +247,68 @@ func (b *Manager) runMigrations(ctx context.Context, p *provider.Provider) ([]Re
244
247
const op = "schema.(Manager).runMigrations"
245
248
246
249
var logEntries []RepairLog
247
- var err error
250
+ var errFinal error
248
251
249
252
if startErr := b .driver .StartRun (ctx ); startErr != nil {
250
- err = errors .Wrap (ctx , startErr , op )
251
- return nil , err
253
+ errFinal = errors .Wrap (ctx , startErr , op )
254
+ return nil , errFinal
252
255
}
253
256
254
257
defer func () {
258
+ if errFinal != nil {
259
+ errRollback := b .driver .RollbackRun (ctx )
260
+ if errRollback != nil {
261
+ errFinal = stderrors .Join (errFinal , errRollback )
262
+ }
263
+ errFinal = errors .Wrap (ctx , errFinal , op )
264
+ return
265
+ }
255
266
if commitErr := b .driver .CommitRun (ctx ); commitErr != nil {
256
- err = errors .Wrap (ctx , commitErr , op )
267
+ errFinal = errors .Wrap (ctx , commitErr , op )
257
268
}
258
269
}()
259
270
260
271
if ensureErr := b .driver .EnsureVersionTable (ctx ); ensureErr != nil {
261
- err = errors .Wrap (ctx , ensureErr , op )
262
- return nil , err
272
+ errFinal = errors .Wrap (ctx , ensureErr , op )
273
+ return nil , errFinal
263
274
}
264
275
265
276
if ensureErr := b .driver .EnsureMigrationLogTable (ctx ); ensureErr != nil {
266
- err = errors .Wrap (ctx , ensureErr , op )
267
- return nil , err
277
+ errFinal = errors .Wrap (ctx , ensureErr , op )
278
+ return nil , errFinal
268
279
}
269
280
270
281
for p .Next () {
271
282
select {
272
283
case <- ctx .Done ():
273
- err = errors .Wrap (ctx , ctx .Err (), op )
274
- return nil , err
284
+ errFinal = errors .Wrap (ctx , ctx .Err (), op )
285
+ return nil , errFinal
275
286
default :
276
287
// context is not done yet. Continue on to the next query to execute.
277
288
}
278
289
279
290
if h := p .PreHook (); h != nil {
280
291
problems , err := b .driver .CheckHook (ctx , h .CheckFunc )
281
292
if err != nil {
282
- return nil , errors .Wrap (ctx , err , op )
293
+ errFinal = errors .Wrap (ctx , err , op )
294
+ return nil , errFinal
283
295
}
284
296
285
297
if len (problems ) > 0 {
286
298
if ! b .selectedRepairs .IsSet (p .Edition (), p .Version ()) {
287
- return nil , MigrationCheckError {
299
+ errFinal = MigrationCheckError {
288
300
Version : p .Version (),
289
301
Edition : p .Edition (),
290
302
Problems : problems ,
291
303
RepairDescription : h .RepairDescription ,
292
304
}
305
+ return nil , errFinal
293
306
}
294
307
295
308
repairs , err := b .driver .RepairHook (ctx , h .RepairFunc )
296
309
if err != nil {
297
- return nil , errors .Wrap (ctx , err , op )
310
+ errFinal = errors .Wrap (ctx , err , op )
311
+ return nil , errFinal
298
312
}
299
313
logEntries = append (logEntries , RepairLog {
300
314
Version : p .Version (),
@@ -304,8 +318,8 @@ func (b *Manager) runMigrations(ctx context.Context, p *provider.Provider) ([]Re
304
318
}
305
319
}
306
320
if runErr := b .driver .Run (ctx , bytes .NewReader (p .Statements ()), p .Version (), p .Edition ()); runErr != nil {
307
- err = errors .Wrap (ctx , runErr , op )
308
- return nil , err
321
+ errFinal = errors .Wrap (ctx , runErr , op )
322
+ return nil , errFinal
309
323
}
310
324
}
311
325
0 commit comments