Skip to content

Commit 4490460

Browse files
authored
enforce ordering of libbeat API server Start/Stop functions (#46865)
* enforce ordering of libbeat API server Start/Stop functions
1 parent 64de32c commit 4490460

File tree

2 files changed

+143
-22
lines changed

2 files changed

+143
-22
lines changed

libbeat/api/server.go

Lines changed: 57 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,14 @@ import (
3030
"github.com/elastic/elastic-agent-libs/logp"
3131
)
3232

33+
type serverState int
34+
35+
const (
36+
stateNew = iota
37+
stateStarted
38+
stateStopped
39+
)
40+
3341
// Server takes care of correctly starting the HTTP component of the API
3442
// and will answer all the routes defined in the received ServeMux.
3543
type Server struct {
@@ -40,6 +48,7 @@ type Server struct {
4048
wg sync.WaitGroup
4149
mutex sync.Mutex
4250
httpServer *http.Server
51+
state serverState
4352
}
4453

4554
// New creates a new API Server with no routes attached.
@@ -60,37 +69,69 @@ func New(log *logp.Logger, config *config.C) (*Server, error) {
6069
l: l,
6170
config: cfg,
6271
log: log.Named("api"),
72+
state: stateNew,
6373
}, nil
6474
}
6575

6676
// Start starts the HTTP server and accepting new connection.
6777
func (s *Server) Start() {
6878
s.mutex.Lock()
6979
defer s.mutex.Unlock()
70-
s.log.Info("Starting stats endpoint")
71-
s.wg.Add(1)
72-
s.httpServer = &http.Server{Handler: s.mux} //nolint:gosec // Keep original behavior
73-
go func(l net.Listener) {
74-
defer s.wg.Done()
75-
s.log.Infof("Metrics endpoint listening on: %s (configured: %s)", l.Addr().String(), s.config.Host)
76-
77-
err := s.httpServer.Serve(l)
78-
s.log.Infof("Stats endpoint (%s) finished: %v", l.Addr().String(), err)
79-
}(s.l)
80+
81+
switch s.state {
82+
case stateNew:
83+
s.state = stateStarted
84+
s.log.Info("Starting stats endpoint")
85+
s.wg.Add(1)
86+
s.httpServer = &http.Server{Handler: s.mux} //nolint:gosec // Keep original behavior
87+
go func(l net.Listener) {
88+
defer s.wg.Done()
89+
s.log.Infof("Metrics endpoint listening on: %s (configured: %s)", l.Addr().String(), s.config.Host)
90+
91+
err := s.httpServer.Serve(l)
92+
s.log.Infof("Stats endpoint (%s) finished: %v", l.Addr().String(), err)
93+
}(s.l)
94+
return
95+
case stateStarted:
96+
// only call Start once
97+
s.log.Debug("not starting stats endpoint because start was already called")
98+
return
99+
case stateStopped:
100+
s.log.Debug("not starting stats endpoint because stop was already called")
101+
return
102+
default:
103+
s.log.Errorf("unknown stats server state: %d", s.state)
104+
}
80105
}
81106

82107
// Stop stops the API server and free any resource associated with the process like unix sockets.
83108
func (s *Server) Stop() error {
84109
s.mutex.Lock()
85110
defer s.mutex.Unlock()
86-
if s.httpServer == nil {
111+
112+
switch s.state {
113+
case stateNew:
114+
s.state = stateStopped
115+
// New always creates a listener, need to close it even if the server hasn't started
116+
if err := s.l.Close(); err != nil {
117+
s.log.Infof("error closing stats endpoint (%s): %v", s.l.Addr().String(), err)
118+
}
87119
return nil
120+
case stateStarted:
121+
s.state = stateStopped
122+
// Closing the server will also close the listener
123+
if err := s.httpServer.Close(); err != nil {
124+
return fmt.Errorf("error closing monitoring server: %w", err)
125+
}
126+
s.wg.Wait()
127+
return nil
128+
case stateStopped:
129+
// only need to call Stop once
130+
s.log.Debug("not stopping stats endpoint because stop was already called")
131+
return nil
132+
default:
133+
return fmt.Errorf("unknown stats server state: %d", s.state)
88134
}
89-
if err := s.httpServer.Close(); err != nil {
90-
return fmt.Errorf("error closing monitoring server: %w", err)
91-
}
92-
s.wg.Wait()
93-
return nil
94135
}
95136

96137
// AttachHandler will attach a handler at the specified route. Routes are

libbeat/api/server_test.go

Lines changed: 86 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -19,16 +19,20 @@ package api
1919

2020
import (
2121
"context"
22+
"encoding/base64"
2223
"io"
24+
"math/rand/v2"
2325
"net"
2426
"net/http"
2527
"net/http/httptest"
2628
"os"
29+
"path/filepath"
2730
"runtime"
2831
"testing"
2932

3033
"github.com/stretchr/testify/assert"
3134
"github.com/stretchr/testify/require"
35+
"go.uber.org/goleak"
3236

3337
"github.com/elastic/elastic-agent-libs/config"
3438
"github.com/elastic/elastic-agent-libs/logp/logptest"
@@ -41,7 +45,7 @@ func TestConfiguration(t *testing.T) {
4145
return
4246
}
4347
t.Run("when user is set", func(t *testing.T) {
44-
cfg := config.MustNewConfigFrom(map[string]interface{}{
48+
cfg := config.MustNewConfigFrom(map[string]any{
4549
"host": "unix:///tmp/ok",
4650
"user": "admin",
4751
})
@@ -51,7 +55,7 @@ func TestConfiguration(t *testing.T) {
5155
})
5256

5357
t.Run("when security descriptor is set", func(t *testing.T) {
54-
cfg := config.MustNewConfigFrom(map[string]interface{}{
58+
cfg := config.MustNewConfigFrom(map[string]any{
5559
"host": "unix:///tmp/ok",
5660
"security_descriptor": "D:P(A;;GA;;;1234)",
5761
})
@@ -87,7 +91,7 @@ func TestSocket(t *testing.T) {
8791
sockFile := tmpDir + "/test.sock"
8892
t.Log(sockFile)
8993

90-
cfg := config.MustNewConfigFrom(map[string]interface{}{
94+
cfg := config.MustNewConfigFrom(map[string]any{
9195
"host": "unix://" + sockFile,
9296
})
9397

@@ -130,7 +134,7 @@ func TestSocket(t *testing.T) {
130134
require.NoError(t, err)
131135
f.Close()
132136

133-
cfg := config.MustNewConfigFrom(map[string]interface{}{
137+
cfg := config.MustNewConfigFrom(map[string]any{
134138
"host": "unix://" + sockFile,
135139
})
136140

@@ -167,7 +171,7 @@ func TestHTTP(t *testing.T) {
167171
// select a random free port.
168172
url := "http://localhost:0"
169173

170-
cfg := config.MustNewConfigFrom(map[string]interface{}{
174+
cfg := config.MustNewConfigFrom(map[string]any{
171175
"host": url,
172176
})
173177
logger := logptest.NewTestingLogger(t, "")
@@ -198,7 +202,7 @@ func attachEchoHelloHandler(t *testing.T, s *Server) {
198202
}
199203

200204
func TestAttachHandler(t *testing.T) {
201-
cfg := config.MustNewConfigFrom(map[string]interface{}{
205+
cfg := config.MustNewConfigFrom(map[string]any{
202206
"host": "http://localhost:0",
203207
})
204208

@@ -222,8 +226,84 @@ func TestAttachHandler(t *testing.T) {
222226
assert.Equal(t, http.StatusMovedPermanently, resp.Result().StatusCode)
223227
}
224228

229+
func TestOrdering(t *testing.T) {
230+
monitorSocket := genSocketPath()
231+
var monitorHost string
232+
if runtime.GOOS == "windows" {
233+
monitorHost = "npipe:///" + filepath.Base(monitorSocket)
234+
} else {
235+
monitorHost = "unix://" + monitorSocket
236+
}
237+
cfg := config.MustNewConfigFrom(map[string]any{
238+
"host": monitorHost,
239+
})
240+
241+
t.Run("NewStartStop", func(t *testing.T) {
242+
defer goleak.VerifyNone(t)
243+
logger := logptest.NewTestingLogger(t, "")
244+
s, err := New(logger, cfg)
245+
require.NoError(t, err)
246+
s.Start()
247+
err = s.Stop()
248+
require.NoError(t, err)
249+
s.wg.Wait()
250+
})
251+
t.Run("NewStopStart", func(t *testing.T) {
252+
defer goleak.VerifyNone(t)
253+
logger := logptest.NewTestingLogger(t, "")
254+
s, err := New(logger, cfg)
255+
require.NoError(t, err)
256+
err = s.Stop()
257+
require.NoError(t, err)
258+
s.Start()
259+
s.wg.Wait()
260+
})
261+
t.Run("NewStop", func(t *testing.T) {
262+
defer goleak.VerifyNone(t)
263+
logger := logptest.NewTestingLogger(t, "")
264+
s, err := New(logger, cfg)
265+
require.NoError(t, err)
266+
err = s.Stop()
267+
require.NoError(t, err)
268+
s.wg.Wait()
269+
})
270+
t.Run("NewStopStop", func(t *testing.T) {
271+
defer goleak.VerifyNone(t)
272+
logger := logptest.NewTestingLogger(t, "")
273+
s, err := New(logger, cfg)
274+
require.NoError(t, err)
275+
err = s.Stop()
276+
require.NoError(t, err)
277+
err = s.Stop()
278+
require.NoError(t, err)
279+
s.wg.Wait()
280+
})
281+
t.Run("NewStartStartStop", func(t *testing.T) {
282+
defer goleak.VerifyNone(t)
283+
logger := logptest.NewTestingLogger(t, "")
284+
s, err := New(logger, cfg)
285+
require.NoError(t, err)
286+
s.Start()
287+
s.Start()
288+
err = s.Stop()
289+
require.NoError(t, err)
290+
s.wg.Wait()
291+
})
292+
}
293+
225294
func newTestHandler(response string) http.Handler {
226295
return http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
227296
_, _ = io.WriteString(w, response)
228297
})
229298
}
299+
300+
func genSocketPath() string {
301+
randData := make([]byte, 16)
302+
for i := range len(randData) {
303+
randData[i] = uint8(rand.UintN(255)) //nolint:gosec // 0-255 fits in a uint8
304+
}
305+
socketName := base64.URLEncoding.EncodeToString(randData) + ".sock"
306+
// don't use t.TempDir() because it can be too long
307+
socketDir := os.TempDir()
308+
return filepath.Join(socketDir, socketName)
309+
}

0 commit comments

Comments
 (0)