@@ -9,11 +9,12 @@ import (
9
9
"errors"
10
10
"fmt"
11
11
12
+ "go.opentelemetry.io/collector/extension/xextension/storage"
13
+
12
14
// Postgres driver
13
15
_ "github.com/jackc/pgx/v5/stdlib"
14
16
// SQLite driver
15
17
_ "github.com/mattn/go-sqlite3"
16
- "go.opentelemetry.io/collector/extension/xextension/storage"
17
18
)
18
19
19
20
const (
@@ -33,7 +34,7 @@ type dbStorageClient struct {
33
34
34
35
func newClient (ctx context.Context , driverName string , db * sql.DB , tableName string ) (* dbStorageClient , error ) {
35
36
createTableSQL := createTable
36
- if driverName == "sqlite" {
37
+ if driverName == driverSqlite {
37
38
createTableSQL = createTableSqlite
38
39
}
39
40
var err error
@@ -59,45 +60,37 @@ func newClient(ctx context.Context, driverName string, db *sql.DB, tableName str
59
60
60
61
// Get will retrieve data from storage that corresponds to the specified key
61
62
func (c * dbStorageClient ) Get (ctx context.Context , key string ) ([]byte , error ) {
62
- rows , err := c .getQuery .QueryContext (ctx , key )
63
- if err != nil {
64
- return nil , err
65
- }
66
- if ! rows .Next () {
67
- return nil , nil
68
- }
69
- var result []byte
70
- err = rows .Scan (& result )
71
- if err != nil {
72
- return result , err
73
- }
74
- err = rows .Close ()
75
- return result , err
63
+ return c .get (ctx , key , nil )
76
64
}
77
65
78
66
// Set will store data. The data can be retrieved using the same key
79
67
func (c * dbStorageClient ) Set (ctx context.Context , key string , value []byte ) error {
80
- _ , err := c .setQuery .ExecContext (ctx , key , value , value )
81
- return err
68
+ return c .set (ctx , key , value , nil )
82
69
}
83
70
84
71
// Delete will delete data associated with the specified key
85
72
func (c * dbStorageClient ) Delete (ctx context.Context , key string ) error {
86
- _ , err := c .deleteQuery .ExecContext (ctx , key )
87
- return err
73
+ return c .delete (ctx , key , nil )
88
74
}
89
75
90
76
// Batch executes the specified operations in order. Get operation results are updated in place
91
77
func (c * dbStorageClient ) Batch (ctx context.Context , ops ... * storage.Operation ) error {
92
- var err error
78
+ // Start a new transaction
79
+ tx , err := c .db .BeginTx (ctx , nil )
80
+ if err != nil {
81
+ return err
82
+ }
83
+ //nolint:errcheck
84
+ defer tx .Rollback ()
85
+
93
86
for _ , op := range ops {
94
87
switch op .Type {
95
88
case storage .Get :
96
- op .Value , err = c .Get (ctx , op .Key )
89
+ op .Value , err = c .get (ctx , op .Key , tx )
97
90
case storage .Set :
98
- err = c .Set (ctx , op .Key , op .Value )
91
+ err = c .set (ctx , op .Key , op .Value , tx )
99
92
case storage .Delete :
100
- err = c .Delete (ctx , op .Key )
93
+ err = c .delete (ctx , op .Key , tx )
101
94
default :
102
95
return errors .New ("wrong operation type" )
103
96
}
@@ -106,7 +99,8 @@ func (c *dbStorageClient) Batch(ctx context.Context, ops ...*storage.Operation)
106
99
return err
107
100
}
108
101
}
109
- return err
102
+
103
+ return tx .Commit ()
110
104
}
111
105
112
106
// Close will close the database
@@ -119,3 +113,39 @@ func (c *dbStorageClient) Close(_ context.Context) error {
119
113
}
120
114
return c .getQuery .Close ()
121
115
}
116
+
117
+ func (c * dbStorageClient ) get (ctx context.Context , key string , tx * sql.Tx ) ([]byte , error ) {
118
+ rows , err := c .wrapTx (c .getQuery , tx ).QueryContext (ctx , key )
119
+ if err != nil {
120
+ return nil , err
121
+ }
122
+
123
+ if ! rows .Next () {
124
+ return nil , nil
125
+ }
126
+
127
+ var result []byte
128
+ if err := rows .Scan (& result ); err != nil {
129
+ return result , err
130
+ }
131
+
132
+ return result , rows .Close ()
133
+ }
134
+
135
+ func (c * dbStorageClient ) set (ctx context.Context , key string , value []byte , tx * sql.Tx ) error {
136
+ _ , err := c .wrapTx (c .setQuery , tx ).ExecContext (ctx , key , value , value )
137
+ return err
138
+ }
139
+
140
+ func (c * dbStorageClient ) delete (ctx context.Context , key string , tx * sql.Tx ) error {
141
+ _ , err := c .wrapTx (c .deleteQuery , tx ).ExecContext (ctx , key )
142
+ return err
143
+ }
144
+
145
+ func (c * dbStorageClient ) wrapTx (stmt * sql.Stmt , tx * sql.Tx ) * sql.Stmt {
146
+ if tx != nil {
147
+ return tx .Stmt (stmt )
148
+ }
149
+
150
+ return stmt
151
+ }
0 commit comments