Skip to content

Commit 8f6826c

Browse files
authored
Merge pull request #8 from rdallman/fix-lockyloo-golang
Fix lockyloo golang
2 parents aa21a21 + a9176e9 commit 8f6826c

File tree

2 files changed

+64
-15
lines changed

2 files changed

+64
-15
lines changed

database/mysql/mysql.go

Lines changed: 30 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
1+
// +build go1.9
2+
13
package mysql
24

35
import (
6+
"context"
47
"crypto/tls"
58
"crypto/x509"
69
"database/sql"
@@ -35,7 +38,9 @@ type Config struct {
3538
}
3639

3740
type Mysql struct {
38-
db *sql.DB
41+
// mysql RELEASE_LOCK must be called from the same conn, so
42+
// just do everything over a single conn anyway.
43+
conn *sql.Conn
3944
isLocked bool
4045

4146
config *Config
@@ -67,8 +72,13 @@ func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) {
6772
config.MigrationsTable = DefaultMigrationsTable
6873
}
6974

75+
conn, err := instance.Conn(context.Background())
76+
if err != nil {
77+
return nil, err
78+
}
79+
7080
mx := &Mysql{
71-
db: instance,
81+
conn: conn,
7282
config: config,
7383
}
7484

@@ -148,7 +158,7 @@ func (m *Mysql) Open(url string) (database.Driver, error) {
148158
}
149159

150160
func (m *Mysql) Close() error {
151-
return m.db.Close()
161+
return m.conn.Close()
152162
}
153163

154164
func (m *Mysql) Lock() error {
@@ -162,9 +172,9 @@ func (m *Mysql) Lock() error {
162172
return err
163173
}
164174

165-
query := "SELECT GET_LOCK(?, 1)"
175+
query := "SELECT GET_LOCK(?, 10)"
166176
var success bool
167-
if err := m.db.QueryRow(query, aid).Scan(&success); err != nil {
177+
if err := m.conn.QueryRowContext(context.Background(), query, aid).Scan(&success); err != nil {
168178
return &database.Error{OrigErr: err, Err: "try lock failed", Query: []byte(query)}
169179
}
170180

@@ -188,10 +198,14 @@ func (m *Mysql) Unlock() error {
188198
}
189199

190200
query := `SELECT RELEASE_LOCK(?)`
191-
if _, err := m.db.Exec(query, aid); err != nil {
201+
if _, err := m.conn.ExecContext(context.Background(), query, aid); err != nil {
192202
return &database.Error{OrigErr: err, Query: []byte(query)}
193203
}
194204

205+
// NOTE: RELEASE_LOCK could return NULL or (or 0 if the code is changed),
206+
// in which case isLocked should be true until the timeout expires -- synchronizing
207+
// these states is likely not worth trying to do; reconsider the necessity of isLocked.
208+
195209
m.isLocked = false
196210
return nil
197211
}
@@ -203,27 +217,28 @@ func (m *Mysql) Run(migration io.Reader) error {
203217
}
204218

205219
query := string(migr[:])
206-
if _, err := m.db.Exec(query); err != nil {
220+
if _, err := m.conn.ExecContext(context.Background(), query); err != nil {
207221
return database.Error{OrigErr: err, Err: "migration failed", Query: migr}
208222
}
209223

210224
return nil
211225
}
212226

213227
func (m *Mysql) SetVersion(version int, dirty bool) error {
214-
tx, err := m.db.Begin()
228+
tx, err := m.conn.BeginTx(context.Background(), &sql.TxOptions{})
215229
if err != nil {
216230
return &database.Error{OrigErr: err, Err: "transaction start failed"}
217231
}
218232

219233
query := "TRUNCATE `" + m.config.MigrationsTable + "`"
220-
if _, err := m.db.Exec(query); err != nil {
234+
if _, err := tx.ExecContext(context.Background(), query); err != nil {
235+
tx.Rollback()
221236
return &database.Error{OrigErr: err, Query: []byte(query)}
222237
}
223238

224239
if version >= 0 {
225240
query := "INSERT INTO `" + m.config.MigrationsTable + "` (version, dirty) VALUES (?, ?)"
226-
if _, err := m.db.Exec(query, version, dirty); err != nil {
241+
if _, err := tx.ExecContext(context.Background(), query, version, dirty); err != nil {
227242
tx.Rollback()
228243
return &database.Error{OrigErr: err, Query: []byte(query)}
229244
}
@@ -238,7 +253,7 @@ func (m *Mysql) SetVersion(version int, dirty bool) error {
238253

239254
func (m *Mysql) Version() (version int, dirty bool, err error) {
240255
query := "SELECT version, dirty FROM `" + m.config.MigrationsTable + "` LIMIT 1"
241-
err = m.db.QueryRow(query).Scan(&version, &dirty)
256+
err = m.conn.QueryRowContext(context.Background(), query).Scan(&version, &dirty)
242257
switch {
243258
case err == sql.ErrNoRows:
244259
return database.NilVersion, false, nil
@@ -259,7 +274,7 @@ func (m *Mysql) Version() (version int, dirty bool, err error) {
259274
func (m *Mysql) Drop() error {
260275
// select all tables
261276
query := `SHOW TABLES LIKE '%'`
262-
tables, err := m.db.Query(query)
277+
tables, err := m.conn.QueryContext(context.Background(), query)
263278
if err != nil {
264279
return &database.Error{OrigErr: err, Query: []byte(query)}
265280
}
@@ -281,7 +296,7 @@ func (m *Mysql) Drop() error {
281296
// delete one by one ...
282297
for _, t := range tableNames {
283298
query = "DROP TABLE IF EXISTS `" + t + "` CASCADE"
284-
if _, err := m.db.Exec(query); err != nil {
299+
if _, err := m.conn.ExecContext(context.Background(), query); err != nil {
285300
return &database.Error{OrigErr: err, Query: []byte(query)}
286301
}
287302
}
@@ -297,7 +312,7 @@ func (m *Mysql) ensureVersionTable() error {
297312
// check if migration table exists
298313
var result string
299314
query := `SHOW TABLES LIKE "` + m.config.MigrationsTable + `"`
300-
if err := m.db.QueryRow(query).Scan(&result); err != nil {
315+
if err := m.conn.QueryRowContext(context.Background(), query).Scan(&result); err != nil {
301316
if err != sql.ErrNoRows {
302317
return &database.Error{OrigErr: err, Query: []byte(query)}
303318
}
@@ -307,7 +322,7 @@ func (m *Mysql) ensureVersionTable() error {
307322

308323
// if not, create the empty migration table
309324
query = "CREATE TABLE `" + m.config.MigrationsTable + "` (version bigint not null primary key, dirty boolean not null)"
310-
if _, err := m.db.Exec(query); err != nil {
325+
if _, err := m.conn.ExecContext(context.Background(), query); err != nil {
311326
return &database.Error{OrigErr: err, Query: []byte(query)}
312327
}
313328
return nil

database/mysql/mysql_test.go

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,3 +63,37 @@ func Test(t *testing.T) {
6363
}
6464
})
6565
}
66+
67+
func TestLockWorks(t *testing.T) {
68+
mt.ParallelTest(t, versions, isReady,
69+
func(t *testing.T, i mt.Instance) {
70+
p := &Mysql{}
71+
addr := fmt.Sprintf("mysql://root:root@tcp(%v:%v)/public", i.Host(), i.Port())
72+
d, err := p.Open(addr)
73+
if err != nil {
74+
t.Fatalf("%v", err)
75+
}
76+
dt.Test(t, d, []byte("SELECT 1"))
77+
78+
ms := d.(*Mysql)
79+
80+
err = ms.Lock()
81+
if err != nil {
82+
t.Fatal(err)
83+
}
84+
err = ms.Unlock()
85+
if err != nil {
86+
t.Fatal(err)
87+
}
88+
89+
// make sure the 2nd lock works (RELEASE_LOCK is very finicky)
90+
err = ms.Lock()
91+
if err != nil {
92+
t.Fatal(err)
93+
}
94+
err = ms.Unlock()
95+
if err != nil {
96+
t.Fatal(err)
97+
}
98+
})
99+
}

0 commit comments

Comments
 (0)