Skip to content

Commit d1b237b

Browse files
Merge pull request #5523 from hashicorp/backport/ddebko-fix-db-rw/recently-balanced-eel
This pull request was automerged via backport-assistant
2 parents cd35b25 + 2e2d18a commit d1b237b

20 files changed

+139
-27
lines changed

internal/auth/oidc/repository_auth_method.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -179,7 +179,7 @@ func (r *Repository) upsertAccount(ctx context.Context, am *AuthMethod, IdTokenC
179179
var rowCnt int
180180
for rows.Next() {
181181
rowCnt += 1
182-
err = r.reader.ScanRows(ctx, rows, &result)
182+
err = reader.ScanRows(ctx, rows, &result)
183183
if err != nil {
184184
return errors.Wrap(ctx, err, op, errors.WithMsg("unable to scan rows for account"))
185185
}

internal/auth/oidc/repository_managed_group_members.go

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ import (
1111
"github.com/hashicorp/boundary/internal/errors"
1212
"github.com/hashicorp/boundary/internal/kms"
1313
"github.com/hashicorp/boundary/internal/oplog"
14+
"github.com/hashicorp/boundary/internal/util"
1415
)
1516

1617
// SetManagedGroupMemberships will set the managed groups for the given account
@@ -207,7 +208,7 @@ func (r *Repository) ListManagedGroupMembershipsByMember(ctx context.Context, wi
207208
limit = opts.withLimit
208209
}
209210
reader := r.reader
210-
if opts.withReader != nil {
211+
if !util.IsNil(opts.withReader) {
211212
reader = opts.withReader
212213
}
213214
var mgs []*ManagedGroupMemberAccount
@@ -232,7 +233,7 @@ func (r *Repository) ListManagedGroupMembershipsByGroup(ctx context.Context, wit
232233
limit = opts.withLimit
233234
}
234235
reader := r.reader
235-
if opts.withReader != nil {
236+
if !util.IsNil(opts.withReader) {
236237
reader = opts.withReader
237238
}
238239
var mgs []*ManagedGroupMemberAccount

internal/auth/repository_auth_method.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,7 @@ func (amr *AuthMethodRepository) ListDeletedIds(ctx context.Context, since time.
147147
var deletedAuthMethodIDs []string
148148
var transactionTimestamp time.Time
149149
if _, err := amr.writer.DoTx(ctx, db.StdRetryCnt, db.ExpBackoff{}, func(r db.Reader, w db.Writer) error {
150-
rows, err := amr.writer.Query(ctx, listDeletedIdsQuery, []any{sql.Named("since", since)})
150+
rows, err := w.Query(ctx, listDeletedIdsQuery, []any{sql.Named("since", since)})
151151
if err != nil {
152152
return errors.Wrap(ctx, err, op)
153153
}

internal/credential/repository_store.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,7 @@ func (s *StoreRepository) ListDeletedIds(ctx context.Context, since time.Time) (
118118
var deletedStoreIDs []string
119119
var transactionTimestamp time.Time
120120
if _, err := s.writer.DoTx(ctx, db.StdRetryCnt, db.ExpBackoff{}, func(r db.Reader, w db.Writer) error {
121-
rows, err := s.writer.Query(ctx, listDeletedIdsQuery, []any{sql.Named("since", since)})
121+
rows, err := w.Query(ctx, listDeletedIdsQuery, []any{sql.Named("since", since)})
122122
if err != nil {
123123
return errors.Wrap(ctx, err, op)
124124
}

internal/host/options.go

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,9 @@ package host
66
import (
77
"errors"
88

9+
"github.com/hashicorp/boundary/internal/db"
910
"github.com/hashicorp/boundary/internal/pagination"
11+
"github.com/hashicorp/boundary/internal/util"
1012
)
1113

1214
// GetOpts - iterate the inbound Options and return a struct
@@ -26,6 +28,8 @@ type Option func(*options) error
2628
// options = how options are represented
2729
type options struct {
2830
WithLimit int
31+
WithReader db.Reader
32+
WithWriter db.Writer
2933
WithOrderByCreateTime bool
3034
Ascending bool
3135
WithStartPageAfterItem pagination.Item
@@ -66,3 +70,19 @@ func WithStartPageAfterItem(item pagination.Item) Option {
6670
return nil
6771
}
6872
}
73+
74+
// WithReaderWriter is used to share the same database reader
75+
// and writer when executing sql within a transaction.
76+
func WithReaderWriter(r db.Reader, w db.Writer) Option {
77+
return func(o *options) error {
78+
if util.IsNil(r) {
79+
return errors.New("reader cannot be nil")
80+
}
81+
if util.IsNil(w) {
82+
return errors.New("writer cannot be nil")
83+
}
84+
o.WithReader = r
85+
o.WithWriter = w
86+
return nil
87+
}
88+
}

internal/host/options_test.go

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,4 +77,23 @@ func Test_GetOpts(t *testing.T) {
7777
assert.Equal(opts.WithStartPageAfterItem.GetPublicId(), "s_1")
7878
assert.Equal(opts.WithStartPageAfterItem.GetUpdateTime(), timestamp.New(updateTime))
7979
})
80+
t.Run("WithReaderWriter", func(t *testing.T) {
81+
t.Parallel()
82+
t.Run("nil writer", func(t *testing.T) {
83+
t.Parallel()
84+
_, err := GetOpts(WithReaderWriter(&db.Db{}, nil))
85+
require.Error(t, err)
86+
})
87+
t.Run("nil reader", func(t *testing.T) {
88+
t.Parallel()
89+
_, err := GetOpts(WithReaderWriter(nil, &db.Db{}))
90+
require.Error(t, err)
91+
})
92+
reader := &db.Db{}
93+
writer := &db.Db{}
94+
opts, err := GetOpts(WithReaderWriter(reader, writer))
95+
require.NoError(t, err)
96+
assert.Equal(t, reader, opts.WithReader)
97+
assert.Equal(t, writer, opts.WithWriter)
98+
})
8099
}

internal/host/plugin/options.go

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
package plugin
55

66
import (
7+
"github.com/hashicorp/boundary/internal/db"
78
"github.com/hashicorp/boundary/internal/pagination"
89
"google.golang.org/protobuf/types/known/structpb"
910
)
@@ -38,6 +39,8 @@ type options struct {
3839
withSecretsHmac []byte
3940
withStartPageAfterItem pagination.Item
4041
withWorkerFilter string
42+
WithReader db.Reader
43+
withWriter db.Writer
4144
}
4245

4346
func getDefaultOptions() options {
@@ -162,3 +165,12 @@ func WithWorkerFilter(wf string) Option {
162165
o.withWorkerFilter = wf
163166
}
164167
}
168+
169+
// WithReaderWriter is used to share the same database reader
170+
// and writer when executing sql within a transaction.
171+
func WithReaderWriter(r db.Reader, w db.Writer) Option {
172+
return func(o *options) {
173+
o.WithReader = r
174+
o.withWriter = w
175+
}
176+
}

internal/host/plugin/options_test.go

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ import (
77
"testing"
88
"time"
99

10+
"github.com/hashicorp/boundary/internal/db"
1011
"github.com/hashicorp/boundary/internal/db/timestamp"
1112
"github.com/hashicorp/boundary/internal/pagination"
1213
"github.com/stretchr/testify/assert"
@@ -113,4 +114,11 @@ func Test_GetOpts(t *testing.T) {
113114
testOpts.withWorkerFilter = `"test" in "/tags/type"`
114115
assert.Equal(t, opts, testOpts)
115116
})
117+
t.Run("WithReaderWriter", func(t *testing.T) {
118+
reader := &db.Db{}
119+
writer := &db.Db{}
120+
opts := getOpts(WithReaderWriter(reader, writer))
121+
assert.Equal(t, reader, opts.WithReader)
122+
assert.Equal(t, writer, opts.withWriter)
123+
})
116124
}

internal/host/plugin/repository_host_catalog.go

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ import (
1111
"github.com/hashicorp/boundary/internal/db"
1212
"github.com/hashicorp/boundary/internal/errors"
1313
"github.com/hashicorp/boundary/internal/event"
14+
"github.com/hashicorp/boundary/internal/host"
1415
"github.com/hashicorp/boundary/internal/kms"
1516
"github.com/hashicorp/boundary/internal/libs/patchstruct"
1617
"github.com/hashicorp/boundary/internal/oplog"
@@ -404,7 +405,7 @@ func (r *Repository) UpdateCatalog(ctx context.Context, c *HostCatalog, version
404405
ctx,
405406
db.StdRetryCnt,
406407
db.ExpBackoff{},
407-
func(_ db.Reader, w db.Writer) error {
408+
func(read db.Reader, w db.Writer) error {
408409
msgs := make([]*oplog.Message, 0, 3)
409410
ticket, err := w.GetTicket(ctx, newCatalog)
410411
if err != nil {
@@ -528,7 +529,7 @@ func (r *Repository) UpdateCatalog(ctx context.Context, c *HostCatalog, version
528529
if needSetSync {
529530
// We also need to mark all host sets in this catalog to be
530531
// synced as well.
531-
setsForCatalog, _, err := r.getSets(ctx, "", returnedCatalog.PublicId)
532+
setsForCatalog, _, err := r.getSets(ctx, "", returnedCatalog.PublicId, host.WithReaderWriter(read, w))
532533
if err != nil {
533534
return errors.Wrap(ctx, err, op, errors.WithMsg("unable to get sets for host catalog"))
534535
}
@@ -713,14 +714,19 @@ func (r *Repository) getCatalog(ctx context.Context, id string) (*HostCatalog, *
713714
return c, p, nil
714715
}
715716

716-
func (r *Repository) getPlugin(ctx context.Context, plgId string) (*plg.Plugin, error) {
717+
func (r *Repository) getPlugin(ctx context.Context, plgId string, opts ...Option) (*plg.Plugin, error) {
717718
const op = "plugin.(Repository).getPlugin"
718719
if plgId == "" {
719720
return nil, errors.New(ctx, errors.InvalidParameter, op, "no plugin id")
720721
}
722+
opt := getOpts(opts...)
723+
reader := r.reader
724+
if !util.IsNil(opt.WithReader) {
725+
reader = opt.WithReader
726+
}
721727
plg := plg.NewPlugin()
722728
plg.PublicId = plgId
723-
if err := r.reader.LookupByPublicId(ctx, plg); err != nil {
729+
if err := reader.LookupByPublicId(ctx, plg); err != nil {
724730
return nil, errors.Wrap(ctx, err, op, errors.WithMsg(fmt.Sprintf("unable to get host plugin with id %q", plgId)))
725731
}
726732
return plg, nil

internal/host/plugin/repository_host_set.go

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -804,6 +804,15 @@ func (r *Repository) getSets(ctx context.Context, publicId string, catalogId str
804804
limit = opts.WithLimit
805805
}
806806

807+
reader := r.reader
808+
writer := r.writer
809+
if !util.IsNil(opts.WithReader) {
810+
reader = opts.WithReader
811+
}
812+
if !util.IsNil(opts.WithWriter) {
813+
writer = opts.WithWriter
814+
}
815+
807816
args := make([]any, 0, 1)
808817
var where string
809818

@@ -825,7 +834,7 @@ func (r *Repository) getSets(ctx context.Context, publicId string, catalogId str
825834
}
826835

827836
var aggHostSets []*hostSetAgg
828-
if err := r.reader.SearchWhere(ctx, &aggHostSets, where, args, dbArgs...); err != nil {
837+
if err := reader.SearchWhere(ctx, &aggHostSets, where, args, dbArgs...); err != nil {
829838
return nil, nil, errors.Wrap(ctx, err, op, errors.WithMsg(fmt.Sprintf("in %s", publicId)))
830839
}
831840

@@ -844,7 +853,7 @@ func (r *Repository) getSets(ctx context.Context, publicId string, catalogId str
844853
}
845854
var plg *plugin.Plugin
846855
if plgId != "" {
847-
plg, err = r.getPlugin(ctx, plgId)
856+
plg, err = r.getPlugin(ctx, plgId, WithReaderWriter(reader, writer))
848857
if err != nil {
849858
return nil, nil, errors.Wrap(ctx, err, op)
850859
}

0 commit comments

Comments
 (0)