@@ -10,6 +10,7 @@ import (
10
10
"reflect"
11
11
"strconv"
12
12
"strings"
13
+ "sync"
13
14
"sync/atomic"
14
15
"testing"
15
16
"time"
@@ -35,14 +36,18 @@ type testServer struct {
35
36
type mockHandler struct {
36
37
// the number of times a query executed
37
38
queryCount atomic.Int32
39
+ modifier * sync.WaitGroup
38
40
}
39
41
40
42
func TestDriverOptions_SetRetriesOn (t * testing.T ) {
41
43
log .SetLevel (log .LevelDebug )
42
44
srv := CreateMockServer (t )
43
45
defer srv .Stop ()
46
+ var wg sync.WaitGroup
47
+ srv .handler .modifier = & wg
48
+ wg .Add (3 )
44
49
45
- conn ,
err := sql .
Open (
"mysql" ,
"[email protected] :3307/test?readTimeout=1s " )
50
+ conn ,
err := sql .
Open (
"mysql" ,
"[email protected] :3307/test?readTimeout=100ms " )
46
51
defer func () {
47
52
_ = conn .Close ()
48
53
}()
@@ -54,6 +59,7 @@ func TestDriverOptions_SetRetriesOn(t *testing.T) {
54
59
// we want to get a golang database/sql/driver ErrBadConn
55
60
require .ErrorIs (t , err , sqlDriver .ErrBadConn )
56
61
62
+ wg .Wait ()
57
63
// here we issue assert that even though we only issued 1 query, that the retries
58
64
// remained on and there were 3 calls to the DB.
59
65
require .EqualValues (t , 3 , srv .handler .queryCount .Load ())
@@ -63,8 +69,11 @@ func TestDriverOptions_SetRetriesOff(t *testing.T) {
63
69
log .SetLevel (log .LevelDebug )
64
70
srv := CreateMockServer (t )
65
71
defer srv .Stop ()
72
+ var wg sync.WaitGroup
73
+ srv .handler .modifier = & wg
74
+ wg .Add (1 )
66
75
67
- conn ,
err := sql .
Open (
"mysql" ,
"[email protected] :3307/test?readTimeout=1s &retries=off" )
76
+ conn ,
err := sql .
Open (
"mysql" ,
"[email protected] :3307/test?readTimeout=100ms &retries=off" )
68
77
defer func () {
69
78
_ = conn .Close ()
70
79
}()
@@ -75,6 +84,7 @@ func TestDriverOptions_SetRetriesOff(t *testing.T) {
75
84
// we want the native error from this driver implementation
76
85
require .ErrorIs (t , err , mysql .ErrBadConn )
77
86
87
+ wg .Wait ()
78
88
// here we issue assert that even though we only issued 1 query, that the retries
79
89
// remained on and there were 3 calls to the DB.
80
90
require .EqualValues (t , 1 , srv .handler .queryCount .Load ())
@@ -311,6 +321,12 @@ func (h *mockHandler) UseDB(dbName string) error {
311
321
}
312
322
313
323
func (h * mockHandler ) handleQuery (query string , binary bool , args []interface {}) (* mysql.Result , error ) {
324
+ defer func () {
325
+ if h .modifier != nil {
326
+ h .modifier .Done ()
327
+ }
328
+ }()
329
+
314
330
h .queryCount .Add (1 )
315
331
ss := strings .Split (query , " " )
316
332
switch strings .ToLower (ss [0 ]) {
@@ -329,7 +345,7 @@ func (h *mockHandler) handleQuery(query string, binary bool, args []interface{})
329
345
}, binary )
330
346
} else {
331
347
if strings .Contains (query , "slow" ) {
332
- time .Sleep (time .Second * 5 )
348
+ time .Sleep (time .Second )
333
349
}
334
350
335
351
var aValue uint64 = 1
0 commit comments