Skip to content

Commit e365025

Browse files
committed
Enhance NATS integration and server configuration
- Added comprehensive documentation to `mcp-nats.go` for better understanding of NATS integration. - Improved error handling in context functions to provide clearer error messages. - Introduced helper functions for determining and validating NATS URLs. - Refactored server initialization in `main.go` to utilize a configuration struct, enhancing clarity and maintainability. - Implemented graceful shutdown handling for the server with context management and signal handling.
1 parent bddcb0f commit e365025

File tree

2 files changed

+185
-62
lines changed

2 files changed

+185
-62
lines changed

cmd/mcp-nats/main.go

Lines changed: 86 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -5,17 +5,46 @@ import (
55
"flag"
66
"fmt"
77
"os"
8+
"os/signal"
9+
"syscall"
10+
"time"
811

912
"github.com/mark3labs/mcp-go/server"
1013
mcpnats "github.com/sinadarbouy/mcp-nats"
1114
"github.com/sinadarbouy/mcp-nats/internal/logger"
1215
"github.com/sinadarbouy/mcp-nats/tools"
1316
)
1417

18+
const (
19+
// Version of the application
20+
Version = "0.1.0"
21+
// AppName is the name of the application
22+
AppName = "mcp-nats"
23+
)
24+
25+
// Config holds all configuration for the server
26+
type Config struct {
27+
Transport string
28+
SSEAddr string
29+
LogLevel string
30+
JSONLogs bool
31+
}
32+
33+
// validateConfig ensures all config values are valid
34+
func validateConfig(cfg *Config) error {
35+
if cfg.Transport != "stdio" && cfg.Transport != "sse" {
36+
return fmt.Errorf("invalid transport type: %s (must be 'stdio' or 'sse')", cfg.Transport)
37+
}
38+
if cfg.Transport == "sse" && cfg.SSEAddr == "" {
39+
return fmt.Errorf("sse-address cannot be empty when using sse transport")
40+
}
41+
return nil
42+
}
43+
1544
func newServer() (*server.MCPServer, error) {
1645
s := server.NewMCPServer(
17-
"mcp-nats",
18-
"0.1.0",
46+
AppName,
47+
Version,
1948
server.WithResourceCapabilities(true, true),
2049
server.WithLogging(),
2150
server.WithRecovery(),
@@ -24,7 +53,7 @@ func newServer() (*server.MCPServer, error) {
2453
// Initialize NATS server tools
2554
natsTools, err := tools.NewNATSServerTools()
2655
if err != nil {
27-
return nil, fmt.Errorf("failed to initialize NATS tools: %v", err)
56+
return nil, fmt.Errorf("failed to initialize NATS tools: %w", err)
2857
}
2958

3059
// Register all NATS server tools
@@ -33,55 +62,88 @@ func newServer() (*server.MCPServer, error) {
3362
return s, nil
3463
}
3564

36-
func run(transport, addr string) error {
65+
func run(ctx context.Context, cfg *Config) error {
3766
s, err := newServer()
3867
if err != nil {
39-
return err
68+
return fmt.Errorf("failed to create server: %w", err)
4069
}
4170

42-
switch transport {
71+
switch cfg.Transport {
4372
case "stdio":
4473
srv := server.NewStdioServer(s)
4574
srv.SetContextFunc(mcpnats.ComposedStdioContextFunc())
4675
logger.Info("Starting NATS MCP server using stdio transport")
47-
return srv.Listen(context.Background(), os.Stdin, os.Stdout)
76+
return srv.Listen(ctx, os.Stdin, os.Stdout)
77+
4878
case "sse":
4979
srv := server.NewSSEServer(s, server.WithSSEContextFunc(mcpnats.ComposedSSEContextFunc()))
5080
logger.Info("Starting NATS MCP server using SSE transport",
51-
"address", addr,
81+
"address", cfg.SSEAddr,
5282
)
53-
if err := srv.Start(addr); err != nil {
54-
return fmt.Errorf("Server error: %v", err)
83+
84+
errChan := make(chan error, 1)
85+
go func() {
86+
if err := srv.Start(cfg.SSEAddr); err != nil {
87+
errChan <- fmt.Errorf("server error: %w", err)
88+
}
89+
}()
90+
91+
// Wait for either context cancellation or server error
92+
select {
93+
case err := <-errChan:
94+
return err
95+
case <-ctx.Done():
96+
// Give the server some time to shutdown gracefully
97+
shutdownCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
98+
defer cancel()
99+
return srv.Shutdown(shutdownCtx)
55100
}
56-
default:
57-
return fmt.Errorf(
58-
"Invalid transport type: %s. Must be 'stdio' or 'sse'",
59-
transport,
60-
)
61101
}
102+
62103
return nil
63104
}
64105

65106
func main() {
107+
cfg := &Config{}
108+
66109
// Parse command line flags
67-
transport := flag.String("transport", "stdio", "Transport type (stdio or sse)")
68-
sseAddr := flag.String("sse-address", "0.0.0.0:8000", "Address for SSE server to listen on")
69-
logLevel := flag.String("log-level", "info", "Log level (debug, info, warn, error)")
70-
jsonLogs := flag.Bool("json-logs", false, "Output logs in JSON format")
110+
flag.StringVar(&cfg.Transport, "transport", "stdio", "Transport type (stdio or sse)")
111+
flag.StringVar(&cfg.SSEAddr, "sse-address", "0.0.0.0:8000", "Address for SSE server to listen on")
112+
flag.StringVar(&cfg.LogLevel, "log-level", "info", "Log level (debug, info, warn, error)")
113+
flag.BoolVar(&cfg.JSONLogs, "json-logs", false, "Output logs in JSON format")
71114
flag.Parse()
72115

116+
// Validate configuration
117+
if err := validateConfig(cfg); err != nil {
118+
fmt.Fprintf(os.Stderr, "Configuration error: %v\n", err)
119+
os.Exit(1)
120+
}
121+
73122
// Initialize logger
74123
logger.Initialize(logger.Config{
75-
Level: logger.GetLevel(*logLevel),
76-
JSONFormat: *jsonLogs,
124+
Level: logger.GetLevel(cfg.LogLevel),
125+
JSONFormat: cfg.JSONLogs,
77126
})
78127

79128
logger.Info("Starting MCP NATS server",
80-
"transport", *transport,
81-
"version", "0.1.0",
129+
"transport", cfg.Transport,
130+
"version", Version,
82131
)
83132

84-
if err := run(*transport, *sseAddr); err != nil {
133+
// Setup context with cancellation for graceful shutdown
134+
ctx, cancel := context.WithCancel(context.Background())
135+
defer cancel()
136+
137+
// Handle shutdown signals
138+
sigChan := make(chan os.Signal, 1)
139+
signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM)
140+
go func() {
141+
sig := <-sigChan
142+
logger.Info("Received shutdown signal", "signal", sig)
143+
cancel()
144+
}()
145+
146+
if err := run(ctx, cfg); err != nil {
85147
logger.Error("Server failed",
86148
"error", err,
87149
)

0 commit comments

Comments
 (0)