Skip to content

Commit 813b0a1

Browse files
authored
Respect context cancellation when backing off (#437)
1 parent ddd44bb commit 813b0a1

File tree

2 files changed

+73
-8
lines changed

2 files changed

+73
-8
lines changed

pkg/db/db.go

Lines changed: 21 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -42,10 +42,13 @@ func (db *RDB) ExecContext(ctx context.Context, query string, args ...interface{
4242

4343
pqErr := &pq.Error{}
4444
if errors.As(err, &pqErr) && pqErr.Code == lockNotAvailableErrorCode {
45-
<-time.After(b.Duration())
46-
} else {
47-
return nil, err
45+
if err := sleepCtx(ctx, b.Duration()); err != nil {
46+
return nil, err
47+
}
48+
continue
4849
}
50+
51+
return nil, err
4952
}
5053
}
5154

@@ -70,13 +73,25 @@ func (db *RDB) WithRetryableTransaction(ctx context.Context, f func(context.Cont
7073

7174
pqErr := &pq.Error{}
7275
if errors.As(err, &pqErr) && pqErr.Code == lockNotAvailableErrorCode {
73-
<-time.After(b.Duration())
74-
} else {
75-
return err
76+
if err := sleepCtx(ctx, b.Duration()); err != nil {
77+
return err
78+
}
79+
continue
7680
}
81+
82+
return err
7783
}
7884
}
7985

8086
func (db *RDB) Close() error {
8187
return db.DB.Close()
8288
}
89+
90+
func sleepCtx(ctx context.Context, d time.Duration) error {
91+
select {
92+
case <-ctx.Done():
93+
return ctx.Err()
94+
case <-time.After(d):
95+
return nil
96+
}
97+
}

pkg/db/db_test.go

Lines changed: 52 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,9 @@ import (
99
"testing"
1010
"time"
1111

12-
"github.com/xataio/pgroll/internal/testutils"
13-
1412
"github.com/stretchr/testify/require"
13+
14+
"github.com/xataio/pgroll/internal/testutils"
1515
"github.com/xataio/pgroll/pkg/db"
1616
)
1717

@@ -37,6 +37,30 @@ func TestExecContext(t *testing.T) {
3737
})
3838
}
3939

40+
func TestExecContextWhenContextCancelled(t *testing.T) {
41+
t.Parallel()
42+
43+
testutils.WithConnectionToContainer(t, func(conn *sql.DB, connStr string) {
44+
ctx := context.Background()
45+
ctx, cancel := context.WithCancel(ctx)
46+
47+
// create a table on which an exclusive lock is held for 2 seconds
48+
setupTableLock(t, connStr, 2*time.Second)
49+
50+
// set the lock timeout to 100ms
51+
ensureLockTimeout(t, conn, 100)
52+
53+
// execute a query that should retry until the lock is released
54+
rdb := &db.RDB{DB: conn}
55+
56+
// Cancel the context before the lock times out
57+
go time.AfterFunc(500*time.Millisecond, cancel)
58+
59+
_, err := rdb.ExecContext(ctx, "INSERT INTO test(id) VALUES (1)")
60+
require.Errorf(t, err, "context canceled")
61+
})
62+
}
63+
4064
func TestWithRetryableTransaction(t *testing.T) {
4165
t.Parallel()
4266

@@ -58,6 +82,32 @@ func TestWithRetryableTransaction(t *testing.T) {
5882
})
5983
}
6084

85+
func TestWithRetryableTransactionWhenContextCancelled(t *testing.T) {
86+
t.Parallel()
87+
88+
testutils.WithConnectionToContainer(t, func(conn *sql.DB, connStr string) {
89+
ctx := context.Background()
90+
ctx, cancel := context.WithCancel(ctx)
91+
92+
// create a table on which an exclusive lock is held for 2 seconds
93+
setupTableLock(t, connStr, 2*time.Second)
94+
95+
// set the lock timeout to 100ms
96+
ensureLockTimeout(t, conn, 100)
97+
98+
// run a transaction that should retry until the lock is released
99+
rdb := &db.RDB{DB: conn}
100+
101+
// Cancel the context before the lock times out
102+
go time.AfterFunc(500*time.Millisecond, cancel)
103+
104+
err := rdb.WithRetryableTransaction(ctx, func(ctx context.Context, tx *sql.Tx) error {
105+
return tx.QueryRowContext(ctx, "SELECT 1 FROM test").Err()
106+
})
107+
require.Errorf(t, err, "context canceled")
108+
})
109+
}
110+
61111
// setupTableLock:
62112
// * connects to the database
63113
// * creates a table in the database

0 commit comments

Comments
 (0)