@@ -14,6 +14,7 @@ import (
14
14
// SQLite driver
15
15
_ "github.com/mattn/go-sqlite3"
16
16
"go.opentelemetry.io/collector/extension/xextension/storage"
17
+ "go.uber.org/zap"
17
18
)
18
19
19
20
const (
@@ -25,15 +26,16 @@ const (
25
26
)
26
27
27
28
type dbStorageClient struct {
29
+ logger * zap.Logger
28
30
db * sql.DB
29
31
getQuery * sql.Stmt
30
32
setQuery * sql.Stmt
31
33
deleteQuery * sql.Stmt
32
34
}
33
35
34
- func newClient (ctx context.Context , driverName string , db * sql.DB , tableName string ) (* dbStorageClient , error ) {
36
+ func newClient (ctx context.Context , logger * zap. Logger , db * sql.DB , driverName string , tableName string ) (* dbStorageClient , error ) {
35
37
createTableSQL := createTable
36
- if driverName == "sqlite" {
38
+ if driverName == driverSQLite {
37
39
createTableSQL = createTableSqlite
38
40
}
39
41
var err error
@@ -54,50 +56,52 @@ func newClient(ctx context.Context, driverName string, db *sql.DB, tableName str
54
56
if err != nil {
55
57
return nil , err
56
58
}
57
- return & dbStorageClient {db , selectQuery , setQuery , deleteQuery }, nil
59
+ return & dbStorageClient {logger , db , selectQuery , setQuery , deleteQuery }, nil
58
60
}
59
61
60
62
// Get will retrieve data from storage that corresponds to the specified key
61
63
func (c * dbStorageClient ) Get (ctx context.Context , key string ) ([]byte , error ) {
62
- rows , err := c .getQuery .QueryContext (ctx , key )
63
- if err != nil {
64
- return nil , err
65
- }
66
- if ! rows .Next () {
67
- return nil , nil
68
- }
69
- var result []byte
70
- err = rows .Scan (& result )
71
- if err != nil {
72
- return result , err
73
- }
74
- err = rows .Close ()
75
- return result , err
64
+ return c .get (ctx , key , nil )
76
65
}
77
66
78
67
// Set will store data. The data can be retrieved using the same key
79
68
func (c * dbStorageClient ) Set (ctx context.Context , key string , value []byte ) error {
80
- _ , err := c .setQuery .ExecContext (ctx , key , value , value )
81
- return err
69
+ return c .set (ctx , key , value , nil )
82
70
}
83
71
84
72
// Delete will delete data associated with the specified key
85
73
func (c * dbStorageClient ) Delete (ctx context.Context , key string ) error {
86
- _ , err := c .deleteQuery .ExecContext (ctx , key )
87
- return err
74
+ return c .delete (ctx , key , nil )
88
75
}
89
76
90
77
// Batch executes the specified operations in order. Get operation results are updated in place
91
78
func (c * dbStorageClient ) Batch (ctx context.Context , ops ... * storage.Operation ) error {
92
- var err error
79
+ // Start a new transaction
80
+ tx , err := c .db .BeginTx (ctx , nil )
81
+ if err != nil {
82
+ return err
83
+ }
84
+
85
+ // In case of any error we should roll back whole transaction to keep DB in consistent state
86
+ // In case of successful commit - tx.Rollback() will be a no-op here as tx is already closed
87
+ defer func () {
88
+ // We should ignore error related already finished transaction here
89
+ // It might happened, for example, if Context was canceled outside of Batch() function
90
+ // in this case whole transaction will be rolled back by sql package and we'll receive ErrTxDone here,
91
+ // which is actually not an issue because transaction was correctly closed with rollback
92
+ if rollbackErr := tx .Rollback (); ! errors .Is (rollbackErr , sql .ErrTxDone ) {
93
+ c .logger .Error ("Failed to rollback Batch() transaction" , zap .Error (rollbackErr ))
94
+ }
95
+ }()
96
+
93
97
for _ , op := range ops {
94
98
switch op .Type {
95
99
case storage .Get :
96
- op .Value , err = c .Get (ctx , op .Key )
100
+ op .Value , err = c .get (ctx , op .Key , tx )
97
101
case storage .Set :
98
- err = c .Set (ctx , op .Key , op .Value )
102
+ err = c .set (ctx , op .Key , op .Value , tx )
99
103
case storage .Delete :
100
- err = c .Delete (ctx , op .Key )
104
+ err = c .delete (ctx , op .Key , tx )
101
105
default :
102
106
return errors .New ("wrong operation type" )
103
107
}
@@ -106,7 +110,8 @@ func (c *dbStorageClient) Batch(ctx context.Context, ops ...*storage.Operation)
106
110
return err
107
111
}
108
112
}
109
- return err
113
+
114
+ return tx .Commit ()
110
115
}
111
116
112
117
// Close will close the database
@@ -119,3 +124,39 @@ func (c *dbStorageClient) Close(_ context.Context) error {
119
124
}
120
125
return c .getQuery .Close ()
121
126
}
127
+
128
+ func (c * dbStorageClient ) get (ctx context.Context , key string , tx * sql.Tx ) ([]byte , error ) {
129
+ rows , err := c .wrapTx (c .getQuery , tx ).QueryContext (ctx , key )
130
+ if err != nil {
131
+ return nil , err
132
+ }
133
+
134
+ if ! rows .Next () {
135
+ return nil , nil
136
+ }
137
+
138
+ var result []byte
139
+ if err := rows .Scan (& result ); err != nil {
140
+ return result , err
141
+ }
142
+
143
+ return result , rows .Close ()
144
+ }
145
+
146
+ func (c * dbStorageClient ) set (ctx context.Context , key string , value []byte , tx * sql.Tx ) error {
147
+ _ , err := c .wrapTx (c .setQuery , tx ).ExecContext (ctx , key , value , value )
148
+ return err
149
+ }
150
+
151
+ func (c * dbStorageClient ) delete (ctx context.Context , key string , tx * sql.Tx ) error {
152
+ _ , err := c .wrapTx (c .deleteQuery , tx ).ExecContext (ctx , key )
153
+ return err
154
+ }
155
+
156
+ func (c * dbStorageClient ) wrapTx (stmt * sql.Stmt , tx * sql.Tx ) * sql.Stmt {
157
+ if tx != nil {
158
+ return tx .Stmt (stmt )
159
+ }
160
+
161
+ return stmt
162
+ }
0 commit comments