@@ -33,7 +33,7 @@ type dbStorageClient struct {
33
33
34
34
func newClient (ctx context.Context , driverName string , db * sql.DB , tableName string ) (* dbStorageClient , error ) {
35
35
createTableSQL := createTable
36
- if driverName == "sqlite" {
36
+ if driverName == driverSqlite {
37
37
createTableSQL = createTableSqlite
38
38
}
39
39
var err error
@@ -59,45 +59,37 @@ func newClient(ctx context.Context, driverName string, db *sql.DB, tableName str
59
59
60
60
// Get will retrieve data from storage that corresponds to the specified key
61
61
func (c * dbStorageClient ) Get (ctx context.Context , key string ) ([]byte , error ) {
62
- rows , err := c .getQuery .QueryContext (ctx , key )
63
- if err != nil {
64
- return nil , err
65
- }
66
- if ! rows .Next () {
67
- return nil , nil
68
- }
69
- var result []byte
70
- err = rows .Scan (& result )
71
- if err != nil {
72
- return result , err
73
- }
74
- err = rows .Close ()
75
- return result , err
62
+ return c .get (ctx , key , nil )
76
63
}
77
64
78
65
// Set will store data. The data can be retrieved using the same key
79
66
func (c * dbStorageClient ) Set (ctx context.Context , key string , value []byte ) error {
80
- _ , err := c .setQuery .ExecContext (ctx , key , value , value )
81
- return err
67
+ return c .set (ctx , key , value , nil )
82
68
}
83
69
84
70
// Delete will delete data associated with the specified key
85
71
func (c * dbStorageClient ) Delete (ctx context.Context , key string ) error {
86
- _ , err := c .deleteQuery .ExecContext (ctx , key )
87
- return err
72
+ return c .delete (ctx , key , nil )
88
73
}
89
74
90
75
// Batch executes the specified operations in order. Get operation results are updated in place
91
76
func (c * dbStorageClient ) Batch (ctx context.Context , ops ... * storage.Operation ) error {
92
- var err error
77
+ // Start a new transaction
78
+ tx , err := c .db .BeginTx (ctx , nil )
79
+ if err != nil {
80
+ return err
81
+ }
82
+ //nolint:errcheck
83
+ defer tx .Rollback ()
84
+
93
85
for _ , op := range ops {
94
86
switch op .Type {
95
87
case storage .Get :
96
- op .Value , err = c .Get (ctx , op .Key )
88
+ op .Value , err = c .get (ctx , op .Key , tx )
97
89
case storage .Set :
98
- err = c .Set (ctx , op .Key , op .Value )
90
+ err = c .set (ctx , op .Key , op .Value , tx )
99
91
case storage .Delete :
100
- err = c .Delete (ctx , op .Key )
92
+ err = c .delete (ctx , op .Key , tx )
101
93
default :
102
94
return errors .New ("wrong operation type" )
103
95
}
@@ -106,7 +98,8 @@ func (c *dbStorageClient) Batch(ctx context.Context, ops ...*storage.Operation)
106
98
return err
107
99
}
108
100
}
109
- return err
101
+
102
+ return tx .Commit ()
110
103
}
111
104
112
105
// Close will close the database
@@ -119,3 +112,39 @@ func (c *dbStorageClient) Close(_ context.Context) error {
119
112
}
120
113
return c .getQuery .Close ()
121
114
}
115
+
116
+ func (c * dbStorageClient ) get (ctx context.Context , key string , tx * sql.Tx ) ([]byte , error ) {
117
+ rows , err := c .wrapTx (c .getQuery , tx ).QueryContext (ctx , key )
118
+ if err != nil {
119
+ return nil , err
120
+ }
121
+
122
+ if ! rows .Next () {
123
+ return nil , nil
124
+ }
125
+
126
+ var result []byte
127
+ if err := rows .Scan (& result ); err != nil {
128
+ return result , err
129
+ }
130
+
131
+ return result , rows .Close ()
132
+ }
133
+
134
+ func (c * dbStorageClient ) set (ctx context.Context , key string , value []byte , tx * sql.Tx ) error {
135
+ _ , err := c .wrapTx (c .setQuery , tx ).ExecContext (ctx , key , value , value )
136
+ return err
137
+ }
138
+
139
+ func (c * dbStorageClient ) delete (ctx context.Context , key string , tx * sql.Tx ) error {
140
+ _ , err := c .wrapTx (c .deleteQuery , tx ).ExecContext (ctx , key )
141
+ return err
142
+ }
143
+
144
+ func (c * dbStorageClient ) wrapTx (stmt * sql.Stmt , tx * sql.Tx ) * sql.Stmt {
145
+ if tx != nil {
146
+ return tx .Stmt (stmt )
147
+ }
148
+
149
+ return stmt
150
+ }
0 commit comments