diff --git a/pkg/gateway/capabilitites.go b/pkg/gateway/capabilitites.go index d9a51def..edf40c8e 100644 --- a/pkg/gateway/capabilitites.go +++ b/pkg/gateway/capabilitites.go @@ -7,6 +7,7 @@ import ( "slices" "strings" "sync" + "time" "github.com/google/jsonschema-go/jsonschema" "github.com/modelcontextprotocol/go-sdk/mcp" @@ -110,6 +111,11 @@ func (caps *Capabilities) getResourceTemplateByURITemplate(resource string) (Res return ResourceTemplateRegistration{}, fmt.Errorf("unable to find resource template") } +// perServerTimeout is the maximum time to wait for a single server to respond +// during capability listing. This prevents unreachable servers (e.g. VPN-only +// endpoints) from blocking the entire startup for their full transport timeout. +const perServerTimeout = 15 * time.Second + func (g *Gateway) listCapabilities(ctx context.Context, serverNames []string, clientConfig *clientConfig) (*Capabilities, error) { var ( lock sync.Mutex @@ -128,7 +134,10 @@ func (g *Gateway) listCapabilities(ctx context.Context, serverNames []string, cl // It's an MCP Server case serverConfig != nil: errs.Go(func() error { - client, err := g.clientPool.AcquireClient(ctx, serverConfig, clientConfig) + serverCtx, cancel := context.WithTimeout(ctx, perServerTimeout) + defer cancel() + + client, err := g.clientPool.AcquireClient(serverCtx, serverConfig, clientConfig) if err != nil { log.Logf(" > Can't start %s: %s", serverConfig.Name, err) return nil @@ -137,12 +146,12 @@ func (g *Gateway) listCapabilities(ctx context.Context, serverNames []string, cl var capabilities Capabilities - tools, err := client.Session().ListTools(ctx, &mcp.ListToolsParams{}) + tools, err := client.Session().ListTools(serverCtx, &mcp.ListToolsParams{}) if err != nil { log.Logf(" > Can't list tools %s: %s", serverConfig.Name, err) } else { // Record the number of tools discovered from this server - telemetry.RecordToolList(ctx, serverConfig.Name, len(tools.Tools)) + telemetry.RecordToolList(serverCtx, serverConfig.Name, len(tools.Tools)) // Determine the prefix to use for this server's tools prefix := g.getToolNamePrefix(serverConfig) @@ -164,10 +173,10 @@ func (g *Gateway) listCapabilities(ctx context.Context, serverNames []string, cl } } - prompts, err := client.Session().ListPrompts(ctx, &mcp.ListPromptsParams{}) + prompts, err := client.Session().ListPrompts(serverCtx, &mcp.ListPromptsParams{}) if err == nil { // Record the number of prompts discovered from this server - telemetry.RecordPromptList(ctx, serverConfig.Name, len(prompts.Prompts)) + telemetry.RecordPromptList(serverCtx, serverConfig.Name, len(prompts.Prompts)) for _, prompt := range prompts.Prompts { capabilities.Prompts = append(capabilities.Prompts, PromptRegistration{ @@ -178,10 +187,10 @@ func (g *Gateway) listCapabilities(ctx context.Context, serverNames []string, cl } } - resources, err := client.Session().ListResources(ctx, &mcp.ListResourcesParams{}) + resources, err := client.Session().ListResources(serverCtx, &mcp.ListResourcesParams{}) if err == nil { // Record the number of resources discovered from this server - telemetry.RecordResourceList(ctx, serverConfig.Name, len(resources.Resources)) + telemetry.RecordResourceList(serverCtx, serverConfig.Name, len(resources.Resources)) for _, resource := range resources.Resources { capabilities.Resources = append(capabilities.Resources, ResourceRegistration{ @@ -192,10 +201,10 @@ func (g *Gateway) listCapabilities(ctx context.Context, serverNames []string, cl } } - resourceTemplates, err := client.Session().ListResourceTemplates(ctx, &mcp.ListResourceTemplatesParams{}) + resourceTemplates, err := client.Session().ListResourceTemplates(serverCtx, &mcp.ListResourceTemplatesParams{}) if err == nil { // Record the number of resource templates discovered from this server - telemetry.RecordResourceTemplateList(ctx, serverConfig.Name, len(resourceTemplates.ResourceTemplates)) + telemetry.RecordResourceTemplateList(serverCtx, serverConfig.Name, len(resourceTemplates.ResourceTemplates)) for _, resourceTemplate := range resourceTemplates.ResourceTemplates { capabilities.ResourceTemplates = append(capabilities.ResourceTemplates, ResourceTemplateRegistration{ diff --git a/pkg/gateway/run.go b/pkg/gateway/run.go index 8a98cdd5..91672500 100644 --- a/pkg/gateway/run.go +++ b/pkg/gateway/run.go @@ -327,10 +327,6 @@ func (g *Gateway) Run(ctx context.Context) error { } } - if err := g.reloadConfiguration(ctx, configuration, nil, nil); err != nil { - return fmt.Errorf("loading configuration: %w", err) - } - // When running in Container mode, disable OAuth notification monitoring and authentication inContainer := os.Getenv("DOCKER_MCP_IN_CONTAINER") == "1" @@ -422,12 +418,32 @@ func (g *Gateway) Run(ctx context.Context) error { }() } - log.Log("> Initialized in", time.Since(start)) if g.DryRun { + // In dry-run mode, load capabilities synchronously so we validate + // server configs and report discovered tools before exiting. + if err := g.reloadConfiguration(ctx, configuration, nil, nil); err != nil { + log.Logf("> Initial capability load failed: %s", err) + } + log.Log("> Initialized in", time.Since(start)) log.Log("Dry run mode enabled, not starting the server.") return nil } + // Load initial server capabilities in the background. This connects to + // each configured MCP server to discover its tools, prompts, and + // resources. Unreachable servers (e.g. VPN-only endpoints when off-VPN) + // can take 30s+ to time out, so we do this concurrently with starting the + // transport server. The go-sdk's Server.AddTool is thread-safe and + // automatically sends tools/list_changed notifications to connected + // clients once new tools are registered. + go func() { + if err := g.reloadConfiguration(ctx, configuration, nil, nil); err != nil { + log.Logf("> Initial capability load failed: %s", err) + log.Log("> Gateway is running but no tools are available. Tools will load on next configuration update.") + } + log.Log("> Initialized in", time.Since(start)) + }() + // Initialize authentication token for SSE and streaming modes // Skip authentication when running in container (DOCKER_MCP_IN_CONTAINER=1) transport := strings.ToLower(g.Transport)