Skip to content

Commit 9a24c2d

Browse files
backport of commit 9e48cc7 (#5907)
Co-authored-by: Michael Milton <[email protected]>
1 parent 9e1d3d4 commit 9a24c2d

File tree

3 files changed

+63
-16
lines changed

3 files changed

+63
-16
lines changed

internal/db/schema/internal/postgres/postgres.go

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -202,6 +202,24 @@ func (p *Postgres) CommitRun(ctx context.Context) error {
202202
return nil
203203
}
204204

205+
// RollbackRun rolls back a transaction.
206+
func (p *Postgres) RollbackRun(ctx context.Context) error {
207+
const op = "postgres.(Postgres).RollbackRun"
208+
defer func() {
209+
p.tx = nil
210+
}()
211+
if p.tx == nil {
212+
return errors.New(ctx, errors.MigrationIntegrity, op, "no pending transaction")
213+
}
214+
if err := p.tx.Rollback(); err != nil {
215+
if errors.Is(err, sql.ErrTxDone) {
216+
return nil
217+
}
218+
return errors.Wrap(ctx, err, op)
219+
}
220+
return nil
221+
}
222+
205223
// Run will apply a migration. The io.Reader should provide the SQL
206224
// statements to execute, and the int is the version for that set of
207225
// statements. This should always be wrapped by StartRun and CommitRun.

internal/db/schema/manager.go

Lines changed: 29 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ import (
77
"bytes"
88
"context"
99
"database/sql"
10+
stderrors "errors"
1011
"fmt"
1112
"io"
1213
"sync"
@@ -32,6 +33,8 @@ type driver interface {
3233
StartRun(context.Context) error
3334
// CommitRun commits a transaction, if there is an error it should rollback the transaction.
3435
CommitRun(context.Context) error
36+
// RollbackRun rolls back a transaction.
37+
RollbackRun(context.Context) error
3538
// CheckHook is a hook that runs prior to a migration's statements.
3639
// It should run in the same transaction a corresponding Run call.
3740
CheckHook(context.Context, migration.CheckFunc) (migration.Problems, error)
@@ -244,57 +247,68 @@ func (b *Manager) runMigrations(ctx context.Context, p *provider.Provider) ([]Re
244247
const op = "schema.(Manager).runMigrations"
245248

246249
var logEntries []RepairLog
247-
var err error
250+
var errFinal error
248251

249252
if startErr := b.driver.StartRun(ctx); startErr != nil {
250-
err = errors.Wrap(ctx, startErr, op)
251-
return nil, err
253+
errFinal = errors.Wrap(ctx, startErr, op)
254+
return nil, errFinal
252255
}
253256

254257
defer func() {
258+
if errFinal != nil {
259+
errRollback := b.driver.RollbackRun(ctx)
260+
if errRollback != nil {
261+
errFinal = stderrors.Join(errFinal, errRollback)
262+
}
263+
errFinal = errors.Wrap(ctx, errFinal, op)
264+
return
265+
}
255266
if commitErr := b.driver.CommitRun(ctx); commitErr != nil {
256-
err = errors.Wrap(ctx, commitErr, op)
267+
errFinal = errors.Wrap(ctx, commitErr, op)
257268
}
258269
}()
259270

260271
if ensureErr := b.driver.EnsureVersionTable(ctx); ensureErr != nil {
261-
err = errors.Wrap(ctx, ensureErr, op)
262-
return nil, err
272+
errFinal = errors.Wrap(ctx, ensureErr, op)
273+
return nil, errFinal
263274
}
264275

265276
if ensureErr := b.driver.EnsureMigrationLogTable(ctx); ensureErr != nil {
266-
err = errors.Wrap(ctx, ensureErr, op)
267-
return nil, err
277+
errFinal = errors.Wrap(ctx, ensureErr, op)
278+
return nil, errFinal
268279
}
269280

270281
for p.Next() {
271282
select {
272283
case <-ctx.Done():
273-
err = errors.Wrap(ctx, ctx.Err(), op)
274-
return nil, err
284+
errFinal = errors.Wrap(ctx, ctx.Err(), op)
285+
return nil, errFinal
275286
default:
276287
// context is not done yet. Continue on to the next query to execute.
277288
}
278289

279290
if h := p.PreHook(); h != nil {
280291
problems, err := b.driver.CheckHook(ctx, h.CheckFunc)
281292
if err != nil {
282-
return nil, errors.Wrap(ctx, err, op)
293+
errFinal = errors.Wrap(ctx, err, op)
294+
return nil, errFinal
283295
}
284296

285297
if len(problems) > 0 {
286298
if !b.selectedRepairs.IsSet(p.Edition(), p.Version()) {
287-
return nil, MigrationCheckError{
299+
errFinal = MigrationCheckError{
288300
Version: p.Version(),
289301
Edition: p.Edition(),
290302
Problems: problems,
291303
RepairDescription: h.RepairDescription,
292304
}
305+
return nil, errFinal
293306
}
294307

295308
repairs, err := b.driver.RepairHook(ctx, h.RepairFunc)
296309
if err != nil {
297-
return nil, errors.Wrap(ctx, err, op)
310+
errFinal = errors.Wrap(ctx, err, op)
311+
return nil, errFinal
298312
}
299313
logEntries = append(logEntries, RepairLog{
300314
Version: p.Version(),
@@ -304,8 +318,8 @@ func (b *Manager) runMigrations(ctx context.Context, p *provider.Provider) ([]Re
304318
}
305319
}
306320
if runErr := b.driver.Run(ctx, bytes.NewReader(p.Statements()), p.Version(), p.Edition()); runErr != nil {
307-
err = errors.Wrap(ctx, runErr, op)
308-
return nil, err
321+
errFinal = errors.Wrap(ctx, runErr, op)
322+
return nil, errFinal
309323
}
310324
}
311325

internal/db/schema/migrations/oss/postgres_97_01_test.go

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ import (
1010
"github.com/hashicorp/boundary/internal/db"
1111
"github.com/hashicorp/boundary/internal/db/common"
1212
"github.com/hashicorp/boundary/internal/db/schema"
13+
"github.com/hashicorp/boundary/internal/db/schema/internal/postgres"
1314
"github.com/hashicorp/boundary/internal/db/schema/migration"
1415
"github.com/hashicorp/boundary/internal/db/schema/migrations/oss/internal/hook97001"
1516
"github.com/hashicorp/boundary/testing/dbtest"
@@ -49,7 +50,8 @@ import (
4950
// - this cannot be created
5051
func TestMigrationHook97001(t *testing.T) {
5152
const (
52-
priorMigration = 96001
53+
priorMigration = 95001
54+
latestMigration = 97005
5355
)
5456
dialect := dbtest.Postgres
5557
ctx := context.Background()
@@ -159,6 +161,19 @@ func TestMigrationHook97001(t *testing.T) {
159161
_, err = rw.Exec(ctx, query, nil)
160162
require.NoError(t, err)
161163

164+
// migrate to latest - make sure it fails
165+
// migration to the prior migration (before the one we want to test)
166+
latestm, err := schema.NewManager(ctx, schema.Dialect(dialect), d)
167+
require.NoError(t, err)
168+
_, err = latestm.ApplyMigrations(ctx)
169+
require.Error(t, err)
170+
171+
driver, err := postgres.New(ctx, d)
172+
require.NoError(t, err)
173+
schemaVer, _, err := driver.CurrentState(ctx, "oss")
174+
require.NoError(t, err)
175+
require.Equal(t, priorMigration, schemaVer)
176+
162177
tx, err := d.BeginTx(ctx, nil)
163178
require.NoError(t, err)
164179

0 commit comments

Comments
 (0)