Skip to content

Commit 56433bc

Browse files
authored
[chore] Extract batcher worker pool, cleanup unit tests (#13164)
Signed-off-by: Bogdan Drutu <[email protected]>
1 parent c9aaed8 commit 56433bc

File tree

3 files changed

+66
-96
lines changed

3 files changed

+66
-96
lines changed

exporter/exporterhelper/internal/queuebatch/multi_batcher.go

Lines changed: 7 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ type batcherSettings[T any] struct {
2121

2222
type multiBatcher struct {
2323
cfg BatchConfig
24-
wp chan struct{}
24+
wp *workerPool
2525
sizerType request.SizerType
2626
sizer request.Sizer[request.Request]
2727
partitioner Partitioner[request.Request]
@@ -34,16 +34,9 @@ type multiBatcher struct {
3434
var _ Batcher[request.Request] = (*multiBatcher)(nil)
3535

3636
func newMultiBatcher(bCfg BatchConfig, bSet batcherSettings[request.Request]) *multiBatcher {
37-
var workerPool chan struct{}
38-
if bSet.maxWorkers != 0 {
39-
workerPool = make(chan struct{}, bSet.maxWorkers)
40-
for i := 0; i < bSet.maxWorkers; i++ {
41-
workerPool <- struct{}{}
42-
}
43-
}
4437
mb := &multiBatcher{
4538
cfg: bCfg,
46-
wp: workerPool,
39+
wp: newWorkerPool(bSet.maxWorkers),
4740
sizerType: bSet.sizerType,
4841
sizer: bSet.sizer,
4942
partitioner: bSet.partitioner,
@@ -68,18 +61,18 @@ func (mb *multiBatcher) getShard(ctx context.Context, req request.Request) *shar
6861
return s.(*shardBatcher)
6962
}
7063
newS := newShard(mb.cfg, mb.sizerType, mb.sizer, mb.wp, mb.consumeFunc)
71-
newS.start(ctx, nil)
64+
_ = newS.Start(ctx, nil)
7265
s, loaded := mb.shards.LoadOrStore(key, newS)
7366
// If not loaded, there was a race condition in adding the new shard. Shutdown the newly created shard.
7467
if loaded {
75-
newS.shutdown(ctx)
68+
_ = newS.Shutdown(ctx)
7669
}
7770
return s.(*shardBatcher)
7871
}
7972

8073
func (mb *multiBatcher) Start(ctx context.Context, host component.Host) error {
8174
if mb.singleShard != nil {
82-
mb.singleShard.start(ctx, host)
75+
return mb.singleShard.Start(ctx, host)
8376
}
8477
return nil
8578
}
@@ -91,16 +84,15 @@ func (mb *multiBatcher) Consume(ctx context.Context, req request.Request, done D
9184

9285
func (mb *multiBatcher) Shutdown(ctx context.Context) error {
9386
if mb.singleShard != nil {
94-
mb.singleShard.shutdown(ctx)
95-
return nil
87+
return mb.singleShard.Shutdown(ctx)
9688
}
9789

9890
var wg sync.WaitGroup
9991
mb.shards.Range(func(_ any, shard any) bool {
10092
wg.Add(1)
10193
go func() {
10294
defer wg.Done()
103-
shard.(*shardBatcher).shutdown(ctx)
95+
_ = shard.(*shardBatcher).Shutdown(ctx)
10496
}()
10597
return true
10698
})

exporter/exporterhelper/internal/queuebatch/shard_batcher.go

Lines changed: 47 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ type batch struct {
2424
// shardBatcher continuously batch incoming requests and flushes asynchronously if minimum size limit is met or on timeout.
2525
type shardBatcher struct {
2626
cfg BatchConfig
27-
workerPool chan struct{}
27+
wp *workerPool
2828
sizerType request.SizerType
2929
sizer request.Sizer[request.Request]
3030
consumeFunc sender.SendFunc[request.Request]
@@ -35,10 +35,10 @@ type shardBatcher struct {
3535
shutdownCh chan struct{}
3636
}
3737

38-
func newShard(cfg BatchConfig, sizerType request.SizerType, sizer request.Sizer[request.Request], workerPool chan struct{}, next sender.SendFunc[request.Request]) *shardBatcher {
38+
func newShard(cfg BatchConfig, sizerType request.SizerType, sizer request.Sizer[request.Request], wp *workerPool, next sender.SendFunc[request.Request]) *shardBatcher {
3939
return &shardBatcher{
4040
cfg: cfg,
41-
workerPool: workerPool,
41+
wp: wp,
4242
sizerType: sizerType,
4343
sizer: sizer,
4444
consumeFunc: next,
@@ -149,22 +149,33 @@ func (qb *shardBatcher) Consume(ctx context.Context, req request.Request, done D
149149
}
150150

151151
// Start starts the goroutine that reads from the queue and flushes asynchronously.
152-
func (qb *shardBatcher) start(_ context.Context, _ component.Host) {
153-
if qb.cfg.FlushTimeout > 0 {
154-
qb.timer = time.NewTimer(qb.cfg.FlushTimeout)
155-
qb.stopWG.Add(1)
156-
go func() {
157-
defer qb.stopWG.Done()
158-
for {
159-
select {
160-
case <-qb.shutdownCh:
161-
return
162-
case <-qb.timer.C:
163-
qb.flushCurrentBatchIfNecessary()
164-
}
165-
}
166-
}()
152+
func (qb *shardBatcher) Start(context.Context, component.Host) error {
153+
if qb.cfg.FlushTimeout <= 0 {
154+
return nil
167155
}
156+
qb.timer = time.NewTimer(qb.cfg.FlushTimeout)
157+
qb.stopWG.Add(1)
158+
go func() {
159+
defer qb.stopWG.Done()
160+
for {
161+
select {
162+
case <-qb.shutdownCh:
163+
return
164+
case <-qb.timer.C:
165+
qb.flushCurrentBatchIfNecessary()
166+
}
167+
}
168+
}()
169+
return nil
170+
}
171+
172+
// Shutdown ensures that queue and all Batcher are stopped.
173+
func (qb *shardBatcher) Shutdown(context.Context) error {
174+
close(qb.shutdownCh)
175+
// Make sure execute one last flush if necessary.
176+
qb.flushCurrentBatchIfNecessary()
177+
qb.stopWG.Wait()
178+
return nil
168179
}
169180

170181
// flushCurrentBatchIfNecessary sends out the current request batch if it is not nil
@@ -186,24 +197,28 @@ func (qb *shardBatcher) flushCurrentBatchIfNecessary() {
186197
// flush starts a goroutine that calls consumeFunc. It blocks until a worker is available if necessary.
187198
func (qb *shardBatcher) flush(ctx context.Context, req request.Request, done Done) {
188199
qb.stopWG.Add(1)
189-
if qb.workerPool != nil {
190-
<-qb.workerPool
191-
}
192-
go func() {
200+
qb.wp.execute(func() {
193201
defer qb.stopWG.Done()
194202
done.OnDone(qb.consumeFunc(ctx, req))
195-
if qb.workerPool != nil {
196-
qb.workerPool <- struct{}{}
197-
}
198-
}()
203+
})
199204
}
200205

201-
// Shutdown ensures that queue and all Batcher are stopped.
202-
func (qb *shardBatcher) shutdown(_ context.Context) {
203-
close(qb.shutdownCh)
204-
// Make sure execute one last flush if necessary.
205-
qb.flushCurrentBatchIfNecessary()
206-
qb.stopWG.Wait()
206+
type workerPool struct {
207+
workers chan struct{}
208+
}
209+
210+
func newWorkerPool(maxWorkers int) *workerPool {
211+
workers := make(chan struct{}, maxWorkers)
212+
for i := 0; i < maxWorkers; i++ {
213+
workers <- struct{}{}
214+
}
215+
return &workerPool{workers: workers}
216+
}
217+
218+
func (wp *workerPool) execute(f func()) {
219+
<-wp.workers
220+
go f()
221+
wp.workers <- struct{}{}
207222
}
208223

209224
type multiDone []Done

exporter/exporterhelper/internal/queuebatch/shard_batcher_test.go

Lines changed: 12 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -59,13 +59,7 @@ func TestShardBatcher_NoSplit_MinThresholdZero_TimeoutDisabled(t *testing.T) {
5959
}
6060

6161
sink := requesttest.NewSink()
62-
ba := newMultiBatcher(cfg, batcherSettings[request.Request]{
63-
sizerType: tt.sizerType,
64-
sizer: tt.sizer,
65-
partitioner: nil,
66-
next: sink.Export,
67-
maxWorkers: tt.maxWorkers,
68-
})
62+
ba := newShard(cfg, tt.sizerType, tt.sizer, newWorkerPool(tt.maxWorkers), sink.Export)
6963
require.NoError(t, ba.Start(context.Background(), componenttest.NewNopHost()))
7064
t.Cleanup(func() {
7165
require.NoError(t, ba.Shutdown(context.Background()))
@@ -83,7 +77,7 @@ func TestShardBatcher_NoSplit_MinThresholdZero_TimeoutDisabled(t *testing.T) {
8377
assert.Eventually(t, func() bool {
8478
return sink.RequestsCount() == 5 && (sink.ItemsCount() == 75 || sink.BytesCount() == 75)
8579
}, 1*time.Second, 10*time.Millisecond)
86-
// Check that done callback is called for the right amount of times.
80+
// Check that done callback is called for the right number of times.
8781
assert.EqualValues(t, 1, done.errors.Load())
8882
assert.EqualValues(t, 5, done.success.Load())
8983
})
@@ -130,13 +124,7 @@ func TestShardBatcher_NoSplit_TimeoutDisabled(t *testing.T) {
130124
}
131125

132126
sink := requesttest.NewSink()
133-
ba := newMultiBatcher(cfg, batcherSettings[request.Request]{
134-
sizerType: tt.sizerType,
135-
sizer: tt.sizer,
136-
partitioner: nil,
137-
next: sink.Export,
138-
maxWorkers: tt.maxWorkers,
139-
})
127+
ba := newShard(cfg, tt.sizerType, tt.sizer, newWorkerPool(tt.maxWorkers), sink.Export)
140128
require.NoError(t, ba.Start(context.Background(), componenttest.NewNopHost()))
141129

142130
done := newFakeDone()
@@ -165,7 +153,7 @@ func TestShardBatcher_NoSplit_TimeoutDisabled(t *testing.T) {
165153
assert.Equal(t, 3, sink.RequestsCount())
166154
assert.True(t, sink.ItemsCount() == 57 || sink.BytesCount() == 57)
167155

168-
// Check that done callback is called for the right amount of times.
156+
// Check that done callback is called for the right number of times.
169157
assert.EqualValues(t, 3, done.errors.Load())
170158
assert.EqualValues(t, 4, done.success.Load())
171159
})
@@ -216,13 +204,7 @@ func TestShardBatcher_NoSplit_WithTimeout(t *testing.T) {
216204
}
217205

218206
sink := requesttest.NewSink()
219-
ba := newMultiBatcher(cfg, batcherSettings[request.Request]{
220-
sizerType: tt.sizerType,
221-
sizer: tt.sizer,
222-
partitioner: nil,
223-
next: sink.Export,
224-
maxWorkers: tt.maxWorkers,
225-
})
207+
ba := newShard(cfg, tt.sizerType, tt.sizer, newWorkerPool(tt.maxWorkers), sink.Export)
226208
require.NoError(t, ba.Start(context.Background(), componenttest.NewNopHost()))
227209
t.Cleanup(func() {
228210
require.NoError(t, ba.Shutdown(context.Background()))
@@ -241,7 +223,7 @@ func TestShardBatcher_NoSplit_WithTimeout(t *testing.T) {
241223
return sink.RequestsCount() == 1 && (sink.ItemsCount() == 75 || sink.BytesCount() == 75)
242224
}, 1*time.Second, 10*time.Millisecond)
243225

244-
// Check that done callback is called for the right amount of times.
226+
// Check that done callback is called for the right number of times.
245227
assert.EqualValues(t, 1, done.errors.Load())
246228
assert.EqualValues(t, 5, done.success.Load())
247229
})
@@ -293,13 +275,7 @@ func TestShardBatcher_Split_TimeoutDisabled(t *testing.T) {
293275
}
294276

295277
sink := requesttest.NewSink()
296-
ba := newMultiBatcher(cfg, batcherSettings[request.Request]{
297-
sizerType: tt.sizerType,
298-
sizer: tt.sizer,
299-
partitioner: nil,
300-
next: sink.Export,
301-
maxWorkers: tt.maxWorkers,
302-
})
278+
ba := newShard(cfg, tt.sizerType, tt.sizer, newWorkerPool(tt.maxWorkers), sink.Export)
303279
require.NoError(t, ba.Start(context.Background(), componenttest.NewNopHost()))
304280

305281
done := newFakeDone()
@@ -332,7 +308,7 @@ func TestShardBatcher_Split_TimeoutDisabled(t *testing.T) {
332308
assert.Equal(t, 11, sink.RequestsCount())
333309
assert.True(t, sink.ItemsCount() == 1005 || sink.BytesCount() == 1005)
334310

335-
// Check that done callback is called for the right amount of times.
311+
// Check that done callback is called for the right number of times.
336312
assert.EqualValues(t, 2, done.errors.Load())
337313
assert.EqualValues(t, 7, done.success.Load())
338314
})
@@ -346,13 +322,7 @@ func TestShardBatcher_Shutdown(t *testing.T) {
346322
}
347323

348324
sink := requesttest.NewSink()
349-
ba := newMultiBatcher(cfg, batcherSettings[request.Request]{
350-
sizerType: request.SizerTypeItems,
351-
sizer: request.NewItemsSizer(),
352-
partitioner: nil,
353-
next: sink.Export,
354-
maxWorkers: 2,
355-
})
325+
ba := newShard(cfg, request.SizerTypeItems, request.NewItemsSizer(), newWorkerPool(2), sink.Export)
356326
require.NoError(t, ba.Start(context.Background(), componenttest.NewNopHost()))
357327

358328
done := newFakeDone()
@@ -367,7 +337,7 @@ func TestShardBatcher_Shutdown(t *testing.T) {
367337
assert.Equal(t, 1, sink.RequestsCount())
368338
assert.Equal(t, 3, sink.ItemsCount())
369339

370-
// Check that done callback is called for the right amount of times.
340+
// Check that done callback is called for the right number of times.
371341
assert.EqualValues(t, 0, done.errors.Load())
372342
assert.EqualValues(t, 2, done.success.Load())
373343
}
@@ -380,14 +350,7 @@ func TestShardBatcher_MergeError(t *testing.T) {
380350
}
381351

382352
sink := requesttest.NewSink()
383-
ba := newMultiBatcher(cfg, batcherSettings[request.Request]{
384-
sizerType: request.SizerTypeItems,
385-
sizer: request.NewItemsSizer(),
386-
partitioner: nil,
387-
next: sink.Export,
388-
maxWorkers: 2,
389-
})
390-
353+
ba := newShard(cfg, request.SizerTypeItems, request.NewItemsSizer(), newWorkerPool(2), sink.Export)
391354
require.NoError(t, ba.Start(context.Background(), componenttest.NewNopHost()))
392355
t.Cleanup(func() {
393356
require.NoError(t, ba.Shutdown(context.Background()))
@@ -405,7 +368,7 @@ func TestShardBatcher_MergeError(t *testing.T) {
405368
return done.errors.Load() == 2
406369
}, 1*time.Second, 10*time.Millisecond)
407370

408-
// Check that done callback is called for the right amount of times.
371+
// Check that done callback is called for the right number of times.
409372
assert.EqualValues(t, 2, done.errors.Load())
410373
assert.EqualValues(t, 0, done.success.Load())
411374
}

0 commit comments

Comments
 (0)