diff --git a/pkg/vmcp/server/server.go b/pkg/vmcp/server/server.go index 2d267ce4a3..c49b7e1d23 100644 --- a/pkg/vmcp/server/server.go +++ b/pkg/vmcp/server/server.go @@ -16,6 +16,7 @@ import ( "log/slog" "net" "net/http" + "strconv" "strings" "sync" "time" @@ -66,6 +67,13 @@ const ( // defaultSessionTTL is the default session time-to-live duration. // Sessions that are inactive for this duration will be automatically cleaned up. defaultSessionTTL = 30 * time.Minute + + // defaultIdleCheckInterval is how often the idle reaper scans for inactive sessions. + defaultIdleCheckInterval = time.Minute + + // defaultRetryAfterSeconds is the Retry-After value returned with HTTP 503 + // when the global session limit is reached. + defaultRetryAfterSeconds = 30 ) //go:generate mockgen -destination=mocks/mock_watcher.go -package=mocks -source=server.go Watcher @@ -161,6 +169,21 @@ type Config struct { // SessionFactory creates MultiSessions for session management. // Required; must not be nil. SessionFactory vmcpsession.MultiSessionFactory + + // MaxSessions is the global concurrent session limit when SessionManagementV2 is enabled. + // Requests that would exceed this limit receive HTTP 503 with a Retry-After header. + // 0 uses the default (100). Requires SessionManagementV2 = true. + MaxSessions int + + // MaxSessionsPerClient is the per-identity session limit when SessionManagementV2 is enabled. + // Keyed by auth.Identity.Subject; anonymous clients are not limited. + // 0 uses the default (10). Requires SessionManagementV2 = true. + MaxSessionsPerClient int + + // IdleSessionTimeout is the duration after which inactive sessions are proactively + // expired when SessionManagementV2 is enabled. Must be ≤ SessionTTL. + // 0 uses the default (5 minutes). Requires SessionManagementV2 = true. + IdleSessionTimeout time.Duration } // Server is the Virtual MCP Server that aggregates multiple backends. @@ -275,6 +298,24 @@ func New( if cfg.SessionTTL == 0 { cfg.SessionTTL = defaultSessionTTL } + if cfg.MaxSessions == 0 { + cfg.MaxSessions = sessionmanager.DefaultMaxSessions + } + if cfg.MaxSessionsPerClient == 0 { + cfg.MaxSessionsPerClient = sessionmanager.DefaultMaxSessionsPerClient + } + if cfg.IdleSessionTimeout == 0 { + cfg.IdleSessionTimeout = sessionmanager.DefaultIdleSessionTimeout + } + // IdleSessionTimeout must not exceed SessionTTL: if it did, the transport + // TTL reaper could evict sessions before the idle reaper fires, leaving + // per-client counters and idle-tracking maps stale. + if cfg.IdleSessionTimeout > cfg.SessionTTL { + slog.Warn("IdleSessionTimeout exceeds SessionTTL; clamping to SessionTTL", + "idle_session_timeout", cfg.IdleSessionTimeout, + "session_ttl", cfg.SessionTTL) + cfg.IdleSessionTimeout = cfg.SessionTTL + } // Create hooks for SDK integration hooks := &server.Hooks{} @@ -392,7 +433,12 @@ func New( if cfg.SessionFactory == nil { return nil, fmt.Errorf("SessionFactory is required but was not provided") } - vmcpSessMgr := sessionmanager.New(sessionManager, cfg.SessionFactory, backendRegistry) + limits := sessionmanager.Limits{ + MaxSessions: cfg.MaxSessions, + MaxSessionsPerClient: cfg.MaxSessionsPerClient, + IdleSessionTimeout: cfg.IdleSessionTimeout, + } + vmcpSessMgr := sessionmanager.New(sessionManager, cfg.SessionFactory, backendRegistry, limits) // Create Server instance srv := &Server{ @@ -548,6 +594,13 @@ func (s *Server) Handler(_ context.Context) (http.Handler, error) { slog.Info("audit middleware enabled for MCP endpoints") } + // Apply session limit middleware when V2 session management is active. + // Runs before auth so over-limit requests are rejected early without auth overhead. + if s.vmcpSessionMgr != nil && s.config.MaxSessions > 0 { + mcpHandler = s.sessionLimitMiddleware(mcpHandler) + slog.Info("session limit middleware enabled", "max_sessions", s.config.MaxSessions) + } + // Apply authentication middleware if configured (runs first in chain) if s.config.AuthMiddleware != nil { mcpHandler = s.config.AuthMiddleware(mcpHandler) @@ -566,6 +619,37 @@ func (s *Server) Handler(_ context.Context) (http.Handler, error) { return mux, nil } +// sessionLimitMiddleware is a best-effort fast-fail for new session requests +// (no Mcp-Session-Id header): it returns HTTP 503 + Retry-After before the +// request reaches the SDK when the global session cap appears to be reached. +// Existing sessions (with a valid Mcp-Session-Id) are never affected. +// +// This check is intentionally optimistic (non-atomic): it avoids the overhead +// of routing and SDK processing for clearly-over-limit requests, but it does +// not guarantee strict enforcement under concurrent load. Strict enforcement +// is provided atomically by sessionmanager.Manager.Generate(), which uses an +// increment-first reservation to prevent races between concurrent initialize +// requests. +func (s *Server) sessionLimitMiddleware(next http.Handler) http.Handler { + // Resolve the concrete manager once so we can call ActiveSessionCount(). + mgr, _ := s.vmcpSessionMgr.(*sessionmanager.Manager) + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Header.Get("Mcp-Session-Id") == "" && mgr != nil { + if mgr.ActiveSessionCount() >= s.config.MaxSessions { + w.Header().Set("Retry-After", strconv.Itoa(defaultRetryAfterSeconds)) + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusServiceUnavailable) + _, _ = w.Write([]byte( + `{"error":{"code":-32000,"message":"Maximum concurrent sessions exceeded. ` + + `Please try again later or contact administrator."}}`, + )) + return + } + } + next.ServeHTTP(w, r) + }) +} + // Start starts the Virtual MCP Server and begins serving requests. // //nolint:gocyclo // Complexity from health monitoring and startup orchestration is acceptable @@ -658,6 +742,19 @@ func (s *Server) Start(ctx context.Context) error { } } + // Start idle session reaper if V2 session management is active with an idle timeout. + if mgr, ok := s.vmcpSessionMgr.(*sessionmanager.Manager); ok && s.config.IdleSessionTimeout > 0 { + idleCtx, idleCancel := context.WithCancel(ctx) + mgr.StartIdleReaper(idleCtx, defaultIdleCheckInterval) + slog.Info("idle session reaper started", + "idle_timeout", s.config.IdleSessionTimeout, + "check_interval", defaultIdleCheckInterval) + s.shutdownFuncs = append(s.shutdownFuncs, func(context.Context) error { + idleCancel() + return nil + }) + } + // Start status reporter if configured if s.statusReporter != nil { shutdown, err := s.statusReporter.Start(ctx) diff --git a/pkg/vmcp/server/sessionmanager/session_manager.go b/pkg/vmcp/server/sessionmanager/session_manager.go index 71532abc0d..79406ab4b1 100644 --- a/pkg/vmcp/server/sessionmanager/session_manager.go +++ b/pkg/vmcp/server/sessionmanager/session_manager.go @@ -18,6 +18,9 @@ import ( "errors" "fmt" "log/slog" + "sync" + "sync/atomic" + "time" "github.com/google/uuid" "github.com/mark3labs/mcp-go/mcp" @@ -39,8 +42,47 @@ const ( // MetadataValTrue is the string value stored under MetadataKeyTerminated // when a session has been terminated. MetadataValTrue = "true" + + // DefaultMaxSessions is the default global concurrent session limit. + // 0 disables the global limit. + DefaultMaxSessions = 100 + + // DefaultMaxSessionsPerClient is the default per-identity session limit. + // 0 disables the per-client limit. + DefaultMaxSessionsPerClient = 10 + + // DefaultIdleSessionTimeout is the default duration after which inactive + // sessions are proactively expired. Must be ≤ session TTL. + // 0 disables idle expiry. + DefaultIdleSessionTimeout = 5 * time.Minute + + // defaultIdleCheckInterval is how often the idle reaper scans for idle sessions. + defaultIdleCheckInterval = time.Minute ) +// ErrSessionLimitReached is returned when the global session limit is hit. +var ErrSessionLimitReached = errors.New("maximum concurrent sessions exceeded") + +// ErrPerClientSessionLimitReached is returned when the per-client session limit is hit. +var ErrPerClientSessionLimitReached = errors.New("maximum sessions per client exceeded") + +// Limits configures resource-exhaustion protections for the Manager. +type Limits struct { + // MaxSessions is the maximum number of concurrent sessions globally. + // 0 means unlimited. + MaxSessions int + + // MaxSessionsPerClient is the maximum concurrent sessions per client identity, + // keyed by auth.Identity.Subject. Anonymous clients (no Subject) are not limited. + // 0 means unlimited. + MaxSessionsPerClient int + + // IdleSessionTimeout is the maximum duration a session may be inactive + // before it is proactively expired. Must be ≤ the session TTL. + // 0 disables idle expiry. + IdleSessionTimeout time.Duration +} + // Manager bridges the domain session lifecycle (MultiSession / MultiSessionFactory) // to the mark3labs SDK's SessionIdManager interface. // @@ -68,19 +110,40 @@ type Manager struct { storage *transportsession.Manager factory vmcpsession.MultiSessionFactory backendRegistry vmcp.BackendRegistry + limits Limits + + // perClientMu guards perClientCounts and sessionSubject. + perClientMu sync.Mutex + perClientCounts map[string]int // subject → active session count + sessionSubject map[string]string // sessionID → subject (for decrement on Terminate) + + // idleActivityMu guards idleActivity. + idleActivityMu sync.RWMutex + idleActivity map[string]time.Time // sessionID → last active time + + // activeSessionCount tracks sessions that have been generated but not yet + // terminated, excluding terminated placeholders left for TTL cleanup. + // This gives an accurate count for global limit enforcement, unlike + // storage.Count() which includes those terminated placeholders. + activeSessionCount atomic.Int64 } // New creates a Manager backed by the given transport manager, session factory, -// and backend registry. +// backend registry, and resource-exhaustion limits. func New( storage *transportsession.Manager, factory vmcpsession.MultiSessionFactory, backendRegistry vmcp.BackendRegistry, + limits Limits, ) *Manager { return &Manager{ storage: storage, factory: factory, backendRegistry: backendRegistry, + limits: limits, + perClientCounts: make(map[string]int), + sessionSubject: make(map[string]string), + idleActivity: make(map[string]time.Time), } } @@ -93,6 +156,22 @@ func New( // The placeholder is replaced by CreateSession() in Phase 2 once context // is available via the OnRegisterSession hook. func (sm *Manager) Generate() string { + // Atomically claim a slot before allocating storage. Incrementing first + // (rather than Load → check → Add) eliminates the TOCTOU race where + // concurrent initialize requests all observe Count < MaxSessions and all + // proceed past the cap. If the incremented value exceeds the cap, or if + // storage allocation fails, the slot is released immediately. + if sm.limits.MaxSessions > 0 { + if int(sm.activeSessionCount.Add(1)) > sm.limits.MaxSessions { + sm.activeSessionCount.Add(-1) + slog.Warn("Manager: session limit reached, rejecting new session", + "active", sm.activeSessionCount.Load(), + "max", sm.limits.MaxSessions, + "error", ErrSessionLimitReached) + return "" + } + } + sessionID := uuid.New().String() if err := sm.storage.AddWithID(sessionID); err != nil { @@ -101,10 +180,17 @@ func (sm *Manager) Generate() string { sessionID = uuid.New().String() if err := sm.storage.AddWithID(sessionID); err != nil { slog.Error("Manager: failed to store placeholder session on retry", "session_id", sessionID, "error", err) + if sm.limits.MaxSessions > 0 { + sm.activeSessionCount.Add(-1) + } return "" } } + if sm.limits.MaxSessions <= 0 { + // Unlimited: count is not pre-incremented above, so increment here. + sm.activeSessionCount.Add(1) + } slog.Debug("Manager: generated placeholder session", "session_id", sessionID) return sessionID } @@ -151,6 +237,12 @@ func (sm *Manager) CreateSession( // Resolve the caller identity (may be nil for anonymous access). identity, _ := auth.IdentityFromContext(ctx) + // Enforce per-client session limit for identified callers. + perClientIncremented, err := sm.enforcePerClientLimit(sessionID, identity) + if err != nil { + return nil, err + } + // Note: Token hash and salt are computed and stored by the session factory // (MakeSessionWithID below). Token binding enforcement happens at the session // level via validateCaller(), which uses HMAC-SHA256 with a per-session salt. @@ -168,6 +260,9 @@ func (sm *Manager) CreateSession( allowAnonymous := sessiontypes.ShouldAllowAnonymous(identity) sess, err := sm.factory.MakeSessionWithID(ctx, sessionID, identity, allowAnonymous, backends) if err != nil { + if perClientIncremented { + sm.decrementPerClientCount(sessionID) + } return nil, fmt.Errorf("Manager.CreateSession: failed to create multi-session: %w", err) } @@ -180,6 +275,9 @@ func (sm *Manager) CreateSession( placeholder2, exists := sm.storage.Get(sessionID) if !exists { _ = sess.Close() + if perClientIncremented { + sm.decrementPerClientCount(sessionID) + } return nil, fmt.Errorf( "Manager.CreateSession: placeholder for session %q disappeared during backend init (terminated concurrently)", sessionID, @@ -187,6 +285,9 @@ func (sm *Manager) CreateSession( } if placeholder2.GetMetadata()[MetadataKeyTerminated] == MetadataValTrue { _ = sess.Close() + if perClientIncremented { + sm.decrementPerClientCount(sessionID) + } return nil, fmt.Errorf( "Manager.CreateSession: session %q was terminated during backend init (marked after first check)", sessionID, @@ -200,9 +301,15 @@ func (sm *Manager) CreateSession( if err := sm.storage.UpsertSession(sess); err != nil { // Best-effort close of the newly created session to release backend connections. _ = sess.Close() + if perClientIncremented { + sm.decrementPerClientCount(sessionID) + } return nil, fmt.Errorf("Manager.CreateSession: failed to replace placeholder: %w", err) } + // Session is fully established — start the idle clock. + sm.resetIdleActivity(sessionID) + slog.Debug("Manager: created multi-session", "session_id", sessionID, "backend_count", len(backends)) @@ -263,6 +370,11 @@ func (sm *Manager) Terminate(sessionID string) (isNotAllowed bool, err error) { sess, exists := sm.storage.Get(sessionID) if !exists { slog.Debug("Manager.Terminate: session not found (already expired?)", "session_id", sessionID) + // The storage entry may have been removed by TTL cleanup racing with + // Terminate(). Clean up any in-memory map entries that may be left behind + // to prevent per-client counts from sticking and stale idle-reap entries. + sm.decrementPerClientCount(sessionID) + sm.removeIdleActivity(sessionID) return false, nil } @@ -276,6 +388,9 @@ func (sm *Manager) Terminate(sessionID string) (isNotAllowed bool, err error) { if deleteErr := sm.storage.Delete(sessionID); deleteErr != nil { return false, fmt.Errorf("Manager.Terminate: failed to delete session from storage: %w", deleteErr) } + sm.activeSessionCount.Add(-1) + sm.decrementPerClientCount(sessionID) + sm.removeIdleActivity(sessionID) } else { // Placeholder session (not yet upgraded to MultiSession). // @@ -294,6 +409,7 @@ func (sm *Manager) Terminate(sessionID string) (isNotAllowed bool, err error) { // We mark (not delete) so Validate() can return isTerminated=true, which // lets the SDK distinguish "actively terminated" from "never existed". // TTL cleanup will remove the placeholder later. + sm.activeSessionCount.Add(-1) sess.SetMetadata(MetadataKeyTerminated, MetadataValTrue) if replaceErr := sm.storage.UpsertSession(sess); replaceErr != nil { slog.Warn("Manager.Terminate: failed to persist terminated flag for placeholder; attempting delete fallback", @@ -382,6 +498,10 @@ func (sm *Manager) GetAdaptedTools(sessionID string) ([]mcpserver.ServerTool, er return mcp.NewToolResultError(callErr.Error()), nil } + // Reset idle clock after the tool call completes so long-running tools + // are not reaped mid-execution by the idle reaper. + sm.resetIdleActivity(capturedSessionID) + return &mcp.CallToolResult{ Result: mcp.Result{ Meta: conversion.ToMCPMeta(result.Meta), @@ -463,3 +583,134 @@ func (sm *Manager) GetAdaptedResources(sessionID string) ([]mcpserver.ServerReso return sdkResources, nil } + +// ActiveSessionCount returns the number of sessions that have been generated +// but not yet terminated. Unlike storage.Count(), this excludes terminated +// placeholders left in storage for TTL cleanup, giving an accurate measure +// for global session limit enforcement. +func (sm *Manager) ActiveSessionCount() int { + return int(sm.activeSessionCount.Load()) +} + +// --------------------------------------------------------------------------- +// Per-client session limit helpers +// --------------------------------------------------------------------------- + +// enforcePerClientLimit checks and increments the per-client session count for the +// given identity. Returns (true, nil) when the count was incremented, (false, nil) +// for anonymous sessions (not subject to limiting), and (false, err) when the limit +// is exceeded. The caller must call decrementPerClientCount on any failure path when +// the returned bool is true. +func (sm *Manager) enforcePerClientLimit(sessionID string, identity *auth.Identity) (bool, error) { + subject := identitySubject(identity) + if sm.limits.MaxSessionsPerClient <= 0 || subject == "" { + return false, nil + } + sm.perClientMu.Lock() + defer sm.perClientMu.Unlock() + if sm.perClientCounts[subject] >= sm.limits.MaxSessionsPerClient { + return false, fmt.Errorf("%w: subject %q", ErrPerClientSessionLimitReached, subject) + } + sm.perClientCounts[subject]++ + sm.sessionSubject[sessionID] = subject + return true, nil +} + +// decrementPerClientCount removes the per-client counter entry for sessionID. +// It is safe to call even if the session was never counted (anonymous sessions). +func (sm *Manager) decrementPerClientCount(sessionID string) { + sm.perClientMu.Lock() + defer sm.perClientMu.Unlock() + subject, ok := sm.sessionSubject[sessionID] + if !ok { + return + } + delete(sm.sessionSubject, sessionID) + if sm.perClientCounts[subject] > 0 { + sm.perClientCounts[subject]-- + } +} + +// identitySubject returns the Subject claim for identity-based rate limiting. +// Returns "" for nil identities or identities without a Subject, which opts +// them out of per-client limiting. +func identitySubject(identity *auth.Identity) string { + if identity == nil { + return "" + } + return identity.Subject +} + +// --------------------------------------------------------------------------- +// Idle session timeout helpers +// --------------------------------------------------------------------------- + +// resetIdleActivity records the current time as the last-active timestamp for +// sessionID. Called on session creation and on every tool call. +// No-op when IdleSessionTimeout is zero (idle tracking disabled). +func (sm *Manager) resetIdleActivity(sessionID string) { + if sm.limits.IdleSessionTimeout <= 0 { + return + } + sm.idleActivityMu.Lock() + sm.idleActivity[sessionID] = time.Now() + sm.idleActivityMu.Unlock() +} + +// removeIdleActivity removes the idle-tracking entry for sessionID. +// Called from Terminate() so the reaper does not attempt to re-terminate. +func (sm *Manager) removeIdleActivity(sessionID string) { + sm.idleActivityMu.Lock() + delete(sm.idleActivity, sessionID) + sm.idleActivityMu.Unlock() +} + +// reapIdleSessions terminates any sessions that have been inactive longer than +// the configured IdleSessionTimeout. +func (sm *Manager) reapIdleSessions() { + cutoff := time.Now().Add(-sm.limits.IdleSessionTimeout) + + sm.idleActivityMu.RLock() + var toTerminate []string + for sessionID, lastActive := range sm.idleActivity { + if lastActive.Before(cutoff) { + toTerminate = append(toTerminate, sessionID) + } + } + sm.idleActivityMu.RUnlock() + + for _, sessionID := range toTerminate { + slog.Info("Manager: terminating idle session", + "session_id", sessionID, + "idle_timeout", sm.limits.IdleSessionTimeout) + if _, err := sm.Terminate(sessionID); err != nil { + slog.Warn("Manager: failed to terminate idle session", + "session_id", sessionID, "error", err) + } + } +} + +// StartIdleReaper starts a background goroutine that periodically calls +// reapIdleSessions. It is a no-op when IdleSessionTimeout is zero (disabled). +// The goroutine is stopped when ctx is cancelled; the caller should add a +// cancel to shutdownFuncs to ensure cleanup on server Stop(). +func (sm *Manager) StartIdleReaper(ctx context.Context, interval time.Duration) { + if sm.limits.IdleSessionTimeout <= 0 { + return + } + if interval <= 0 { + interval = defaultIdleCheckInterval + } + go func() { + ticker := time.NewTicker(interval) + defer ticker.Stop() + for { + select { + case <-ctx.Done(): + return + case <-ticker.C: + sm.reapIdleSessions() + } + } + }() +} diff --git a/pkg/vmcp/server/sessionmanager/session_manager_test.go b/pkg/vmcp/server/sessionmanager/session_manager_test.go index 481a542721..8373326555 100644 --- a/pkg/vmcp/server/sessionmanager/session_manager_test.go +++ b/pkg/vmcp/server/sessionmanager/session_manager_test.go @@ -162,7 +162,108 @@ func newTestSessionManager( ) (*Manager, *transportsession.Manager) { t.Helper() storage := newTestTransportManager(t) - return New(storage, factory, registry), storage + return New(storage, factory, registry, Limits{}), storage +} + +// newTestVMCPSessionManager is an alias for newTestSessionManager using default (zero) Limits. +func newTestVMCPSessionManager( + t *testing.T, + factory vmcpsession.MultiSessionFactory, + registry vmcp.BackendRegistry, +) (*Manager, *transportsession.Manager) { + return newTestSessionManager(t, factory, registry) +} + +// fakeMultiSession is a minimal in-process MultiSession implementation for tests. +type fakeMultiSession struct { + transportsession.Session + tools []vmcp.Tool + closed bool + callToolResult *vmcp.ToolCallResult + callToolErr error + lastCallMeta map[string]any +} + +func newFakeMultiSession(sess transportsession.Session, tools []vmcp.Tool) *fakeMultiSession { + return &fakeMultiSession{Session: sess, tools: tools} +} + +func (f *fakeMultiSession) Tools() []vmcp.Tool { + result := make([]vmcp.Tool, len(f.tools)) + copy(result, f.tools) + return result +} +func (*fakeMultiSession) Resources() []vmcp.Resource { return nil } +func (*fakeMultiSession) Prompts() []vmcp.Prompt { return nil } +func (*fakeMultiSession) BackendSessions() map[string]string { return nil } +func (*fakeMultiSession) GetRoutingTable() *vmcp.RoutingTable { return nil } +func (f *fakeMultiSession) CallTool( + _ context.Context, _ *auth.Identity, _ string, _ map[string]any, meta map[string]any, +) (*vmcp.ToolCallResult, error) { + f.lastCallMeta = meta + return f.callToolResult, f.callToolErr +} +func (*fakeMultiSession) ReadResource(_ context.Context, _ *auth.Identity, _ string) (*vmcp.ResourceReadResult, error) { + return nil, errors.New("not implemented") +} +func (*fakeMultiSession) GetPrompt(_ context.Context, _ *auth.Identity, _ string, _ map[string]any) (*vmcp.PromptGetResult, error) { + return nil, errors.New("not implemented") +} +func (f *fakeMultiSession) Close() error { + f.closed = true + return nil +} + +// fakeMultiSessionFactory is a configurable MultiSessionFactory for tests. +type fakeMultiSessionFactory struct { + tools []vmcp.Tool + err error + createdSessions map[string]*fakeMultiSession + delay time.Duration +} + +func newFakeFactory(tools []vmcp.Tool) *fakeMultiSessionFactory { + return &fakeMultiSessionFactory{ + tools: tools, + createdSessions: make(map[string]*fakeMultiSession), + } +} + +func (f *fakeMultiSessionFactory) MakeSession( + _ context.Context, _ *auth.Identity, _ []*vmcp.Backend, +) (vmcpsession.MultiSession, error) { + if f.err != nil { + return nil, f.err + } + sess := newFakeMultiSession(transportsession.NewStreamableSession("auto-id"), f.tools) + f.createdSessions["auto-id"] = sess + return sess, nil +} + +func (f *fakeMultiSessionFactory) MakeSessionWithID( + _ context.Context, id string, _ *auth.Identity, _ bool, _ []*vmcp.Backend, +) (vmcpsession.MultiSession, error) { + if f.delay > 0 { + time.Sleep(f.delay) + } + if f.err != nil { + return nil, f.err + } + sess := newFakeMultiSession(transportsession.NewStreamableSession(id), f.tools) + f.createdSessions[id] = sess + return sess, nil +} + +// newTestVMCPSessionManagerWithLimits creates a Manager with explicit resource limits. +func newTestVMCPSessionManagerWithLimits( + t *testing.T, + factory vmcpsession.MultiSessionFactory, + registry vmcp.BackendRegistry, + limits Limits, +) (*Manager, *transportsession.Manager) { + t.Helper() + storage := newTestTransportManager(t) + return New(storage, factory, registry, limits), storage } // --------------------------------------------------------------------------- @@ -203,10 +304,8 @@ func TestSessionManager_Generate(t *testing.T) { ) t.Cleanup(func() { _ = failingMgr.Stop() }) - ctrl := gomock.NewController(t) - sess := newMockSession(t, ctrl, "placeholder", nil) - factory := newMockFactory(t, ctrl, sess) - sm := New(failingMgr, factory, newFakeRegistry()) + factory := newFakeFactory(nil) + sm := New(failingMgr, factory, newFakeRegistry(), Limits{}) id := sm.Generate() assert.Empty(t, id, "Generate() should return '' when storage is unavailable") @@ -666,7 +765,7 @@ func TestSessionManager_Terminate(t *testing.T) { failingStorage, ) t.Cleanup(func() { _ = storage.Stop() }) - sm := New(storage, factory, registry) + sm := New(storage, factory, registry, Limits{}) // Generate a placeholder (first Store, succeeds). sessionID := sm.Generate() @@ -705,7 +804,7 @@ func TestSessionManager_Terminate(t *testing.T) { failingStorage, ) t.Cleanup(func() { _ = storage.Stop() }) - sm := New(storage, factory, registry) + sm := New(storage, factory, registry, Limits{}) // Generate a placeholder (first Store, succeeds). sessionID := sm.Generate() @@ -1389,3 +1488,323 @@ func newCallToolRequest(name string, args map[string]any) mcp.CallToolRequest { req.Params.Arguments = args return req } + +// --------------------------------------------------------------------------- +// Tests: per-client session limit +// --------------------------------------------------------------------------- + +func TestVMCPSessionManager_PerClientLimit(t *testing.T) { + t.Parallel() + + t.Run("allows sessions up to MaxSessionsPerClient", func(t *testing.T) { + t.Parallel() + + factory := newFakeFactory(nil) + registry := newFakeRegistry() + sm, _ := newTestVMCPSessionManagerWithLimits(t, factory, registry, Limits{MaxSessionsPerClient: 2}) + + identity := &auth.Identity{Subject: "user-1"} + ctx := auth.WithIdentity(context.Background(), identity) + + id1 := sm.Generate() + _, err := sm.CreateSession(ctx, id1) + require.NoError(t, err) + + id2 := sm.Generate() + _, err = sm.CreateSession(ctx, id2) + require.NoError(t, err) + }) + + t.Run("rejects session when per-client limit reached", func(t *testing.T) { + t.Parallel() + + factory := newFakeFactory(nil) + registry := newFakeRegistry() + sm, _ := newTestVMCPSessionManagerWithLimits(t, factory, registry, Limits{MaxSessionsPerClient: 1}) + + identity := &auth.Identity{Subject: "user-1"} + ctx := auth.WithIdentity(context.Background(), identity) + + id1 := sm.Generate() + _, err := sm.CreateSession(ctx, id1) + require.NoError(t, err) + + id2 := sm.Generate() + _, err = sm.CreateSession(ctx, id2) + require.ErrorIs(t, err, ErrPerClientSessionLimitReached) + }) + + t.Run("count is decremented after Terminate, allowing new session", func(t *testing.T) { + t.Parallel() + + factory := newFakeFactory(nil) + registry := newFakeRegistry() + sm, _ := newTestVMCPSessionManagerWithLimits(t, factory, registry, Limits{MaxSessionsPerClient: 1}) + + identity := &auth.Identity{Subject: "user-1"} + ctx := auth.WithIdentity(context.Background(), identity) + + id1 := sm.Generate() + _, err := sm.CreateSession(ctx, id1) + require.NoError(t, err) + + _, err = sm.Terminate(id1) + require.NoError(t, err) + + id2 := sm.Generate() + _, err = sm.CreateSession(ctx, id2) + require.NoError(t, err, "should allow new session after previous was terminated") + }) + + t.Run("anonymous sessions (no Subject) are not limited", func(t *testing.T) { + t.Parallel() + + factory := newFakeFactory(nil) + registry := newFakeRegistry() + sm, _ := newTestVMCPSessionManagerWithLimits(t, factory, registry, Limits{MaxSessionsPerClient: 1}) + + // No identity in context → anonymous. + ctx := context.Background() + + id1 := sm.Generate() + _, err := sm.CreateSession(ctx, id1) + require.NoError(t, err) + + id2 := sm.Generate() + _, err = sm.CreateSession(ctx, id2) + require.NoError(t, err, "anonymous sessions should not be subject to per-client limit") + }) + + t.Run("different subjects have independent counts", func(t *testing.T) { + t.Parallel() + + factory := newFakeFactory(nil) + registry := newFakeRegistry() + sm, _ := newTestVMCPSessionManagerWithLimits(t, factory, registry, Limits{MaxSessionsPerClient: 1}) + + ctx1 := auth.WithIdentity(context.Background(), &auth.Identity{Subject: "user-a"}) + ctx2 := auth.WithIdentity(context.Background(), &auth.Identity{Subject: "user-b"}) + + idA := sm.Generate() + _, err := sm.CreateSession(ctx1, idA) + require.NoError(t, err) + + idB := sm.Generate() + _, err = sm.CreateSession(ctx2, idB) + require.NoError(t, err, "user-b should have its own independent count") + }) +} + +// --------------------------------------------------------------------------- +// Tests: idle session reaper +// --------------------------------------------------------------------------- + +func TestVMCPSessionManager_IdleReaper(t *testing.T) { + t.Parallel() + + t.Run("terminates session that exceeds idle timeout", func(t *testing.T) { + t.Parallel() + + factory := newFakeFactory(nil) + registry := newFakeRegistry() + idleTimeout := 5 * time.Minute + sm, storage := newTestVMCPSessionManagerWithLimits(t, factory, registry, + Limits{IdleSessionTimeout: idleTimeout}) + + ctx := context.Background() + id := sm.Generate() + _, err := sm.CreateSession(ctx, id) + require.NoError(t, err) + + // Back-date the idle timestamp so the session appears past the timeout + // without any real sleep, making the test immune to CI scheduling jitter. + sm.idleActivityMu.Lock() + sm.idleActivity[id] = time.Now().Add(-(idleTimeout + time.Second)) + sm.idleActivityMu.Unlock() + + sm.reapIdleSessions() + + _, exists := storage.Get(id) + assert.False(t, exists, "idle session should have been reaped") + }) + + t.Run("does not terminate session active within timeout", func(t *testing.T) { + t.Parallel() + + factory := newFakeFactory([]vmcp.Tool{{Name: "noop"}}) + registry := newFakeRegistry() + idleTimeout := 200 * time.Millisecond + sm, storage := newTestVMCPSessionManagerWithLimits(t, factory, registry, + Limits{IdleSessionTimeout: idleTimeout}) + + ctx := context.Background() + id := sm.Generate() + _, err := sm.CreateSession(ctx, id) + require.NoError(t, err) + + // Simulate activity by resetting the idle clock. + sm.resetIdleActivity(id) + + // Reap immediately — session should survive. + sm.reapIdleSessions() + + _, exists := storage.Get(id) + assert.True(t, exists, "recently active session should not be reaped") + }) + + t.Run("reaper is no-op when IdleSessionTimeout is zero", func(t *testing.T) { + t.Parallel() + + factory := newFakeFactory(nil) + registry := newFakeRegistry() + sm, storage := newTestVMCPSessionManagerWithLimits(t, factory, registry, Limits{}) + + ctx := context.Background() + id := sm.Generate() + _, err := sm.CreateSession(ctx, id) + require.NoError(t, err) + + // Idle map should be empty when timeout is disabled. + sm.idleActivityMu.RLock() + idleCount := len(sm.idleActivity) + sm.idleActivityMu.RUnlock() + assert.Equal(t, 0, idleCount, "idle map should be empty when timeout is disabled") + + sm.reapIdleSessions() // should not panic or touch storage + + _, exists := storage.Get(id) + assert.True(t, exists, "session should still exist when idle reaper is disabled") + }) +} + +// --------------------------------------------------------------------------- +// Tests: ActiveSessionCount / global limit accuracy +// --------------------------------------------------------------------------- + +func TestVMCPSessionManager_ActiveSessionCount(t *testing.T) { + t.Parallel() + + t.Run("increments on Generate and decrements on Terminate for MultiSession", func(t *testing.T) { + t.Parallel() + + factory := newFakeFactory(nil) + registry := newFakeRegistry() + sm, _ := newTestVMCPSessionManager(t, factory, registry) + + assert.Equal(t, 0, sm.ActiveSessionCount()) + + id := sm.Generate() + assert.Equal(t, 1, sm.ActiveSessionCount()) + + _, err := sm.CreateSession(context.Background(), id) + require.NoError(t, err) + assert.Equal(t, 1, sm.ActiveSessionCount(), "CreateSession should not change the count") + + _, err = sm.Terminate(id) + require.NoError(t, err) + assert.Equal(t, 0, sm.ActiveSessionCount()) + }) + + t.Run("decrements on Terminate for placeholder (terminated but not deleted)", func(t *testing.T) { + t.Parallel() + + factory := newFakeFactory(nil) + registry := newFakeRegistry() + sm, _ := newTestVMCPSessionManager(t, factory, registry) + + id := sm.Generate() + assert.Equal(t, 1, sm.ActiveSessionCount()) + + // Terminate the placeholder before CreateSession — it is marked terminated, + // not deleted, but the active count must still drop. + _, err := sm.Terminate(id) + require.NoError(t, err) + assert.Equal(t, 0, sm.ActiveSessionCount(), + "terminated placeholder must not count towards active sessions") + }) + + t.Run("Generate returns empty string and does not increment count when global limit reached", func(t *testing.T) { + t.Parallel() + + factory := newFakeFactory(nil) + registry := newFakeRegistry() + sm, _ := newTestVMCPSessionManagerWithLimits(t, factory, registry, Limits{MaxSessions: 1}) + + // First generate succeeds; active count becomes 1 (== MaxSessions). + id := sm.Generate() + require.NotEmpty(t, id) + assert.Equal(t, 1, sm.ActiveSessionCount()) + + // Second generate must be rejected because the limit is reached. + id2 := sm.Generate() + assert.Empty(t, id2, "Generate must return empty string when global limit is reached") + assert.Equal(t, 1, sm.ActiveSessionCount(), "rejected Generate must not increment active count") + }) + + t.Run("rejected CreateSession (per-client limit) does not leak into active count", func(t *testing.T) { + t.Parallel() + + factory := newFakeFactory(nil) + registry := newFakeRegistry() + sm, _ := newTestVMCPSessionManagerWithLimits(t, factory, registry, Limits{MaxSessionsPerClient: 1}) + + identity := &auth.Identity{Subject: "user-x"} + ctx := auth.WithIdentity(context.Background(), identity) + + // First session: succeeds. + id1 := sm.Generate() + _, err := sm.CreateSession(ctx, id1) + require.NoError(t, err) + assert.Equal(t, 1, sm.ActiveSessionCount()) + + // Second generate: count = 2. + id2 := sm.Generate() + assert.Equal(t, 2, sm.ActiveSessionCount()) + + // CreateSession fails (per-client limit). The server will call Terminate(id2). + _, err = sm.CreateSession(ctx, id2) + require.ErrorIs(t, err, ErrPerClientSessionLimitReached) + _, _ = sm.Terminate(id2) // server-side cleanup + + // Active count must return to 1 (only the first session remains). + assert.Equal(t, 1, sm.ActiveSessionCount(), + "failed registration must not permanently consume the global session budget") + }) + + t.Run("Terminate cleans up in-memory maps when storage entry already removed by TTL", func(t *testing.T) { + t.Parallel() + + factory := newFakeFactory(nil) + registry := newFakeRegistry() + sm, storage := newTestVMCPSessionManagerWithLimits(t, factory, registry, Limits{MaxSessionsPerClient: 5}) + + identity := &auth.Identity{Subject: "user-ttl"} + ctx := auth.WithIdentity(context.Background(), identity) + + id := sm.Generate() + require.NotEmpty(t, id) + _, err := sm.CreateSession(ctx, id) + require.NoError(t, err) + + // Simulate TTL eviction by deleting directly from the transport storage, + // bypassing sm.Terminate() (so sessionSubject/idleActivity are NOT cleaned up yet). + require.NoError(t, storage.Delete(id)) + + // Now call Terminate() — storage.Get returns !exists. The fix must still + // clean up the in-memory maps so per-client counts and idle entries don't leak. + _, err = sm.Terminate(id) + require.NoError(t, err) + + // Per-client count for this identity must be back to zero. + sm.perClientMu.Lock() + count := sm.perClientCounts[identity.Subject] + sm.perClientMu.Unlock() + assert.Equal(t, 0, count, "per-client count must be cleaned up even when storage entry was already gone") + + // Idle activity must be removed. + sm.idleActivityMu.RLock() + _, hasIdle := sm.idleActivity[id] + sm.idleActivityMu.RUnlock() + assert.False(t, hasIdle, "idle activity entry must be cleaned up even when storage entry was already gone") + }) +}