diff --git a/cmd/xsql/command_unit_test.go b/cmd/xsql/command_unit_test.go index f904cba..d8b8e3c 100644 --- a/cmd/xsql/command_unit_test.go +++ b/cmd/xsql/command_unit_test.go @@ -8,6 +8,8 @@ import ( "path/filepath" "testing" + "github.com/modelcontextprotocol/go-sdk/mcp" + "github.com/zx06/xsql/internal/app" "github.com/zx06/xsql/internal/config" "github.com/zx06/xsql/internal/errors" @@ -381,6 +383,51 @@ func TestRunMCPServer_ConfigMissing(t *testing.T) { } } +func TestRunMCPServer_StdioTreatsContextCanceledAsCleanExit(t *testing.T) { + prevRun := runMCPStdioServer + runMCPStdioServer = func(ctx context.Context, _ *mcp.Server) error { + cancelCtx, cancel := context.WithCancel(ctx) + cancel() + return cancelCtx.Err() + } + defer func() { + runMCPStdioServer = prevRun + }() + + configPath := filepath.Join(t.TempDir(), "xsql.yaml") + if err := os.WriteFile(configPath, []byte("profiles: {}\nssh_proxies: {}\n"), 0644); err != nil { + t.Fatalf("failed to write config: %v", err) + } + + GlobalConfig.ConfigStr = configPath + err := runMCPServer(&mcpServerOptions{}) + if err != nil { + t.Fatalf("expected nil error for canceled stdio server, got %v", err) + } +} + +func TestRunMCPServer_StdioPropagatesNonCanceledError(t *testing.T) { + prevRun := runMCPStdioServer + wantErr := context.DeadlineExceeded + runMCPStdioServer = func(ctx context.Context, _ *mcp.Server) error { + return wantErr + } + defer func() { + runMCPStdioServer = prevRun + }() + + configPath := filepath.Join(t.TempDir(), "xsql.yaml") + if err := os.WriteFile(configPath, []byte("profiles: {}\nssh_proxies: {}\n"), 0644); err != nil { + t.Fatalf("failed to write config: %v", err) + } + + GlobalConfig.ConfigStr = configPath + err := runMCPServer(&mcpServerOptions{}) + if err != wantErr { + t.Fatalf("expected %v, got %v", wantErr, err) + } +} + func TestResolveMCPServerOptions_Defaults(t *testing.T) { cfg := config.File{ Profiles: map[string]config.Profile{}, diff --git a/cmd/xsql/mcp.go b/cmd/xsql/mcp.go index d47bad1..932fa97 100644 --- a/cmd/xsql/mcp.go +++ b/cmd/xsql/mcp.go @@ -2,6 +2,8 @@ package main import ( "context" + stderrors "errors" + "log" "net/http" "os" "os/signal" @@ -17,6 +19,10 @@ import ( "github.com/zx06/xsql/internal/secret" ) +var runMCPStdioServer = func(ctx context.Context, server *mcp.Server) error { + return server.Run(ctx, &mcp.StdioTransport{}) +} + // NewMCPCommand creates the MCP command group func NewMCPCommand() *cobra.Command { mcpCmd := &cobra.Command{ @@ -75,8 +81,22 @@ func runMCPServer(opts *mcpServerOptions) error { switch resolved.transport { case mcp_pkg.TransportStdio: - ctx := context.Background() - return server.Run(ctx, &mcp.StdioTransport{}) + // Install signal handler for graceful shutdown in stdio mode + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + sigChan := make(chan os.Signal, 1) + signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM) + go func() { + <-sigChan + signal.Stop(sigChan) + cancel() + }() + + if err := runMCPStdioServer(ctx, server); err != nil && !stderrors.Is(err, context.Canceled) { + return err + } + return nil case mcp_pkg.TransportStreamableHTTP: handler, err := mcp_pkg.NewStreamableHTTPHandler(server, resolved.httpAuthToken) if err != nil { @@ -86,8 +106,11 @@ func runMCPServer(opts *mcpServerOptions) error { return errors.Wrap(errors.CodeInternal, "failed to create streamable http handler", nil, err) } httpServer := &http.Server{ - Addr: resolved.httpAddr, - Handler: handler, + Addr: resolved.httpAddr, + Handler: handler, + ReadTimeout: 30 * time.Second, + WriteTimeout: 30 * time.Second, + IdleTimeout: 120 * time.Second, } sigChan := make(chan os.Signal, 1) @@ -95,12 +118,18 @@ func runMCPServer(opts *mcpServerOptions) error { go func() { <-sigChan + signal.Stop(sigChan) ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() - _ = httpServer.Shutdown(ctx) + if shutdownErr := httpServer.Shutdown(ctx); shutdownErr != nil { + log.Printf("[mcp] http server shutdown error: %v", shutdownErr) + } }() - return httpServer.ListenAndServe() + if listenErr := httpServer.ListenAndServe(); listenErr != nil && listenErr != http.ErrServerClosed { + return listenErr + } + return nil default: return errors.New(errors.CodeCfgInvalid, "unsupported mcp transport", map[string]any{"transport": resolved.transport}) } diff --git a/internal/app/conn.go b/internal/app/conn.go index 48175ad..9e1496c 100644 --- a/internal/app/conn.go +++ b/internal/app/conn.go @@ -3,6 +3,7 @@ package app import ( "context" "database/sql" + "sync" "github.com/zx06/xsql/internal/config" "github.com/zx06/xsql/internal/db" @@ -15,6 +16,7 @@ type Connection struct { DB *sql.DB SSHClient *ssh.Client Profile config.Profile + closeMu sync.Mutex closeHooks []func() } @@ -30,7 +32,11 @@ func (c *Connection) Close() error { errs = append(errs, err) } } - for _, fn := range c.closeHooks { + c.closeMu.Lock() + hooks := c.closeHooks + c.closeHooks = nil + c.closeMu.Unlock() + for _, fn := range hooks { if fn != nil { fn() } @@ -81,6 +87,7 @@ func ResolveConnection(ctx context.Context, opts ConnectionOptions) (*Connection } closeHooks := make([]func(), 0, 1) + var hooksMu sync.Mutex connOpts := db.ConnOptions{ DSN: opts.Profile.DSN, Host: opts.Profile.Host, @@ -90,7 +97,9 @@ func ResolveConnection(ctx context.Context, opts ConnectionOptions) (*Connection Database: opts.Profile.Database, RegisterCloseHook: func(fn func()) { if fn != nil { + hooksMu.Lock() closeHooks = append(closeHooks, fn) + hooksMu.Unlock() } }, } diff --git a/internal/db/mysql/driver.go b/internal/db/mysql/driver.go index 43ce78a..52c8d28 100644 --- a/internal/db/mysql/driver.go +++ b/internal/db/mysql/driver.go @@ -17,9 +17,10 @@ import ( ) var ( - dialerCounter uint64 - dialers sync.Map - registeredDials sync.Map + dialerCounter uint64 + dialers sync.Map + registerDialContextFn = mysql.RegisterDialContext + deregisterDialContextFn = mysql.DeregisterDialContext ) func init() { @@ -30,24 +31,30 @@ func registerDialContext(dialer func(context.Context, string, string) (net.Conn, dialerNum := atomic.AddUint64(&dialerCounter, 1) dialName := fmt.Sprintf("xsql_ssh_tunnel_%d", dialerNum) - if _, loaded := registeredDials.LoadOrStore(dialName, true); !loaded { - mysql.RegisterDialContext(dialName, func(ctx context.Context, addr string) (net.Conn, error) { - d, ok := dialers.Load(dialName) - if !ok { - return nil, fmt.Errorf("dialer not found: %s", dialName) - } - fn, ok := d.(func(context.Context, string, string) (net.Conn, error)) - if !ok || fn == nil { - return nil, fmt.Errorf("invalid dialer for: %s", dialName) - } - return fn(ctx, "tcp", addr) - }) - } + registerDialContextFn(dialName, func(ctx context.Context, addr string) (net.Conn, error) { + d, ok := dialers.Load(dialName) + if !ok { + return nil, fmt.Errorf("dialer not found: %s", dialName) + } + fn, ok := d.(func(context.Context, string, string) (net.Conn, error)) + if !ok || fn == nil { + return nil, fmt.Errorf("invalid dialer for: %s", dialName) + } + return fn(ctx, "tcp", addr) + }) dialers.Store(dialName, dialer) return dialName } +func cleanupDialContext(dialName string) { + if dialName == "" { + return + } + dialers.Delete(dialName) + deregisterDialContextFn(dialName) +} + type Driver struct{} func (d *Driver) Open(ctx context.Context, opts db.ConnOptions) (*sql.DB, *errors.XError) { @@ -80,7 +87,7 @@ func (d *Driver) Open(ctx context.Context, opts db.ConnOptions) (*sql.DB, *error cfg.Net = dialName if opts.RegisterCloseHook != nil { opts.RegisterCloseHook(func() { - dialers.Delete(dialName) + cleanupDialContext(dialName) }) } } @@ -94,9 +101,7 @@ func (d *Driver) Open(ctx context.Context, opts db.ConnOptions) (*sql.DB, *error if closeErr := conn.Close(); closeErr != nil { log.Printf("failed to close mysql connection: %v", closeErr) } - if dialName != "" { - dialers.Delete(dialName) - } + cleanupDialContext(dialName) return nil, errors.Wrap(errors.CodeDBConnectFailed, "failed to ping mysql", nil, err) } return conn, nil diff --git a/internal/db/mysql/driver_test.go b/internal/db/mysql/driver_test.go index 56a672f..1ab0641 100644 --- a/internal/db/mysql/driver_test.go +++ b/internal/db/mysql/driver_test.go @@ -4,6 +4,7 @@ import ( "context" "net" "sync" + "sync/atomic" "testing" "time" @@ -197,7 +198,15 @@ func TestDriver_Open_ContextCancelled(t *testing.T) { func TestDriver_Open_WithDialer_CleanupOnFailure(t *testing.T) { resetSyncMap(&dialers) - resetSyncMap(®isteredDials) + var deregisterCalls int32 + prevDeregister := deregisterDialContextFn + deregisterDialContextFn = func(net string) { + atomic.AddInt32(&deregisterCalls, 1) + prevDeregister(net) + } + defer func() { + deregisterDialContextFn = prevDeregister + }() drv, _ := db.Get("mysql") ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) @@ -227,11 +236,17 @@ func TestDriver_Open_WithDialer_CleanupOnFailure(t *testing.T) { if countSyncMap(&dialers) != 0 { t.Fatal("expected dialers map to be cleaned on open failure") } + if got := atomic.LoadInt32(&deregisterCalls); got != 1 { + t.Fatalf("expected one deregister call on open failure, got %d", got) + } hooks[0]() if countSyncMap(&dialers) != 0 { t.Fatal("expected hook cleanup to be idempotent") } + if got := atomic.LoadInt32(&deregisterCalls); got != 2 { + t.Fatalf("expected close hook cleanup to remain safe, got %d deregister calls", got) + } } func countSyncMap(m *sync.Map) int { diff --git a/internal/mcp/tools.go b/internal/mcp/tools.go index 8c11c5c..2833ec0 100644 --- a/internal/mcp/tools.go +++ b/internal/mcp/tools.go @@ -34,6 +34,12 @@ type ToolHandler struct { // NewToolHandler creates a new tool handler func NewToolHandler(cfg *config.File) *ToolHandler { + if cfg == nil { + cfg = &config.File{ + Profiles: map[string]config.Profile{}, + SSHProxies: map[string]config.SSHProxy{}, + } + } return &ToolHandler{ config: cfg, } diff --git a/internal/proxy/proxy.go b/internal/proxy/proxy.go index bdd4974..24cf64e 100644 --- a/internal/proxy/proxy.go +++ b/internal/proxy/proxy.go @@ -151,16 +151,22 @@ func (p *Proxy) handleConnection(localConn net.Conn, remoteAddr string) { // Bidirectional copy var wg sync.WaitGroup - wg.Add(2) + errChan := make(chan error, 2) + wg.Add(1) go func() { defer wg.Done() - _, _ = io.Copy(localConn, remoteConn) + if _, err := io.Copy(localConn, remoteConn); err != nil { + errChan <- fmt.Errorf("copy remote->local failed: %w", err) + } }() + wg.Add(1) go func() { defer wg.Done() - _, _ = io.Copy(remoteConn, localConn) + if _, err := io.Copy(remoteConn, localConn); err != nil { + errChan <- fmt.Errorf("copy local->remote failed: %w", err) + } }() // Wait for both copies to finish or context cancellation @@ -172,7 +178,24 @@ func (p *Proxy) handleConnection(localConn net.Conn, remoteAddr string) { select { case <-done: + // Check if there were any copy errors + select { + case err := <-errChan: + log.Printf("[proxy] connection copy error: %v", err) + default: + } case <-p.ctx.Done(): + // Context cancelled: close connections to unblock io.Copy goroutines + _ = localConn.Close() + _ = remoteConn.Close() + // Wait for goroutines to finish + <-done + // Check for any final errors + select { + case err := <-errChan: + log.Printf("[proxy] connection copy error on shutdown: %v", err) + default: + } } } diff --git a/tests/e2e/mcp_test.go b/tests/e2e/mcp_test.go index dc156fa..89da464 100644 --- a/tests/e2e/mcp_test.go +++ b/tests/e2e/mcp_test.go @@ -8,7 +8,11 @@ import ( "os" "os/exec" "path/filepath" + "runtime" + "strings" + "syscall" "testing" + "time" "github.com/modelcontextprotocol/go-sdk/mcp" ) @@ -234,6 +238,51 @@ func TestMCPServer_EmptyConfig(t *testing.T) { } } +func TestMCPServer_SIGINTCleanExit(t *testing.T) { + if runtime.GOOS == "windows" { + t.Skip("requires POSIX signals") + } + + config := createTempConfig(t, `profiles: {}`) + + cmd := exec.Command(testBinary, "mcp", "server", "--config", config) + var stdout, stderr bytes.Buffer + cmd.Stdout = &stdout + cmd.Stderr = &stderr + + if err := cmd.Start(); err != nil { + t.Fatalf("failed to start MCP server: %v", err) + } + + time.Sleep(200 * time.Millisecond) + + if err := cmd.Process.Signal(syscall.SIGINT); err != nil { + t.Fatalf("failed to send SIGINT: %v", err) + } + + done := make(chan error, 1) + go func() { + done <- cmd.Wait() + }() + + select { + case err := <-done: + if err != nil { + if exitErr, ok := err.(*exec.ExitError); ok { + t.Fatalf("expected clean exit, got exit code %d, stderr: %s", exitErr.ExitCode(), stderr.String()) + } + t.Fatalf("expected clean exit, got %v", err) + } + case <-time.After(5 * time.Second): + _ = cmd.Process.Kill() + t.Fatal("mcp server did not exit after SIGINT") + } + + if strings.Contains(stderr.String(), "context canceled") { + t.Fatalf("stderr should not contain context canceled, got: %s", stderr.String()) + } +} + // listMCPTools lists all available MCP tools func listMCPTools(t *testing.T, configPath string) []mcp.Tool { t.Helper()