Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 31 additions & 7 deletions libbeat/api/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,13 +33,15 @@ import (
// Server takes care of correctly starting the HTTP component of the API
// and will answer all the routes defined in the received ServeMux.
type Server struct {
log *logp.Logger
mux *http.ServeMux
l net.Listener
config Config
wg sync.WaitGroup
mutex sync.Mutex
httpServer *http.Server
log *logp.Logger
mux *http.ServeMux
l net.Listener
config Config
wg sync.WaitGroup
mutex sync.Mutex
httpServer *http.Server
startCalled bool
stopCalled bool
}

// New creates a new API Server with no routes attached.
Expand Down Expand Up @@ -67,6 +69,17 @@ func New(log *logp.Logger, config *config.C) (*Server, error) {
func (s *Server) Start() {
s.mutex.Lock()
defer s.mutex.Unlock()

// only need to call Start once
if s.startCalled {
return
}
s.startCalled = true

if s.stopCalled {
s.log.Info("Not starting stating stats endpoint since stop was already called")
return
}
s.log.Info("Starting stats endpoint")
s.wg.Add(1)
s.httpServer = &http.Server{Handler: s.mux} //nolint:gosec // Keep original behavior
Expand All @@ -83,12 +96,23 @@ func (s *Server) Start() {
func (s *Server) Stop() error {
s.mutex.Lock()
defer s.mutex.Unlock()
// only need to call Stop once
if s.stopCalled {
return nil
}
s.stopCalled = true

if s.httpServer == nil {
// New always creates a listener, need to close it even if the server hasn't started
if err := s.l.Close(); err != nil {
s.log.Infof("Error closing stats endpoint (%s): %v", s.l.Addr().String(), err)
}
return nil
}
if err := s.httpServer.Close(); err != nil {
return fmt.Errorf("error closing monitoring server: %w", err)
}

s.wg.Wait()
return nil
}
Expand Down
92 changes: 86 additions & 6 deletions libbeat/api/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,16 +19,20 @@ package api

import (
"context"
"encoding/base64"
"io"
"math/rand/v2"
"net"
"net/http"
"net/http/httptest"
"os"
"path/filepath"
"runtime"
"testing"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.uber.org/goleak"

"github.com/elastic/elastic-agent-libs/config"
"github.com/elastic/elastic-agent-libs/logp/logptest"
Expand All @@ -41,7 +45,7 @@ func TestConfiguration(t *testing.T) {
return
}
t.Run("when user is set", func(t *testing.T) {
cfg := config.MustNewConfigFrom(map[string]interface{}{
cfg := config.MustNewConfigFrom(map[string]any{
"host": "unix:///tmp/ok",
"user": "admin",
})
Expand All @@ -51,7 +55,7 @@ func TestConfiguration(t *testing.T) {
})

t.Run("when security descriptor is set", func(t *testing.T) {
cfg := config.MustNewConfigFrom(map[string]interface{}{
cfg := config.MustNewConfigFrom(map[string]any{
"host": "unix:///tmp/ok",
"security_descriptor": "D:P(A;;GA;;;1234)",
})
Expand Down Expand Up @@ -87,7 +91,7 @@ func TestSocket(t *testing.T) {
sockFile := tmpDir + "/test.sock"
t.Log(sockFile)

cfg := config.MustNewConfigFrom(map[string]interface{}{
cfg := config.MustNewConfigFrom(map[string]any{
"host": "unix://" + sockFile,
})

Expand Down Expand Up @@ -130,7 +134,7 @@ func TestSocket(t *testing.T) {
require.NoError(t, err)
f.Close()

cfg := config.MustNewConfigFrom(map[string]interface{}{
cfg := config.MustNewConfigFrom(map[string]any{
"host": "unix://" + sockFile,
})

Expand Down Expand Up @@ -167,7 +171,7 @@ func TestHTTP(t *testing.T) {
// select a random free port.
url := "http://localhost:0"

cfg := config.MustNewConfigFrom(map[string]interface{}{
cfg := config.MustNewConfigFrom(map[string]any{
"host": url,
})
logger := logptest.NewTestingLogger(t, "")
Expand Down Expand Up @@ -198,7 +202,7 @@ func attachEchoHelloHandler(t *testing.T, s *Server) {
}

func TestAttachHandler(t *testing.T) {
cfg := config.MustNewConfigFrom(map[string]interface{}{
cfg := config.MustNewConfigFrom(map[string]any{
"host": "http://localhost:0",
})

Expand All @@ -222,8 +226,84 @@ func TestAttachHandler(t *testing.T) {
assert.Equal(t, http.StatusMovedPermanently, resp.Result().StatusCode)
}

func TestOrdering(t *testing.T) {
monitorSocket := genSocketPath()
var monitorHost string
if runtime.GOOS == "windows" {
monitorHost = "npipe:///" + filepath.Base(monitorSocket)
} else {
monitorHost = "unix://" + monitorSocket
}
cfg := config.MustNewConfigFrom(map[string]any{
"host": monitorHost,
})

t.Run("NewStartStop", func(t *testing.T) {
defer goleak.VerifyNone(t)
logger := logptest.NewTestingLogger(t, "")
s, err := New(logger, cfg)
require.NoError(t, err)
s.Start()
err = s.Stop()
require.NoError(t, err)
s.wg.Wait()
})
t.Run("NewStopStart", func(t *testing.T) {
defer goleak.VerifyNone(t)
logger := logptest.NewTestingLogger(t, "")
s, err := New(logger, cfg)
require.NoError(t, err)
err = s.Stop()
require.NoError(t, err)
s.Start()
s.wg.Wait()
})
t.Run("NewStop", func(t *testing.T) {
defer goleak.VerifyNone(t)
logger := logptest.NewTestingLogger(t, "")
s, err := New(logger, cfg)
require.NoError(t, err)
err = s.Stop()
require.NoError(t, err)
s.wg.Wait()
})
t.Run("NewStopStop", func(t *testing.T) {
defer goleak.VerifyNone(t)
logger := logptest.NewTestingLogger(t, "")
s, err := New(logger, cfg)
require.NoError(t, err)
err = s.Stop()
require.NoError(t, err)
err = s.Stop()
require.NoError(t, err)
s.wg.Wait()
})
t.Run("NewStartStartStop", func(t *testing.T) {
defer goleak.VerifyNone(t)
logger := logptest.NewTestingLogger(t, "")
s, err := New(logger, cfg)
require.NoError(t, err)
s.Start()
s.Start()
err = s.Stop()
require.NoError(t, err)
s.wg.Wait()
})
}

func newTestHandler(response string) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
_, _ = io.WriteString(w, response)
})
}

func genSocketPath() string {
randData := make([]byte, 16)
for i := range len(randData) {
randData[i] = uint8(rand.UintN(255)) //nolint:gosec // 0-255 fits in a uint8
}
socketName := base64.URLEncoding.EncodeToString(randData) + ".sock"
// don't use t.TempDir() because it can be too long
socketDir := os.TempDir()
return filepath.Join(socketDir, socketName)
}
Loading