diff --git a/libbeat/api/server.go b/libbeat/api/server.go index 28e498b5fb6d..4212714f72c0 100644 --- a/libbeat/api/server.go +++ b/libbeat/api/server.go @@ -31,6 +31,14 @@ import ( "github.com/elastic/elastic-agent-libs/logp" ) +type serverState int + +const ( + stateNew = iota + stateStarted + stateStopped +) + // Server takes care of correctly starting the HTTP component of the API // and will answer all the routes defined in the received ServeMux. type Server struct { @@ -41,6 +49,7 @@ type Server struct { wg sync.WaitGroup mutex sync.Mutex httpServer *http.Server + state serverState } // New creates a new API Server with no routes attached. @@ -61,6 +70,7 @@ func New(log *logp.Logger, config *config.C) (*Server, error) { l: l, config: cfg, log: log.Named("api"), + state: stateNew, }, nil } @@ -68,30 +78,61 @@ func New(log *logp.Logger, config *config.C) (*Server, error) { func (s *Server) Start() { s.mutex.Lock() defer s.mutex.Unlock() - s.log.Info("Starting stats endpoint") - s.wg.Add(1) - s.httpServer = &http.Server{Handler: s.mux} //nolint:gosec // Keep original behavior - go func(l net.Listener) { - defer s.wg.Done() - s.log.Infof("Metrics endpoint listening on: %s (configured: %s)", l.Addr().String(), s.config.Host) - - err := s.httpServer.Serve(l) - s.log.Infof("Stats endpoint (%s) finished: %v", l.Addr().String(), err) - }(s.l) + + switch s.state { + case stateNew: + s.state = stateStarted + s.log.Info("Starting stats endpoint") + s.wg.Add(1) + s.httpServer = &http.Server{Handler: s.mux} //nolint:gosec // Keep original behavior + go func(l net.Listener) { + defer s.wg.Done() + s.log.Infof("Metrics endpoint listening on: %s (configured: %s)", l.Addr().String(), s.config.Host) + + err := s.httpServer.Serve(l) + s.log.Infof("Stats endpoint (%s) finished: %v", l.Addr().String(), err) + }(s.l) + return + case stateStarted: + // only call Start once + s.log.Debug("not starting stats endpoint because start was already called") + return + case stateStopped: + s.log.Debug("not starting stats endpoint because stop was already called") + return + default: + s.log.Errorf("unknown stats server state: %d", s.state) + } } // Stop stops the API server and free any resource associated with the process like unix sockets. func (s *Server) Stop() error { s.mutex.Lock() defer s.mutex.Unlock() - if s.httpServer == nil { + + switch s.state { + case stateNew: + s.state = stateStopped + // New always creates a listener, need to close it even if the server hasn't started + if err := s.l.Close(); err != nil { + s.log.Infof("error closing stats endpoint (%s): %v", s.l.Addr().String(), err) + } return nil + case stateStarted: + s.state = stateStopped + // Closing the server will also close the listener + if err := s.httpServer.Close(); err != nil { + return fmt.Errorf("error closing monitoring server: %w", err) + } + s.wg.Wait() + return nil + case stateStopped: + // only need to call Stop once + s.log.Debug("not stopping stats endpoint because stop was already called") + return nil + default: + return fmt.Errorf("unknown stats server state: %d", s.state) } - if err := s.httpServer.Close(); err != nil { - return fmt.Errorf("error closing monitoring server: %w", err) - } - s.wg.Wait() - return nil } // AttachHandler will attach a handler at the specified route. Routes are diff --git a/libbeat/api/server_test.go b/libbeat/api/server_test.go index f72903b3cff9..1306939c3467 100644 --- a/libbeat/api/server_test.go +++ b/libbeat/api/server_test.go @@ -19,16 +19,20 @@ package api import ( "context" + "encoding/base64" "io" + "math/rand/v2" "net" "net/http" "net/http/httptest" "os" + "path/filepath" "runtime" "testing" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "go.uber.org/goleak" "github.com/elastic/elastic-agent-libs/config" "github.com/elastic/elastic-agent-libs/logp/logptest" @@ -41,7 +45,7 @@ func TestConfiguration(t *testing.T) { return } t.Run("when user is set", func(t *testing.T) { - cfg := config.MustNewConfigFrom(map[string]interface{}{ + cfg := config.MustNewConfigFrom(map[string]any{ "host": "unix:///tmp/ok", "user": "admin", }) @@ -51,7 +55,7 @@ func TestConfiguration(t *testing.T) { }) t.Run("when security descriptor is set", func(t *testing.T) { - cfg := config.MustNewConfigFrom(map[string]interface{}{ + cfg := config.MustNewConfigFrom(map[string]any{ "host": "unix:///tmp/ok", "security_descriptor": "D:P(A;;GA;;;1234)", }) @@ -87,7 +91,7 @@ func TestSocket(t *testing.T) { sockFile := tmpDir + "/test.sock" t.Log(sockFile) - cfg := config.MustNewConfigFrom(map[string]interface{}{ + cfg := config.MustNewConfigFrom(map[string]any{ "host": "unix://" + sockFile, }) @@ -130,7 +134,7 @@ func TestSocket(t *testing.T) { require.NoError(t, err) f.Close() - cfg := config.MustNewConfigFrom(map[string]interface{}{ + cfg := config.MustNewConfigFrom(map[string]any{ "host": "unix://" + sockFile, }) @@ -167,7 +171,7 @@ func TestHTTP(t *testing.T) { // select a random free port. url := "http://localhost:0" - cfg := config.MustNewConfigFrom(map[string]interface{}{ + cfg := config.MustNewConfigFrom(map[string]any{ "host": url, }) logger := logptest.NewTestingLogger(t, "") @@ -198,7 +202,7 @@ func attachEchoHelloHandler(t *testing.T, s *Server) { } func TestAttachHandler(t *testing.T) { - cfg := config.MustNewConfigFrom(map[string]interface{}{ + cfg := config.MustNewConfigFrom(map[string]any{ "host": "http://localhost:0", }) @@ -223,8 +227,84 @@ func TestAttachHandler(t *testing.T) { assert.Equal(t, "test!", resp.Body.String()) } +func TestOrdering(t *testing.T) { + monitorSocket := genSocketPath() + var monitorHost string + if runtime.GOOS == "windows" { + monitorHost = "npipe:///" + filepath.Base(monitorSocket) + } else { + monitorHost = "unix://" + monitorSocket + } + cfg := config.MustNewConfigFrom(map[string]any{ + "host": monitorHost, + }) + + t.Run("NewStartStop", func(t *testing.T) { + defer goleak.VerifyNone(t) + logger := logptest.NewTestingLogger(t, "") + s, err := New(logger, cfg) + require.NoError(t, err) + s.Start() + err = s.Stop() + require.NoError(t, err) + s.wg.Wait() + }) + t.Run("NewStopStart", func(t *testing.T) { + defer goleak.VerifyNone(t) + logger := logptest.NewTestingLogger(t, "") + s, err := New(logger, cfg) + require.NoError(t, err) + err = s.Stop() + require.NoError(t, err) + s.Start() + s.wg.Wait() + }) + t.Run("NewStop", func(t *testing.T) { + defer goleak.VerifyNone(t) + logger := logptest.NewTestingLogger(t, "") + s, err := New(logger, cfg) + require.NoError(t, err) + err = s.Stop() + require.NoError(t, err) + s.wg.Wait() + }) + t.Run("NewStopStop", func(t *testing.T) { + defer goleak.VerifyNone(t) + logger := logptest.NewTestingLogger(t, "") + s, err := New(logger, cfg) + require.NoError(t, err) + err = s.Stop() + require.NoError(t, err) + err = s.Stop() + require.NoError(t, err) + s.wg.Wait() + }) + t.Run("NewStartStartStop", func(t *testing.T) { + defer goleak.VerifyNone(t) + logger := logptest.NewTestingLogger(t, "") + s, err := New(logger, cfg) + require.NoError(t, err) + s.Start() + s.Start() + err = s.Stop() + require.NoError(t, err) + s.wg.Wait() + }) +} + func newTestHandler(response string) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { _, _ = io.WriteString(w, response) }) } + +func genSocketPath() string { + randData := make([]byte, 16) + for i := range len(randData) { + randData[i] = uint8(rand.UintN(255)) //nolint:gosec // 0-255 fits in a uint8 + } + socketName := base64.URLEncoding.EncodeToString(randData) + ".sock" + // don't use t.TempDir() because it can be too long + socketDir := os.TempDir() + return filepath.Join(socketDir, socketName) +}